line_predictor.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. from typing import Any, Optional
  2. import torch
  3. from torch import nn
  4. from torchvision.ops import MultiScaleRoIAlign
  5. from libs.vision_libs.ops import misc as misc_nn_ops
  6. from libs.vision_libs.transforms._presets import ObjectDetection
  7. from .roi_heads import RoIHeads
  8. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  9. from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
  10. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  11. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
  12. from libs.vision_libs.models.detection._utils import overwrite_eps
  13. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  14. from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
  15. from models.config.config_tool import read_yaml
  16. import numpy as np
  17. import torch.nn.functional as F
  18. from scipy.ndimage import gaussian_filter
  19. FEATURE_DIM = 8
  20. def non_maximum_suppression(a):
  21. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  22. mask = (a == ap).float().clamp(min=0.0)
  23. return a * mask
  24. class Bottleneck1D(nn.Module):
  25. def __init__(self, inplanes, outplanes):
  26. super(Bottleneck1D, self).__init__()
  27. planes = outplanes // 2
  28. self.op = nn.Sequential(
  29. nn.BatchNorm1d(inplanes),
  30. nn.ReLU(inplace=True),
  31. nn.Conv1d(inplanes, planes, kernel_size=1),
  32. nn.BatchNorm1d(planes),
  33. nn.ReLU(inplace=True),
  34. nn.Conv1d(planes, planes, kernel_size=3, padding=1),
  35. nn.BatchNorm1d(planes),
  36. nn.ReLU(inplace=True),
  37. nn.Conv1d(planes, outplanes, kernel_size=1),
  38. )
  39. def forward(self, x):
  40. return x + self.op(x)
  41. class LineRCNNPredictor(nn.Module):
  42. def __init__(self,n_pts0 = 32,
  43. n_pts1 = 8,
  44. n_stc_posl =300,
  45. dim_loi = 1,
  46. use_conv = 0,
  47. dim_fc = 1024,
  48. n_out_line = 2500,
  49. n_out_junc =250,
  50. n_dyn_junc = 300,
  51. eval_junc_thres = 0.008,
  52. n_dyn_posl =300,
  53. n_dyn_negl =80,
  54. n_dyn_othr = 600,
  55. use_cood = 0,
  56. use_slop = 0,
  57. n_stc_negl = 40,
  58. head_size = [[2], [1], [2]] ,
  59. **kwargs):
  60. super().__init__()
  61. # self.backbone = backbone
  62. # self.cfg = read_yaml(cfg)
  63. # self.cfg = read_yaml(r'./config/wireframe.yaml')
  64. # print(f'linePredictor cfg:{cfg}')
  65. #
  66. # self.cfg = cfg
  67. # self.n_pts0 = self.cfg['n_pts0']
  68. # self.n_pts1 = self.cfg['n_pts1']
  69. # self.n_stc_posl = self.cfg['n_stc_posl']
  70. # self.dim_loi = self.cfg['dim_loi']
  71. # self.use_conv = self.cfg['use_conv']
  72. # self.dim_fc = self.cfg['dim_fc']
  73. # self.n_out_line = self.cfg['n_out_line']
  74. # self.n_out_junc = self.cfg['n_out_junc']
  75. # self.loss_weight = self.cfg['loss_weight']
  76. # self.n_dyn_junc = self.cfg['n_dyn_junc']
  77. # self.eval_junc_thres = self.cfg['eval_junc_thres']
  78. # self.n_dyn_posl = self.cfg['n_dyn_posl']
  79. # self.n_dyn_negl = self.cfg['n_dyn_negl']
  80. # self.n_dyn_othr = self.cfg['n_dyn_othr']
  81. # self.use_cood = self.cfg['use_cood']
  82. # self.use_slop = self.cfg['use_slop']
  83. # self.n_stc_negl = self.cfg['n_stc_negl']
  84. # self.head_size = self.cfg['head_size']
  85. self.n_pts0 = n_pts0
  86. self.n_pts1 = n_pts1
  87. self.n_stc_posl =n_stc_posl
  88. self.dim_loi = dim_loi
  89. self.use_conv = use_conv
  90. self.dim_fc = dim_fc
  91. self.n_out_line = n_out_line
  92. self.n_out_junc =n_out_junc
  93. # self.loss_weight =
  94. self.n_dyn_junc = n_dyn_junc
  95. self.eval_junc_thres = eval_junc_thres
  96. self.n_dyn_posl =n_dyn_posl
  97. self.n_dyn_negl = n_dyn_negl
  98. self.n_dyn_othr = n_dyn_othr
  99. self.use_cood = use_cood
  100. self.use_slop = use_slop
  101. self.n_stc_negl = n_stc_negl
  102. self.head_size = head_size
  103. self.num_class = sum(sum(self.head_size, []))
  104. self.head_off = np.cumsum([sum(h) for h in self.head_size])
  105. lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
  106. self.register_buffer("lambda_", lambda_)
  107. self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
  108. self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
  109. scale_factor = self.n_pts0 // self.n_pts1
  110. if self.use_conv:
  111. self.pooling = nn.Sequential(
  112. nn.MaxPool1d(scale_factor, scale_factor),
  113. Bottleneck1D(self.dim_loi, self.dim_loi),
  114. )
  115. self.fc2 = nn.Sequential(
  116. nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
  117. )
  118. else:
  119. self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
  120. self.fc2 = nn.Sequential(
  121. nn.Linear(self.dim_loi * FEATURE_DIM, self.dim_fc),
  122. nn.ReLU(inplace=True),
  123. nn.Linear(self.dim_fc, self.dim_fc),
  124. nn.ReLU(inplace=True),
  125. nn.Linear(self.dim_fc, 1),
  126. )
  127. self.loss = nn.BCEWithLogitsLoss(reduction="none")
  128. def forward(self, inputs, features, targets=None):
  129. # outputs, features = input
  130. # for out in outputs:
  131. # print(f'out:{out.shape}')
  132. # outputs=merge_features(outputs,100)
  133. batch, channel, row, col = inputs.shape
  134. # print(f'outputs:{inputs.shape}')
  135. # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
  136. if targets is not None:
  137. self.training = True
  138. # print(f'target:{targets}')
  139. wires_targets = [t["wires"] for t in targets]
  140. # print(f'wires_target:{wires_targets}')
  141. # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  142. junc_maps = [d["junc_map"] for d in wires_targets]
  143. junc_offsets = [d["junc_offset"] for d in wires_targets]
  144. line_maps = [d["line_map"] for d in wires_targets]
  145. junc_map_tensor = torch.stack(junc_maps, dim=0)
  146. junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  147. line_map_tensor = torch.stack(line_maps, dim=0)
  148. wires_meta = {
  149. "junc_map": junc_map_tensor,
  150. "junc_offset": junc_offset_tensor,
  151. # "line_map": line_map_tensor,
  152. }
  153. else:
  154. self.training = False
  155. t = {
  156. "junc_coords": torch.zeros(1, 2),
  157. "jtyp": torch.zeros(1, dtype=torch.uint8),
  158. "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
  159. "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
  160. "junc_map": torch.zeros([1, 1, 128, 128]),
  161. "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
  162. }
  163. wires_targets = [t for b in range(inputs.size(0))]
  164. wires_meta = {
  165. "junc_map": torch.zeros([1, 1, 128, 128]),
  166. "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
  167. }
  168. T = wires_meta.copy()
  169. n_jtyp = T["junc_map"].shape[1]
  170. offset = self.head_off
  171. result = {}
  172. print(f' wires_targets len:{len(wires_targets)}')
  173. for stack, output in enumerate([inputs]):
  174. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  175. # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
  176. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  177. # lmap = output[offset[0]: offset[1]].squeeze(0)
  178. lmap = output[offset[0]: offset[1]]
  179. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  180. if stack == 0:
  181. result["preds"] = {
  182. "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  183. "lmap": lmap.sigmoid(),
  184. "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  185. }
  186. # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
  187. # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
  188. # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
  189. h = result["preds"]
  190. print(f'features shape:{features.shape}')
  191. print(f'inputs shape :{inputs.shape}')
  192. # x = self.fc1(features)
  193. lmap = inputs[:,2:3,:,:].sigmoid()
  194. x=lmap
  195. print(f'x:{lmap.shape}')
  196. n_batch, n_channel, row, col = lmap.shape
  197. # n_batch, n_channel, row, col = x.shape
  198. # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
  199. xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
  200. for i, meta in enumerate(wires_targets):
  201. p, label, feat, jc = self.sample_lines(
  202. meta, h["jmap"][i], h["joff"][i],lmap[i]
  203. )
  204. print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
  205. ys.append(label)
  206. if self.training and self.do_static_sampling:
  207. p = torch.cat([p, meta["lpre"]])
  208. feat = torch.cat([feat, meta["lpre_feat"]])
  209. ys.append(meta["lpre_label"])
  210. del jc
  211. else:
  212. jcs.append(jc)
  213. ps.append(p)
  214. fs.append(feat)
  215. #
  216. # p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
  217. # p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
  218. # px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
  219. # px0 = px.floor().clamp(min=0, max=127)
  220. # py0 = py.floor().clamp(min=0, max=127)
  221. # px1 = (px0 + 1).clamp(min=0, max=127)
  222. # py1 = (py0 + 1).clamp(min=0, max=127)
  223. # px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
  224. #
  225. # # xp: [N_LINE, N_CHANNEL, N_POINT]
  226. # xp = (
  227. # (
  228. # x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
  229. # + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
  230. # + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
  231. # + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
  232. # )
  233. # .reshape(n_channel, -1, self.n_pts0)
  234. # .permute(1, 0, 2)
  235. # )
  236. # xp = self.pooling(xp)
  237. # # print(f'xp.shape:{xp.shape}')
  238. # xs.append(xp)
  239. idx.append(idx[-1] + feat.shape[0])
  240. # print(f'idx__:{idx}')
  241. # x, y = torch.cat(xs), torch.cat(ys)
  242. y=torch.cat(ys)
  243. f = torch.cat(fs)
  244. print(f'f:{f.shape}')
  245. # x = x.reshape(-1, self.n_pts1 * self.dim_loi)
  246. # print("Weight dtype:", self.fc2.weight.dtype)
  247. # x = torch.cat([x, f], 1)
  248. # print(f'x3:{x.shape}')
  249. # print("Input dtype:", x.dtype)
  250. f= f.to(dtype=torch.float32)
  251. # x = x.to(dtype=torch.float32)
  252. # print("Input dtype1:", x.dtype)
  253. x = self.fc2(f).flatten()
  254. # return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
  255. return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
  256. # if mode != "training":
  257. # self.inference(x, idx, jcs, n_batch, ps)
  258. # return result
  259. def sample_lines(self, meta, jmap, joff,lmap):
  260. device = jmap.device
  261. with torch.no_grad():
  262. junc = meta["junc_coords"].to(device) # [N, 2]
  263. jtyp = meta["jtyp"].to(device) # [N]
  264. Lpos = meta["line_pos_idx"].to(device)
  265. Lneg = meta["line_neg_idx"].to(device)
  266. n_type = jmap.shape[0]
  267. jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
  268. joff = joff.reshape(n_type, 2, -1)
  269. max_K = self.n_dyn_junc // n_type
  270. N = len(junc)
  271. # if mode != "training":
  272. if not self.training:
  273. K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
  274. else:
  275. K = min(int(N * 2 + 2), max_K)
  276. if K < 2:
  277. K = 2
  278. device = jmap.device
  279. # index: [N_TYPE, K]
  280. score, index = torch.topk(jmap, k=K)
  281. y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
  282. x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
  283. # xy: [N_TYPE, K, 2]
  284. xy = torch.cat([y[..., None], x[..., None]], dim=-1)
  285. xy_ = xy[..., None, :]
  286. del x, y, index
  287. # dist: [N_TYPE, K, N]
  288. dist = torch.sum((xy_ - junc) ** 2, -1)
  289. cost, match = torch.min(dist, -1)
  290. # xy: [N_TYPE * K, 2]
  291. # match: [N_TYPE, K]
  292. for t in range(n_type):
  293. match[t, jtyp[match[t]] != t] = N
  294. match[cost > 1.5 * 1.5] = N
  295. match = match.flatten()
  296. _ = torch.arange(n_type * K, device=device)
  297. u, v = torch.meshgrid(_, _)
  298. u, v = u.flatten(), v.flatten()
  299. up, vp = match[u], match[v]
  300. label = Lpos[up, vp]
  301. # if mode == "training":
  302. if self.training:
  303. c = torch.zeros_like(label, dtype=torch.bool)
  304. # sample positive lines
  305. cdx = label.nonzero().flatten()
  306. if len(cdx) > self.n_dyn_posl:
  307. # print("too many positive lines")
  308. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
  309. cdx = cdx[perm]
  310. c[cdx] = 1
  311. # sample negative lines
  312. cdx = Lneg[up, vp].nonzero().flatten()
  313. if len(cdx) > self.n_dyn_negl:
  314. # print("too many negative lines")
  315. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
  316. cdx = cdx[perm]
  317. c[cdx] = 1
  318. # sample other (unmatched) lines
  319. cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
  320. c[cdx] = 1
  321. else:
  322. c = (u < v).flatten()
  323. # sample lines
  324. u, v, label = u[c], v[c], label[c]
  325. xy = xy.reshape(n_type * K, 2)
  326. xyu, xyv = xy[u], xy[v]
  327. u2v = xyu - xyv
  328. u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
  329. # print(f'xp.shape:{xp.shape}')
  330. feat = torch.cat(
  331. [
  332. xyu / 128 * self.use_cood,
  333. xyv / 128 * self.use_cood,
  334. u2v * self.use_slop,
  335. (u[:, None] > K).float(),
  336. (v[:, None] > K).float(),
  337. ],
  338. 1,
  339. )
  340. print(f'feat shape:{feat.shape}')
  341. # lmap = gaussian_filter(lmap, sigma=1)
  342. # lmap = torch.from_numpy(gaussian_filter(lmap.cpu().numpy(), sigma=1)).to('cuda:0')
  343. line = torch.cat([xyu[:, None], xyv[:, None]], 1)
  344. # print(f'line:{line.shape}')
  345. n_channel, row, col = lmap.shape
  346. p=line
  347. print(f'p.shape :{p.shape}')
  348. p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
  349. p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
  350. px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
  351. px0 = px.floor().clamp(min=0, max=127)
  352. py0 = py.floor().clamp(min=0, max=127)
  353. px1 = (px0 + 1).clamp(min=0, max=127)
  354. py1 = (py0 + 1).clamp(min=0, max=127)
  355. px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
  356. # xp: [N_LINE, N_CHANNEL, N_POINT]
  357. x=lmap
  358. xp = (
  359. (
  360. x[ :, px0l, py0l] * (px1 - px) * (py1 - py)
  361. + x[ :, px1l, py0l] * (px - px0) * (py1 - py)
  362. + x[ :, px0l, py1l] * (px1 - px) * (py - py0)
  363. + x[ :, px1l, py1l] * (px - px0) * (py - py0)
  364. )
  365. .reshape(n_channel, -1, self.n_pts0)
  366. .permute(1, 0, 2)
  367. )
  368. xp = self.pooling(xp).squeeze(1)
  369. print(f'xp shape:{xp.shape}')
  370. xy = xy.reshape(n_type, K, 2)
  371. jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
  372. return line, label.float(), xp, jcs
  373. _COMMON_META = {
  374. "categories": _COCO_PERSON_CATEGORIES,
  375. "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
  376. "min_size": (1, 1),
  377. }