wirepoint_rcnn.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. import os
  2. from typing import Optional, Any
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from tensorboardX import SummaryWriter
  7. from torch import nn
  8. import torch.nn.functional as F
  9. # from torchinfo import summary
  10. from torchvision.io import read_image
  11. from torchvision.models import resnet50, ResNet50_Weights
  12. from torchvision.models import resnet18, ResNet18_Weights
  13. from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
  14. from torchvision.models.detection._utils import overwrite_eps
  15. from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  16. from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
  17. from torchvision.models.detection.keypoint_rcnn import KeypointRCNNHeads, KeypointRCNNPredictor, \
  18. KeypointRCNN_ResNet50_FPN_Weights
  19. from torchvision.ops import MultiScaleRoIAlign
  20. from torchvision.ops import misc as misc_nn_ops
  21. # from visdom import Visdom
  22. from models.config import config_tool
  23. from models.config.config_tool import read_yaml
  24. from models.ins.trainer import get_transform
  25. from models.wirenet.head import RoIHeads
  26. from models.wirenet.wirepoint_dataset import WirePointDataset
  27. from tools import utils
  28. from torch.utils.tensorboard import SummaryWriter
  29. import matplotlib.pyplot as plt
  30. import matplotlib as mpl
  31. from skimage import io
  32. import os.path as osp
  33. from torchvision.utils import draw_bounding_boxes
  34. from torchvision import transforms
  35. from models.wirenet.postprocess import postprocess
  36. FEATURE_DIM = 8
  37. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  38. print(f"Using device: {device}")
  39. def non_maximum_suppression(a):
  40. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  41. mask = (a == ap).float().clamp(min=0.0)
  42. return a * mask
  43. class Bottleneck1D(nn.Module):
  44. def __init__(self, inplanes, outplanes):
  45. super(Bottleneck1D, self).__init__()
  46. planes = outplanes // 2
  47. self.op = nn.Sequential(
  48. nn.BatchNorm1d(inplanes),
  49. nn.ReLU(inplace=True),
  50. nn.Conv1d(inplanes, planes, kernel_size=1),
  51. nn.BatchNorm1d(planes),
  52. nn.ReLU(inplace=True),
  53. nn.Conv1d(planes, planes, kernel_size=3, padding=1),
  54. nn.BatchNorm1d(planes),
  55. nn.ReLU(inplace=True),
  56. nn.Conv1d(planes, outplanes, kernel_size=1),
  57. )
  58. def forward(self, x):
  59. return x + self.op(x)
  60. class WirepointRCNN(FasterRCNN):
  61. def __init__(
  62. self,
  63. backbone,
  64. num_classes=None,
  65. # transform parameters
  66. min_size=None,
  67. max_size=1333,
  68. image_mean=None,
  69. image_std=None,
  70. # RPN parameters
  71. rpn_anchor_generator=None,
  72. rpn_head=None,
  73. rpn_pre_nms_top_n_train=2000,
  74. rpn_pre_nms_top_n_test=1000,
  75. rpn_post_nms_top_n_train=2000,
  76. rpn_post_nms_top_n_test=1000,
  77. rpn_nms_thresh=0.7,
  78. rpn_fg_iou_thresh=0.7,
  79. rpn_bg_iou_thresh=0.3,
  80. rpn_batch_size_per_image=256,
  81. rpn_positive_fraction=0.5,
  82. rpn_score_thresh=0.0,
  83. # Box parameters
  84. box_roi_pool=None,
  85. box_head=None,
  86. box_predictor=None,
  87. box_score_thresh=0.05,
  88. box_nms_thresh=0.5,
  89. box_detections_per_img=100,
  90. box_fg_iou_thresh=0.5,
  91. box_bg_iou_thresh=0.5,
  92. box_batch_size_per_image=512,
  93. box_positive_fraction=0.25,
  94. bbox_reg_weights=None,
  95. # keypoint parameters
  96. keypoint_roi_pool=None,
  97. keypoint_head=None,
  98. keypoint_predictor=None,
  99. num_keypoints=None,
  100. wirepoint_roi_pool=None,
  101. wirepoint_head=None,
  102. wirepoint_predictor=None,
  103. **kwargs,
  104. ):
  105. if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
  106. raise TypeError(
  107. "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
  108. )
  109. if min_size is None:
  110. min_size = (640, 672, 704, 736, 768, 800)
  111. if num_keypoints is not None:
  112. if keypoint_predictor is not None:
  113. raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
  114. else:
  115. num_keypoints = 17
  116. out_channels = backbone.out_channels
  117. if wirepoint_roi_pool is None:
  118. wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
  119. sampling_ratio=2, )
  120. if wirepoint_head is None:
  121. keypoint_layers = tuple(512 for _ in range(8))
  122. # print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
  123. wirepoint_head = WirepointHead(out_channels, keypoint_layers)
  124. if wirepoint_predictor is None:
  125. keypoint_dim_reduced = 512 # == keypoint_layers[-1]
  126. wirepoint_predictor = WirepointPredictor()
  127. super().__init__(
  128. backbone,
  129. num_classes,
  130. # transform parameters
  131. min_size,
  132. max_size,
  133. image_mean,
  134. image_std,
  135. # RPN-specific parameters
  136. rpn_anchor_generator,
  137. rpn_head,
  138. rpn_pre_nms_top_n_train,
  139. rpn_pre_nms_top_n_test,
  140. rpn_post_nms_top_n_train,
  141. rpn_post_nms_top_n_test,
  142. rpn_nms_thresh,
  143. rpn_fg_iou_thresh,
  144. rpn_bg_iou_thresh,
  145. rpn_batch_size_per_image,
  146. rpn_positive_fraction,
  147. rpn_score_thresh,
  148. # Box parameters
  149. box_roi_pool,
  150. box_head,
  151. box_predictor,
  152. box_score_thresh,
  153. box_nms_thresh,
  154. box_detections_per_img,
  155. box_fg_iou_thresh,
  156. box_bg_iou_thresh,
  157. box_batch_size_per_image,
  158. box_positive_fraction,
  159. bbox_reg_weights,
  160. **kwargs,
  161. )
  162. if box_roi_pool is None:
  163. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
  164. if box_head is None:
  165. resolution = box_roi_pool.output_size[0]
  166. representation_size = 1024
  167. box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  168. if box_predictor is None:
  169. representation_size = 1024
  170. box_predictor = FastRCNNPredictor(representation_size, num_classes)
  171. roi_heads = RoIHeads(
  172. # Box
  173. box_roi_pool,
  174. box_head,
  175. box_predictor,
  176. box_fg_iou_thresh,
  177. box_bg_iou_thresh,
  178. box_batch_size_per_image,
  179. box_positive_fraction,
  180. bbox_reg_weights,
  181. box_score_thresh,
  182. box_nms_thresh,
  183. box_detections_per_img,
  184. # wirepoint_roi_pool=wirepoint_roi_pool,
  185. # wirepoint_head=wirepoint_head,
  186. # wirepoint_predictor=wirepoint_predictor,
  187. )
  188. self.roi_heads = roi_heads
  189. self.roi_heads.wirepoint_roi_pool = wirepoint_roi_pool
  190. self.roi_heads.wirepoint_head = wirepoint_head
  191. self.roi_heads.wirepoint_predictor = wirepoint_predictor
  192. class WirepointHead(nn.Module):
  193. def __init__(self, input_channels, num_class):
  194. super(WirepointHead, self).__init__()
  195. self.head_size = [[2], [1], [2]]
  196. m = int(input_channels / 4)
  197. heads = []
  198. # print(f'M.head_size:{M.head_size}')
  199. # for output_channels in sum(M.head_size, []):
  200. for output_channels in sum(self.head_size, []):
  201. heads.append(
  202. nn.Sequential(
  203. nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
  204. nn.ReLU(inplace=True),
  205. nn.Conv2d(m, output_channels, kernel_size=1),
  206. )
  207. )
  208. self.heads = nn.ModuleList(heads)
  209. def forward(self, x):
  210. # for idx, head in enumerate(self.heads):
  211. # print(f'{idx},multitask head:{head(x).shape},input x:{x.shape}')
  212. outputs = torch.cat([head(x) for head in self.heads], dim=1)
  213. features = x
  214. return outputs, features
  215. class WirepointPredictor(nn.Module):
  216. def __init__(self):
  217. super().__init__()
  218. # self.backbone = backbone
  219. # self.cfg = read_yaml(cfg)
  220. self.cfg = read_yaml('wirenet.yaml')
  221. self.n_pts0 = self.cfg['model']['n_pts0']
  222. self.n_pts1 = self.cfg['model']['n_pts1']
  223. self.n_stc_posl = self.cfg['model']['n_stc_posl']
  224. self.dim_loi = self.cfg['model']['dim_loi']
  225. self.use_conv = self.cfg['model']['use_conv']
  226. self.dim_fc = self.cfg['model']['dim_fc']
  227. self.n_out_line = self.cfg['model']['n_out_line']
  228. self.n_out_junc = self.cfg['model']['n_out_junc']
  229. self.loss_weight = self.cfg['model']['loss_weight']
  230. self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
  231. self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
  232. self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
  233. self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
  234. self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
  235. self.use_cood = self.cfg['model']['use_cood']
  236. self.use_slop = self.cfg['model']['use_slop']
  237. self.n_stc_negl = self.cfg['model']['n_stc_negl']
  238. self.head_size = self.cfg['model']['head_size']
  239. self.num_class = sum(sum(self.head_size, []))
  240. self.head_off = np.cumsum([sum(h) for h in self.head_size])
  241. lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
  242. self.register_buffer("lambda_", lambda_)
  243. self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
  244. self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
  245. scale_factor = self.n_pts0 // self.n_pts1
  246. if self.use_conv:
  247. self.pooling = nn.Sequential(
  248. nn.MaxPool1d(scale_factor, scale_factor),
  249. Bottleneck1D(self.dim_loi, self.dim_loi),
  250. )
  251. self.fc2 = nn.Sequential(
  252. nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
  253. )
  254. else:
  255. self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
  256. self.fc2 = nn.Sequential(
  257. nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
  258. nn.ReLU(inplace=True),
  259. nn.Linear(self.dim_fc, self.dim_fc),
  260. nn.ReLU(inplace=True),
  261. nn.Linear(self.dim_fc, 1),
  262. )
  263. self.loss = nn.BCEWithLogitsLoss(reduction="none")
  264. def forward(self, inputs, features, targets=None):
  265. # outputs, features = input
  266. # for out in outputs:
  267. # print(f'out:{out.shape}')
  268. # outputs=merge_features(outputs,100)
  269. batch, channel, row, col = inputs.shape
  270. # print(f'outputs:{inputs.shape}')
  271. # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
  272. if targets is not None:
  273. self.training = True
  274. # print(f'target:{targets}')
  275. wires_targets = [t["wires"] for t in targets]
  276. # print(f'wires_target:{wires_targets}')
  277. # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  278. junc_maps = [d["junc_map"] for d in wires_targets]
  279. junc_offsets = [d["junc_offset"] for d in wires_targets]
  280. line_maps = [d["line_map"] for d in wires_targets]
  281. junc_map_tensor = torch.stack(junc_maps, dim=0)
  282. junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  283. line_map_tensor = torch.stack(line_maps, dim=0)
  284. wires_meta = {
  285. "junc_map": junc_map_tensor,
  286. "junc_offset": junc_offset_tensor,
  287. # "line_map": line_map_tensor,
  288. }
  289. else:
  290. self.training = False
  291. t = {
  292. "junc_coords": torch.zeros(1, 2).to(device),
  293. "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
  294. "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  295. "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  296. "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
  297. "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
  298. }
  299. wires_targets = [t for b in range(inputs.size(0))]
  300. wires_meta = {
  301. "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
  302. "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
  303. }
  304. T = wires_meta.copy()
  305. n_jtyp = T["junc_map"].shape[1]
  306. offset = self.head_off
  307. result = {}
  308. for stack, output in enumerate([inputs]):
  309. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  310. # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
  311. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  312. lmap = output[offset[0]: offset[1]].squeeze(0)
  313. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  314. if stack == 0:
  315. result["preds"] = {
  316. "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  317. "lmap": lmap.sigmoid(),
  318. "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  319. }
  320. h = result["preds"]
  321. # print(f'features shape:{features.shape}')
  322. x = self.fc1(features)
  323. n_batch, n_channel, row, col = x.shape
  324. xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
  325. for i, meta in enumerate(wires_targets):
  326. p, label, feat, jc = self.sample_lines(
  327. meta, h["jmap"][i], h["joff"][i],
  328. )
  329. # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
  330. ys.append(label)
  331. if self.training and self.do_static_sampling:
  332. p = torch.cat([p, meta["lpre"]])
  333. feat = torch.cat([feat, meta["lpre_feat"]])
  334. ys.append(meta["lpre_label"])
  335. del jc
  336. else:
  337. jcs.append(jc)
  338. ps.append(p)
  339. fs.append(feat)
  340. p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
  341. p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
  342. px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
  343. px0 = px.floor().clamp(min=0, max=127)
  344. py0 = py.floor().clamp(min=0, max=127)
  345. px1 = (px0 + 1).clamp(min=0, max=127)
  346. py1 = (py0 + 1).clamp(min=0, max=127)
  347. px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
  348. # xp: [N_LINE, N_CHANNEL, N_POINT]
  349. xp = (
  350. (
  351. x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
  352. + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
  353. + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
  354. + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
  355. )
  356. .reshape(n_channel, -1, self.n_pts0)
  357. .permute(1, 0, 2)
  358. )
  359. xp = self.pooling(xp)
  360. # print(f'xp.shape:{xp.shape}')
  361. xs.append(xp)
  362. idx.append(idx[-1] + xp.shape[0])
  363. # print(f'idx__:{idx}')
  364. x, y = torch.cat(xs), torch.cat(ys)
  365. f = torch.cat(fs)
  366. x = x.reshape(-1, self.n_pts1 * self.dim_loi)
  367. x = torch.cat([x, f], 1)
  368. x = x.to(dtype=torch.float32)
  369. x = self.fc2(x).flatten()
  370. # return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
  371. return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
  372. # if mode != "training":
  373. # self.inference(x, idx, jcs, n_batch, ps)
  374. # return result
  375. def sample_lines(self, meta, jmap, joff):
  376. with torch.no_grad():
  377. junc = meta["junc_coords"] # [N, 2]
  378. jtyp = meta["jtyp"] # [N]
  379. Lpos = meta["line_pos_idx"]
  380. Lneg = meta["line_neg_idx"]
  381. n_type = jmap.shape[0]
  382. jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
  383. joff = joff.reshape(n_type, 2, -1)
  384. max_K = self.n_dyn_junc // n_type
  385. N = len(junc)
  386. # if mode != "training":
  387. if not self.training:
  388. K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
  389. else:
  390. K = min(int(N * 2 + 2), max_K)
  391. if K < 2:
  392. K = 2
  393. device = jmap.device
  394. # index: [N_TYPE, K]
  395. score, index = torch.topk(jmap, k=K)
  396. y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
  397. x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
  398. # xy: [N_TYPE, K, 2]
  399. xy = torch.cat([y[..., None], x[..., None]], dim=-1)
  400. xy_ = xy[..., None, :]
  401. del x, y, index
  402. # print(f"xy_.is_cuda: {xy_.is_cuda}")
  403. # print(f"junc.is_cuda: {junc.is_cuda}")
  404. # dist: [N_TYPE, K, N]
  405. dist = torch.sum((xy_ - junc) ** 2, -1)
  406. cost, match = torch.min(dist, -1)
  407. # xy: [N_TYPE * K, 2]
  408. # match: [N_TYPE, K]
  409. for t in range(n_type):
  410. match[t, jtyp[match[t]] != t] = N
  411. match[cost > 1.5 * 1.5] = N
  412. match = match.flatten()
  413. _ = torch.arange(n_type * K, device=device)
  414. u, v = torch.meshgrid(_, _)
  415. u, v = u.flatten(), v.flatten()
  416. up, vp = match[u], match[v]
  417. label = Lpos[up, vp]
  418. # if mode == "training":
  419. if self.training:
  420. c = torch.zeros_like(label, dtype=torch.bool)
  421. # sample positive lines
  422. cdx = label.nonzero().flatten()
  423. if len(cdx) > self.n_dyn_posl:
  424. # print("too many positive lines")
  425. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
  426. cdx = cdx[perm]
  427. c[cdx] = 1
  428. # sample negative lines
  429. cdx = Lneg[up, vp].nonzero().flatten()
  430. if len(cdx) > self.n_dyn_negl:
  431. # print("too many negative lines")
  432. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
  433. cdx = cdx[perm]
  434. c[cdx] = 1
  435. # sample other (unmatched) lines
  436. cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
  437. c[cdx] = 1
  438. else:
  439. c = (u < v).flatten()
  440. # sample lines
  441. u, v, label = u[c], v[c], label[c]
  442. xy = xy.reshape(n_type * K, 2)
  443. xyu, xyv = xy[u], xy[v]
  444. u2v = xyu - xyv
  445. u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
  446. feat = torch.cat(
  447. [
  448. xyu / 128 * self.use_cood,
  449. xyv / 128 * self.use_cood,
  450. u2v * self.use_slop,
  451. (u[:, None] > K).float(),
  452. (v[:, None] > K).float(),
  453. ],
  454. 1,
  455. )
  456. line = torch.cat([xyu[:, None], xyv[:, None]], 1)
  457. xy = xy.reshape(n_type, K, 2)
  458. jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
  459. return line, label.float(), feat, jcs
  460. # def wirepointrcnn_resnet50_fpn(
  461. # *,
  462. # weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
  463. # progress: bool = True,
  464. # num_classes: Optional[int] = None,
  465. # num_keypoints: Optional[int] = None,
  466. # weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  467. # trainable_backbone_layers: Optional[int] = None,
  468. # **kwargs: Any,
  469. # ) -> WirepointRCNN:
  470. # weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
  471. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  472. #
  473. # is_trained = weights is not None or weights_backbone is not None
  474. # trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  475. #
  476. # norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  477. #
  478. # backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  479. # backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  480. # model = WirepointRCNN(backbone, num_classes=5, **kwargs)
  481. #
  482. # if weights is not None:
  483. # model.load_state_dict(weights.get_state_dict(progress=progress))
  484. # if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
  485. # overwrite_eps(model, 0.0)
  486. #
  487. # return model
  488. def wirepointrcnn_resnet18_fpn(
  489. *,
  490. weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
  491. progress: bool = True,
  492. num_classes: Optional[int] = None,
  493. num_keypoints: Optional[int] = None,
  494. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  495. trainable_backbone_layers: Optional[int] = None,
  496. **kwargs: Any,
  497. ) -> WirepointRCNN:
  498. weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
  499. weights_backbone = ResNet18_Weights.verify(weights_backbone)
  500. is_trained = weights is not None or weights_backbone is not None
  501. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  502. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  503. backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  504. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  505. model = WirepointRCNN(backbone, num_classes=5, **kwargs)
  506. if weights is not None:
  507. model.load_state_dict(weights.get_state_dict(progress=progress))
  508. if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
  509. overwrite_eps(model, 0.0)
  510. return model
  511. def _loss(losses):
  512. total_loss = 0
  513. for i in losses.keys():
  514. if i != "loss_wirepoint":
  515. total_loss += losses[i]
  516. else:
  517. loss_labels = losses[i]["losses"]
  518. loss_labels_k = list(loss_labels[0].keys())
  519. for j, name in enumerate(loss_labels_k):
  520. loss = loss_labels[0][name].mean()
  521. total_loss += loss
  522. return total_loss
  523. cmap = plt.get_cmap("jet")
  524. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  525. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  526. sm.set_array([])
  527. def c(x):
  528. return sm.to_rgba(x)
  529. def imshow(im):
  530. plt.close()
  531. plt.tight_layout()
  532. plt.imshow(im)
  533. plt.colorbar(sm, fraction=0.046)
  534. plt.xlim([0, im.shape[0]])
  535. plt.ylim([im.shape[0], 0])
  536. # plt.show()
  537. # def _plot_samples(img, i, result, prefix, epoch):
  538. # print(f"prefix:{prefix}")
  539. # def draw_vecl(lines, sline, juncs, junts, fn):
  540. # directory = os.path.dirname(fn)
  541. # if not os.path.exists(directory):
  542. # os.makedirs(directory)
  543. # imshow(img.permute(1, 2, 0))
  544. # if len(lines) > 0 and not (lines[0] == 0).all():
  545. # for i, ((a, b), s) in enumerate(zip(lines, sline)):
  546. # if i > 0 and (lines[i] == lines[0]).all():
  547. # break
  548. # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
  549. # if not (juncs[0] == 0).all():
  550. # for i, j in enumerate(juncs):
  551. # if i > 0 and (i == juncs[0]).all():
  552. # break
  553. # plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
  554. # if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
  555. # for i, j in enumerate(junts):
  556. # if i > 0 and (i == junts[0]).all():
  557. # break
  558. # plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
  559. # plt.savefig(fn), plt.close()
  560. #
  561. # rjuncs = result["juncs"][i].cpu().numpy() * 4
  562. # rjunts = None
  563. # if "junts" in result:
  564. # rjunts = result["junts"][i].cpu().numpy() * 4
  565. #
  566. # vecl_result = result["lines"][i].cpu().numpy() * 4
  567. # score = result["score"][i].cpu().numpy()
  568. #
  569. # draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
  570. #
  571. # img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
  572. # writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
  573. def _plot_samples(img, i, result, prefix, epoch, writer):
  574. # print(f"prefix:{prefix}")
  575. def draw_vecl(lines, sline, juncs, junts, fn):
  576. # 确保目录存在
  577. directory = os.path.dirname(fn)
  578. if not os.path.exists(directory):
  579. os.makedirs(directory)
  580. # 绘制图像
  581. plt.figure()
  582. plt.imshow(img.permute(1, 2, 0).cpu().numpy())
  583. plt.axis('off') # 可选:关闭坐标轴
  584. if len(lines) > 0 and not (lines[0] == 0).all():
  585. for idx, ((a, b), s) in enumerate(zip(lines, sline)):
  586. if idx > 0 and (lines[idx] == lines[0]).all():
  587. break
  588. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1)
  589. if not (juncs[0] == 0).all():
  590. for idx, j in enumerate(juncs):
  591. if idx > 0 and (j == juncs[0]).all():
  592. break
  593. plt.scatter(j[1], j[0], c="red", s=20, zorder=100)
  594. if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
  595. for idx, j in enumerate(junts):
  596. if idx > 0 and (j == junts[0]).all():
  597. break
  598. plt.scatter(j[1], j[0], c="blue", s=20, zorder=100)
  599. # plt.show()
  600. # 将matplotlib图像转换为numpy数组
  601. plt.tight_layout()
  602. fig = plt.gcf()
  603. fig.canvas.draw()
  604. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  605. fig.canvas.get_width_height()[::-1] + (3,))
  606. plt.close()
  607. return image_from_plot
  608. # 获取结果数据并转换为numpy数组
  609. rjuncs = result["juncs"][i].cpu().numpy() * 4
  610. rjunts = None
  611. if "junts" in result:
  612. rjunts = result["junts"][i].cpu().numpy() * 4
  613. vecl_result = result["lines"][i].cpu().numpy() * 4
  614. score = result["score"][i].cpu().numpy()
  615. # 调用绘图函数并获取图像
  616. image_path = f"{prefix}_vecl_b.jpg"
  617. image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path)
  618. # 将numpy数组转换为torch tensor,并写入TensorBoard
  619. image_tensor = transforms.ToTensor()(image_array)
  620. writer.add_image(f'output_epoch', image_tensor, global_step=epoch)
  621. writer.add_image(f'ori_epoch', img, global_step=epoch)
  622. def show_line(img, pred, prefix, epoch, write):
  623. fn = f"{prefix}_line.jpg"
  624. directory = os.path.dirname(fn)
  625. if not os.path.exists(directory):
  626. os.makedirs(directory)
  627. print(fn)
  628. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  629. H = pred
  630. im = img.permute(1, 2, 0)
  631. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  632. scores = H["score"][0].cpu().numpy()
  633. for i in range(1, len(lines)):
  634. if (lines[i] == lines[0]).all():
  635. lines = lines[:i]
  636. scores = scores[:i]
  637. break
  638. # postprocess lines to remove overlapped lines
  639. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  640. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  641. for i, t in enumerate([0.5]):
  642. plt.gca().set_axis_off()
  643. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  644. plt.margins(0, 0)
  645. for (a, b), s in zip(nlines, nscores):
  646. if s < t:
  647. continue
  648. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  649. plt.scatter(a[1], a[0], **PLTOPTS)
  650. plt.scatter(b[1], b[0], **PLTOPTS)
  651. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  652. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  653. plt.imshow(im)
  654. plt.savefig(fn, bbox_inches="tight")
  655. plt.show()
  656. plt.close()
  657. img2 = cv2.imread(fn) # 预测图
  658. # img1 = im.resize(img2.shape) # 原图
  659. # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC')
  660. writer.add_image("output", img2, epoch)
  661. if __name__ == '__main__':
  662. cfg = 'wirenet.yaml'
  663. cfg = read_yaml(cfg)
  664. print(f'cfg:{cfg}')
  665. print(cfg['model']['n_dyn_negl'])
  666. # net = WirepointPredictor()
  667. # if torch.cuda.is_available():
  668. # device_name = "cuda"
  669. # torch.backends.cudnn.deterministic = True
  670. # torch.cuda.manual_seed(0)
  671. # print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  672. # else:
  673. # print("CUDA is not available")
  674. #
  675. # device = torch.device(device_name)
  676. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  677. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  678. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  679. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
  680. train_collate_fn = utils.collate_fn_wirepoint
  681. data_loader_train = torch.utils.data.DataLoader(
  682. dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
  683. )
  684. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  685. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  686. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  687. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
  688. val_collate_fn = utils.collate_fn_wirepoint
  689. data_loader_val = torch.utils.data.DataLoader(
  690. dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
  691. )
  692. model = wirepointrcnn_resnet18_fpn().to(device)
  693. # print(model)
  694. # model1 = wirepointrcnn_resnet50_fpn().to(device)
  695. # print(model1)
  696. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  697. writer = SummaryWriter(cfg['io']['logdir'])
  698. def move_to_device(data, device):
  699. if isinstance(data, (list, tuple)):
  700. return type(data)(move_to_device(item, device) for item in data)
  701. elif isinstance(data, dict):
  702. return {key: move_to_device(value, device) for key, value in data.items()}
  703. elif isinstance(data, torch.Tensor):
  704. return data.to(device)
  705. else:
  706. return data # 对于非张量类型的数据不做任何改变
  707. def writer_loss(writer, losses, epoch):
  708. # ??????
  709. try:
  710. for key, value in losses.items():
  711. if key == 'loss_wirepoint':
  712. # ?? wirepoint ??????
  713. for subdict in losses['loss_wirepoint']['losses']:
  714. for subkey, subvalue in subdict.items():
  715. # ?? .item() ?????
  716. writer.add_scalar(f'loss_wirepoint/{subkey}',
  717. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  718. epoch)
  719. elif isinstance(value, torch.Tensor):
  720. # ????????
  721. writer.add_scalar(key, value.item(), epoch)
  722. except Exception as e:
  723. print(f"TensorBoard logging error: {e}")
  724. for epoch in range(cfg['optim']['max_epoch']):
  725. print(f"epoch:{epoch}")
  726. model.train()
  727. for imgs, targets in data_loader_train:
  728. losses = model(move_to_device(imgs, device), move_to_device(targets, device))
  729. loss = _loss(losses)
  730. print(f"loss:{loss}")
  731. optimizer.zero_grad()
  732. loss.backward()
  733. optimizer.step()
  734. writer_loss(writer, losses, epoch)
  735. # model.eval()
  736. # with torch.no_grad():
  737. # for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  738. # pred = model(move_to_device(imgs, device))
  739. # print(f"pred:{pred}")
  740. #
  741. # if batch_idx == 0:
  742. # result = pred[1]['wires'] # pred[0].keys() ['boxes', 'labels', 'scores']
  743. # print(imgs[0].shape) # [3,512,512]
  744. # # imshow(imgs[0].permute(1, 2, 0)) # 改为(512, 512, 3)
  745. # _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
  746. # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
  747. # imgs, targets = next(iter(data_loader))
  748. #
  749. # model.train()
  750. # pred = model(imgs, targets)
  751. # print(f'pred:{pred}')
  752. # result, losses = model(imgs, targets)
  753. # print(f'result:{result}')
  754. # print(f'pred:{losses}')