wirepoint_rcnn.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. import os
  2. from typing import Optional, Any
  3. import numpy as np
  4. import torch
  5. from tensorboardX import SummaryWriter
  6. from torch import nn
  7. import torch.nn.functional as F
  8. # from torchinfo import summary
  9. from torchvision.io import read_image
  10. from torchvision.models import resnet50, ResNet50_Weights
  11. from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
  12. from torchvision.models.detection._utils import overwrite_eps
  13. from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  14. from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
  15. from torchvision.models.detection.keypoint_rcnn import KeypointRCNNHeads, KeypointRCNNPredictor, \
  16. KeypointRCNN_ResNet50_FPN_Weights
  17. from torchvision.ops import MultiScaleRoIAlign
  18. from torchvision.ops import misc as misc_nn_ops
  19. # from visdom import Visdom
  20. from models.config import config_tool
  21. from models.config.config_tool import read_yaml
  22. from models.ins.trainer import get_transform
  23. from models.wirenet.head import RoIHeads
  24. from models.wirenet.wirepoint_dataset import WirePointDataset
  25. from tools import utils
  26. from torch.utils.tensorboard import SummaryWriter
  27. FEATURE_DIM = 8
  28. def non_maximum_suppression(a):
  29. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  30. mask = (a == ap).float().clamp(min=0.0)
  31. return a * mask
  32. class Bottleneck1D(nn.Module):
  33. def __init__(self, inplanes, outplanes):
  34. super(Bottleneck1D, self).__init__()
  35. planes = outplanes // 2
  36. self.op = nn.Sequential(
  37. nn.BatchNorm1d(inplanes),
  38. nn.ReLU(inplace=True),
  39. nn.Conv1d(inplanes, planes, kernel_size=1),
  40. nn.BatchNorm1d(planes),
  41. nn.ReLU(inplace=True),
  42. nn.Conv1d(planes, planes, kernel_size=3, padding=1),
  43. nn.BatchNorm1d(planes),
  44. nn.ReLU(inplace=True),
  45. nn.Conv1d(planes, outplanes, kernel_size=1),
  46. )
  47. def forward(self, x):
  48. return x + self.op(x)
  49. class WirepointRCNN(FasterRCNN):
  50. def __init__(
  51. self,
  52. backbone,
  53. num_classes=None,
  54. # transform parameters
  55. min_size=None,
  56. max_size=1333,
  57. image_mean=None,
  58. image_std=None,
  59. # RPN parameters
  60. rpn_anchor_generator=None,
  61. rpn_head=None,
  62. rpn_pre_nms_top_n_train=2000,
  63. rpn_pre_nms_top_n_test=1000,
  64. rpn_post_nms_top_n_train=2000,
  65. rpn_post_nms_top_n_test=1000,
  66. rpn_nms_thresh=0.7,
  67. rpn_fg_iou_thresh=0.7,
  68. rpn_bg_iou_thresh=0.3,
  69. rpn_batch_size_per_image=256,
  70. rpn_positive_fraction=0.5,
  71. rpn_score_thresh=0.0,
  72. # Box parameters
  73. box_roi_pool=None,
  74. box_head=None,
  75. box_predictor=None,
  76. box_score_thresh=0.05,
  77. box_nms_thresh=0.5,
  78. box_detections_per_img=100,
  79. box_fg_iou_thresh=0.5,
  80. box_bg_iou_thresh=0.5,
  81. box_batch_size_per_image=512,
  82. box_positive_fraction=0.25,
  83. bbox_reg_weights=None,
  84. # keypoint parameters
  85. keypoint_roi_pool=None,
  86. keypoint_head=None,
  87. keypoint_predictor=None,
  88. num_keypoints=None,
  89. wirepoint_roi_pool=None,
  90. wirepoint_head=None,
  91. wirepoint_predictor=None,
  92. **kwargs,
  93. ):
  94. if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
  95. raise TypeError(
  96. "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
  97. )
  98. if min_size is None:
  99. min_size = (640, 672, 704, 736, 768, 800)
  100. if num_keypoints is not None:
  101. if keypoint_predictor is not None:
  102. raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
  103. else:
  104. num_keypoints = 17
  105. out_channels = backbone.out_channels
  106. if wirepoint_roi_pool is None:
  107. wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
  108. sampling_ratio=2, )
  109. if wirepoint_head is None:
  110. keypoint_layers = tuple(512 for _ in range(8))
  111. print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
  112. wirepoint_head = WirepointHead(out_channels, keypoint_layers)
  113. if wirepoint_predictor is None:
  114. keypoint_dim_reduced = 512 # == keypoint_layers[-1]
  115. wirepoint_predictor = WirepointPredictor()
  116. super().__init__(
  117. backbone,
  118. num_classes,
  119. # transform parameters
  120. min_size,
  121. max_size,
  122. image_mean,
  123. image_std,
  124. # RPN-specific parameters
  125. rpn_anchor_generator,
  126. rpn_head,
  127. rpn_pre_nms_top_n_train,
  128. rpn_pre_nms_top_n_test,
  129. rpn_post_nms_top_n_train,
  130. rpn_post_nms_top_n_test,
  131. rpn_nms_thresh,
  132. rpn_fg_iou_thresh,
  133. rpn_bg_iou_thresh,
  134. rpn_batch_size_per_image,
  135. rpn_positive_fraction,
  136. rpn_score_thresh,
  137. # Box parameters
  138. box_roi_pool,
  139. box_head,
  140. box_predictor,
  141. box_score_thresh,
  142. box_nms_thresh,
  143. box_detections_per_img,
  144. box_fg_iou_thresh,
  145. box_bg_iou_thresh,
  146. box_batch_size_per_image,
  147. box_positive_fraction,
  148. bbox_reg_weights,
  149. **kwargs,
  150. )
  151. if box_roi_pool is None:
  152. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
  153. if box_head is None:
  154. resolution = box_roi_pool.output_size[0]
  155. representation_size = 1024
  156. box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  157. if box_predictor is None:
  158. representation_size = 1024
  159. box_predictor = FastRCNNPredictor(representation_size, num_classes)
  160. roi_heads = RoIHeads(
  161. # Box
  162. box_roi_pool,
  163. box_head,
  164. box_predictor,
  165. box_fg_iou_thresh,
  166. box_bg_iou_thresh,
  167. box_batch_size_per_image,
  168. box_positive_fraction,
  169. bbox_reg_weights,
  170. box_score_thresh,
  171. box_nms_thresh,
  172. box_detections_per_img,
  173. # wirepoint_roi_pool=wirepoint_roi_pool,
  174. # wirepoint_head=wirepoint_head,
  175. # wirepoint_predictor=wirepoint_predictor,
  176. )
  177. self.roi_heads = roi_heads
  178. self.roi_heads.wirepoint_roi_pool = wirepoint_roi_pool
  179. self.roi_heads.wirepoint_head = wirepoint_head
  180. self.roi_heads.wirepoint_predictor = wirepoint_predictor
  181. class WirepointHead(nn.Module):
  182. def __init__(self, input_channels, num_class):
  183. super(WirepointHead, self).__init__()
  184. self.head_size = [[2], [1], [2]]
  185. m = int(input_channels / 4)
  186. heads = []
  187. # print(f'M.head_size:{M.head_size}')
  188. # for output_channels in sum(M.head_size, []):
  189. for output_channels in sum(self.head_size, []):
  190. heads.append(
  191. nn.Sequential(
  192. nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
  193. nn.ReLU(inplace=True),
  194. nn.Conv2d(m, output_channels, kernel_size=1),
  195. )
  196. )
  197. self.heads = nn.ModuleList(heads)
  198. def forward(self, x):
  199. # for idx, head in enumerate(self.heads):
  200. # print(f'{idx},multitask head:{head(x).shape},input x:{x.shape}')
  201. outputs = torch.cat([head(x) for head in self.heads], dim=1)
  202. features = x
  203. return outputs, features
  204. class WirepointPredictor(nn.Module):
  205. def __init__(self):
  206. super().__init__()
  207. # self.backbone = backbone
  208. # self.cfg = read_yaml(cfg)
  209. self.cfg = read_yaml('wirenet.yaml')
  210. self.n_pts0 = self.cfg['model']['n_pts0']
  211. self.n_pts1 = self.cfg['model']['n_pts1']
  212. self.n_stc_posl = self.cfg['model']['n_stc_posl']
  213. self.dim_loi = self.cfg['model']['dim_loi']
  214. self.use_conv = self.cfg['model']['use_conv']
  215. self.dim_fc = self.cfg['model']['dim_fc']
  216. self.n_out_line = self.cfg['model']['n_out_line']
  217. self.n_out_junc = self.cfg['model']['n_out_junc']
  218. self.loss_weight = self.cfg['model']['loss_weight']
  219. self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
  220. self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
  221. self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
  222. self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
  223. self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
  224. self.use_cood = self.cfg['model']['use_cood']
  225. self.use_slop = self.cfg['model']['use_slop']
  226. self.n_stc_negl = self.cfg['model']['n_stc_negl']
  227. self.head_size = self.cfg['model']['head_size']
  228. self.num_class = sum(sum(self.head_size, []))
  229. self.head_off = np.cumsum([sum(h) for h in self.head_size])
  230. lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
  231. self.register_buffer("lambda_", lambda_)
  232. self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
  233. self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
  234. scale_factor = self.n_pts0 // self.n_pts1
  235. if self.use_conv:
  236. self.pooling = nn.Sequential(
  237. nn.MaxPool1d(scale_factor, scale_factor),
  238. Bottleneck1D(self.dim_loi, self.dim_loi),
  239. )
  240. self.fc2 = nn.Sequential(
  241. nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
  242. )
  243. else:
  244. self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
  245. self.fc2 = nn.Sequential(
  246. nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
  247. nn.ReLU(inplace=True),
  248. nn.Linear(self.dim_fc, self.dim_fc),
  249. nn.ReLU(inplace=True),
  250. nn.Linear(self.dim_fc, 1),
  251. )
  252. self.loss = nn.BCEWithLogitsLoss(reduction="none")
  253. def forward(self, inputs, features, targets=None):
  254. # outputs, features = input
  255. # for out in outputs:
  256. # print(f'out:{out.shape}')
  257. # outputs=merge_features(outputs,100)
  258. batch, channel, row, col = inputs.shape
  259. print(f'outputs:{inputs.shape}')
  260. # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
  261. if targets is not None:
  262. self.training = True
  263. # print(f'target:{targets}')
  264. wires_targets = [t["wires"] for t in targets]
  265. # print(f'wires_target:{wires_targets}')
  266. # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  267. junc_maps = [d["junc_map"] for d in wires_targets]
  268. junc_offsets = [d["junc_offset"] for d in wires_targets]
  269. line_maps = [d["line_map"] for d in wires_targets]
  270. junc_map_tensor = torch.stack(junc_maps, dim=0)
  271. junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  272. line_map_tensor = torch.stack(line_maps, dim=0)
  273. wires_meta = {
  274. "junc_map": junc_map_tensor,
  275. "junc_offset": junc_offset_tensor,
  276. # "line_map": line_map_tensor,
  277. }
  278. else:
  279. self.training = False
  280. t = {
  281. "junc_coords": torch.zeros(1, 2),
  282. "jtyp": torch.zeros(1, dtype=torch.uint8),
  283. "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
  284. "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
  285. "junc_map": torch.zeros([1, 1, 128, 128]),
  286. "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
  287. }
  288. wires_targets = [t for b in range(inputs.size(0))]
  289. wires_meta = {
  290. "junc_map": torch.zeros([1, 1, 128, 128]),
  291. "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
  292. }
  293. T = wires_meta.copy()
  294. n_jtyp = T["junc_map"].shape[1]
  295. offset = self.head_off
  296. result = {}
  297. for stack, output in enumerate([inputs]):
  298. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  299. # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
  300. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  301. lmap = output[offset[0]: offset[1]].squeeze(0)
  302. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  303. if stack == 0:
  304. result["preds"] = {
  305. "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  306. "lmap": lmap.sigmoid(),
  307. "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  308. }
  309. h = result["preds"]
  310. # print(f'features shape:{features.shape}')
  311. x = self.fc1(features)
  312. n_batch, n_channel, row, col = x.shape
  313. xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
  314. for i, meta in enumerate(wires_targets):
  315. p, label, feat, jc = self.sample_lines(
  316. meta, h["jmap"][i], h["joff"][i],
  317. )
  318. # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
  319. ys.append(label)
  320. if self.training and self.do_static_sampling:
  321. p = torch.cat([p, meta["lpre"]])
  322. feat = torch.cat([feat, meta["lpre_feat"]])
  323. ys.append(meta["lpre_label"])
  324. del jc
  325. else:
  326. jcs.append(jc)
  327. ps.append(p)
  328. fs.append(feat)
  329. p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
  330. p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
  331. px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
  332. px0 = px.floor().clamp(min=0, max=127)
  333. py0 = py.floor().clamp(min=0, max=127)
  334. px1 = (px0 + 1).clamp(min=0, max=127)
  335. py1 = (py0 + 1).clamp(min=0, max=127)
  336. px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
  337. # xp: [N_LINE, N_CHANNEL, N_POINT]
  338. xp = (
  339. (
  340. x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
  341. + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
  342. + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
  343. + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
  344. )
  345. .reshape(n_channel, -1, self.n_pts0)
  346. .permute(1, 0, 2)
  347. )
  348. xp = self.pooling(xp)
  349. # print(f'xp.shape:{xp.shape}')
  350. xs.append(xp)
  351. idx.append(idx[-1] + xp.shape[0])
  352. # print(f'idx__:{idx}')
  353. x, y = torch.cat(xs), torch.cat(ys)
  354. f = torch.cat(fs)
  355. x = x.reshape(-1, self.n_pts1 * self.dim_loi)
  356. print(f"pstest{ps}")
  357. x = torch.cat([x, f], 1)
  358. x = x.to(dtype=torch.float32)
  359. x = self.fc2(x).flatten()
  360. # return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
  361. return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
  362. # if mode != "training":
  363. # self.inference(x, idx, jcs, n_batch, ps)
  364. # return result
  365. def sample_lines(self, meta, jmap, joff):
  366. with torch.no_grad():
  367. junc = meta["junc_coords"] # [N, 2]
  368. jtyp = meta["jtyp"] # [N]
  369. Lpos = meta["line_pos_idx"]
  370. Lneg = meta["line_neg_idx"]
  371. n_type = jmap.shape[0]
  372. jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
  373. joff = joff.reshape(n_type, 2, -1)
  374. max_K = self.n_dyn_junc // n_type
  375. N = len(junc)
  376. # if mode != "training":
  377. if not self.training:
  378. K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
  379. else:
  380. K = min(int(N * 2 + 2), max_K)
  381. if K < 2:
  382. K = 2
  383. device = jmap.device
  384. # index: [N_TYPE, K]
  385. score, index = torch.topk(jmap, k=K)
  386. y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
  387. x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
  388. # xy: [N_TYPE, K, 2]
  389. xy = torch.cat([y[..., None], x[..., None]], dim=-1)
  390. xy_ = xy[..., None, :]
  391. del x, y, index
  392. # dist: [N_TYPE, K, N]
  393. dist = torch.sum((xy_ - junc) ** 2, -1)
  394. cost, match = torch.min(dist, -1)
  395. # xy: [N_TYPE * K, 2]
  396. # match: [N_TYPE, K]
  397. for t in range(n_type):
  398. match[t, jtyp[match[t]] != t] = N
  399. match[cost > 1.5 * 1.5] = N
  400. match = match.flatten()
  401. _ = torch.arange(n_type * K, device=device)
  402. u, v = torch.meshgrid(_, _)
  403. u, v = u.flatten(), v.flatten()
  404. up, vp = match[u], match[v]
  405. label = Lpos[up, vp]
  406. # if mode == "training":
  407. if self.training:
  408. c = torch.zeros_like(label, dtype=torch.bool)
  409. # sample positive lines
  410. cdx = label.nonzero().flatten()
  411. if len(cdx) > self.n_dyn_posl:
  412. # print("too many positive lines")
  413. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
  414. cdx = cdx[perm]
  415. c[cdx] = 1
  416. # sample negative lines
  417. cdx = Lneg[up, vp].nonzero().flatten()
  418. if len(cdx) > self.n_dyn_negl:
  419. # print("too many negative lines")
  420. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
  421. cdx = cdx[perm]
  422. c[cdx] = 1
  423. # sample other (unmatched) lines
  424. cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
  425. c[cdx] = 1
  426. else:
  427. c = (u < v).flatten()
  428. # sample lines
  429. u, v, label = u[c], v[c], label[c]
  430. xy = xy.reshape(n_type * K, 2)
  431. xyu, xyv = xy[u], xy[v]
  432. u2v = xyu - xyv
  433. u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
  434. feat = torch.cat(
  435. [
  436. xyu / 128 * self.use_cood,
  437. xyv / 128 * self.use_cood,
  438. u2v * self.use_slop,
  439. (u[:, None] > K).float(),
  440. (v[:, None] > K).float(),
  441. ],
  442. 1,
  443. )
  444. line = torch.cat([xyu[:, None], xyv[:, None]], 1)
  445. xy = xy.reshape(n_type, K, 2)
  446. jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
  447. return line, label.float(), feat, jcs
  448. def wirepointrcnn_resnet50_fpn(
  449. *,
  450. weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
  451. progress: bool = True,
  452. num_classes: Optional[int] = None,
  453. num_keypoints: Optional[int] = None,
  454. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  455. trainable_backbone_layers: Optional[int] = None,
  456. **kwargs: Any,
  457. ) -> WirepointRCNN:
  458. weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
  459. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  460. is_trained = weights is not None or weights_backbone is not None
  461. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  462. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  463. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  464. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  465. model = WirepointRCNN(backbone, num_classes=5, **kwargs)
  466. if weights is not None:
  467. model.load_state_dict(weights.get_state_dict(progress=progress))
  468. if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
  469. overwrite_eps(model, 0.0)
  470. return model
  471. def _loss(losses):
  472. total_loss = 0
  473. for i in losses.keys():
  474. if i != "loss_wirepoint":
  475. total_loss += losses[i]
  476. else:
  477. loss_labels = losses[i]["losses"]
  478. loss_labels_k = list(loss_labels[0].keys())
  479. for j, name in enumerate(loss_labels_k):
  480. loss = loss_labels[0][name].mean()
  481. total_loss += loss
  482. return total_loss
  483. if __name__ == '__main__':
  484. cfg = 'wirenet.yaml'
  485. cfg = read_yaml(cfg)
  486. print(f'cfg:{cfg}')
  487. print(cfg['model']['n_dyn_negl'])
  488. # net = WirepointPredictor()
  489. if torch.cuda.is_available():
  490. device_name = "cuda"
  491. torch.backends.cudnn.deterministic = True
  492. torch.cuda.manual_seed(0)
  493. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  494. else:
  495. print("CUDA is not available")
  496. device = torch.device(device_name)
  497. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  498. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  499. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  500. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
  501. train_collate_fn = utils.collate_fn_wirepoint
  502. data_loader_train = torch.utils.data.DataLoader(
  503. dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
  504. )
  505. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  506. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  507. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  508. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
  509. val_collate_fn = utils.collate_fn_wirepoint
  510. data_loader_val = torch.utils.data.DataLoader(
  511. dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
  512. )
  513. model = wirepointrcnn_resnet50_fpn().to(device)
  514. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  515. writer = SummaryWriter(cfg['io']['logdir'])
  516. def move_to_device(data, device):
  517. if isinstance(data, (list, tuple)):
  518. return type(data)(move_to_device(item, device) for item in data)
  519. elif isinstance(data, dict):
  520. return {key: move_to_device(value, device) for key, value in data.items()}
  521. elif isinstance(data, torch.Tensor):
  522. return data.to(device)
  523. else:
  524. return data # 对于非张量类型的数据不做任何改变
  525. def writer_loss(writer, losses):
  526. # 记录每个损失项到TensorBoard
  527. for key, value in losses.items():
  528. if isinstance(value, dict): # 如果value本身也是一个字典(例如'loss_wirepoint')
  529. for subkey, subvalue in value['losses'][0].items():
  530. writer.add_scalar(f'{key}/{subkey}', subvalue.item(), epoch)
  531. else:
  532. writer.add_scalar(key, value.item(), epoch)
  533. for epoch in range(cfg['optim']['max_epoch']):
  534. model.train()
  535. for imgs, targets in data_loader_train:
  536. losses = model(move_to_device(imgs, device), move_to_device(targets, device))
  537. loss = _loss(losses)
  538. print(loss)
  539. optimizer.zero_grad()
  540. loss.backward()
  541. optimizer.step()
  542. writer_loss(writer, losses)
  543. model.eval()
  544. with torch.no_grad():
  545. for imgs, targets in data_loader_val:
  546. print(111)
  547. pred = model(move_to_device(imgs, device))
  548. print(f"pred:{pred}")
  549. # imgs, targets = next(iter(data_loader))
  550. #
  551. # model.train()
  552. # pred = model(imgs, targets)
  553. # print(f'pred:{pred}')
  554. # result, losses = model(imgs, targets)
  555. # print(f'result:{result}')
  556. # print(f'pred:{losses}')
  557. '''
  558. ########### predict#############
  559. img_path=r"I:\wirenet_dateset\images\train\00030078_2.png"
  560. transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  561. img = read_image(img_path)
  562. img = transforms(img)
  563. img = torch.ones((2, 3, 512, 512))
  564. # print(f'img shape:{img.shape}')
  565. model.eval()
  566. onnx_file_path = "./wirenet.onnx"
  567. # 导出模型为ONNX格式
  568. # torch.onnx.export(model, img, onnx_file_path, verbose=True, input_names=['input'],
  569. # output_names=['output'])
  570. # torch.save(model,'./wirenet.pt')
  571. # 5. 指定输出的 ONNX 文件名
  572. # onnx_file_path = "./wirepoint_rcnn.onnx"
  573. # 准备一个示例输入:Mask R-CNN 需要一个图像列表作为输入,每个图像张量的形状应为 [C, H, W]
  574. img = [torch.ones((3, 800, 800))] # 示例输入图像大小为 800x800,3个通道
  575. # 指定输出的 ONNX 文件名
  576. # onnx_file_path = "./mask_rcnn.onnx"
  577. # model_scripted = torch.jit.script(model)
  578. # torch.onnx.export(model_scripted, input, "model.onnx", verbose=True, input_names=["input"],
  579. # output_names=["output"])
  580. #
  581. # print(f"Model has been converted to ONNX and saved to {onnx_file_path}")
  582. pred=model(img)
  583. #
  584. print(f'pred:{pred}')
  585. ################################################## end predict
  586. ########## traing ###################################
  587. # imgs, targets = next(iter(data_loader))
  588. # model.train()
  589. # pred = model(imgs, targets)
  590. # class WrapperModule(torch.nn.Module):
  591. # def __init__(self, model):
  592. # super(WrapperModule, self).__init__()
  593. # self.model = model
  594. #
  595. # def forward(self,img, targets):
  596. # # 在这里处理复杂的输入结构,将其转换为适合追踪的形式
  597. # return self.model(img,targets)
  598. # torch.save(model.state_dict(),'./wire.pt')
  599. # 包装原始模型
  600. # wrapped_model = WrapperModule(model)
  601. # # model_scripted = torch.jit.trace(wrapped_model,img)
  602. # writer = SummaryWriter('./')
  603. # writer.add_graph(wrapped_model, (imgs,targets))
  604. # writer.close()
  605. #
  606. # print(f'pred:{pred}')
  607. ########## end traing ###################################
  608. # for imgs,targets in data_loader:
  609. # print(f'imgs:{imgs}')
  610. # print(f'targets:{targets}')
  611. '''