line_vectorizer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import itertools
  2. import random
  3. from collections import defaultdict
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from lcnn.config import M
  9. FEATURE_DIM = 8
  10. class LineVectorizer(nn.Module):
  11. def __init__(self, backbone):
  12. super().__init__()
  13. self.backbone = backbone
  14. lambda_ = torch.linspace(0, 1, M.n_pts0)[:, None]
  15. self.register_buffer("lambda_", lambda_)
  16. self.do_static_sampling = M.n_stc_posl + M.n_stc_negl > 0
  17. self.fc1 = nn.Conv2d(256, M.dim_loi, 1)
  18. scale_factor = M.n_pts0 // M.n_pts1
  19. if M.use_conv:
  20. self.pooling = nn.Sequential(
  21. nn.MaxPool1d(scale_factor, scale_factor),
  22. Bottleneck1D(M.dim_loi, M.dim_loi),
  23. )
  24. self.fc2 = nn.Sequential(
  25. nn.ReLU(inplace=True), nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, 1)
  26. )
  27. else:
  28. self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
  29. self.fc2 = nn.Sequential(
  30. nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, M.dim_fc),
  31. nn.ReLU(inplace=True),
  32. nn.Linear(M.dim_fc, M.dim_fc),
  33. nn.ReLU(inplace=True),
  34. nn.Linear(M.dim_fc, 1),
  35. )
  36. self.loss = nn.BCEWithLogitsLoss(reduction="none")
  37. def forward(self, input_dict):
  38. result = self.backbone(input_dict)
  39. h = result["preds"]
  40. x = self.fc1(result["feature"])
  41. n_batch, n_channel, row, col = x.shape
  42. xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
  43. for i, meta in enumerate(input_dict["meta"]):
  44. p, label, feat, jc = self.sample_lines(
  45. meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
  46. )
  47. # print("p.shape:", p.shape)
  48. ys.append(label)
  49. if input_dict["mode"] == "training" and self.do_static_sampling:
  50. p = torch.cat([p, meta["lpre"]])
  51. feat = torch.cat([feat, meta["lpre_feat"]])
  52. ys.append(meta["lpre_label"])
  53. del jc
  54. else:
  55. jcs.append(jc)
  56. ps.append(p)
  57. fs.append(feat)
  58. p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
  59. p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
  60. px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
  61. px0 = px.floor().clamp(min=0, max=127)
  62. py0 = py.floor().clamp(min=0, max=127)
  63. px1 = (px0 + 1).clamp(min=0, max=127)
  64. py1 = (py0 + 1).clamp(min=0, max=127)
  65. px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
  66. # xp: [N_LINE, N_CHANNEL, N_POINT]
  67. xp = (
  68. (
  69. x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
  70. + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
  71. + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
  72. + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
  73. )
  74. .reshape(n_channel, -1, M.n_pts0)
  75. .permute(1, 0, 2)
  76. )
  77. xp = self.pooling(xp)
  78. xs.append(xp)
  79. idx.append(idx[-1] + xp.shape[0])
  80. x, y = torch.cat(xs), torch.cat(ys)
  81. f = torch.cat(fs)
  82. x = x.reshape(-1, M.n_pts1 * M.dim_loi)
  83. x = torch.cat([x, f], 1)
  84. x = self.fc2(x.float()).flatten()
  85. if input_dict["mode"] != "training":
  86. p = torch.cat(ps)
  87. s = torch.sigmoid(x)
  88. b = s > 0.5
  89. lines = []
  90. score = []
  91. for i in range(n_batch):
  92. p0 = p[idx[i]: idx[i + 1]]
  93. s0 = s[idx[i]: idx[i + 1]]
  94. mask = b[idx[i]: idx[i + 1]]
  95. p0 = p0[mask]
  96. s0 = s0[mask]
  97. if len(p0) == 0:
  98. lines.append(torch.zeros([1, M.n_out_line, 2, 2], device=p.device))
  99. score.append(torch.zeros([1, M.n_out_line], device=p.device))
  100. else:
  101. arg = torch.argsort(s0, descending=True)
  102. p0, s0 = p0[arg], s0[arg]
  103. lines.append(p0[None, torch.arange(M.n_out_line) % len(p0)])
  104. score.append(s0[None, torch.arange(M.n_out_line) % len(s0)])
  105. for j in range(len(jcs[i])):
  106. if len(jcs[i][j]) == 0:
  107. jcs[i][j] = torch.zeros([M.n_out_junc, 2], device=p.device)
  108. jcs[i][j] = jcs[i][j][
  109. None, torch.arange(M.n_out_junc) % len(jcs[i][j])
  110. ]
  111. result["preds"]["lines"] = torch.cat(lines)
  112. result["preds"]["score"] = torch.cat(score)
  113. result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
  114. # print(result)
  115. result["box"] = result['aaa']
  116. del result['aaa']
  117. if len(jcs[i]) > 1:
  118. result["preds"]["junts"] = torch.cat(
  119. [jcs[i][1] for i in range(n_batch)]
  120. )
  121. if input_dict["mode"] != "testing":
  122. y = torch.cat(ys)
  123. loss = self.loss(x, y)
  124. lpos_mask, lneg_mask = y, 1 - y
  125. loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
  126. def sum_batch(x):
  127. xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
  128. return torch.cat(xs)
  129. lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
  130. lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
  131. result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
  132. result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
  133. if input_dict["mode"] == "training":
  134. for i in result["aaa"].keys():
  135. result["losses"][0][i] = result["aaa"][i]
  136. del result["preds"]
  137. return result
  138. def sample_lines(self, meta, jmap, joff, mode):
  139. with torch.no_grad():
  140. junc = meta["junc_coords"] # [N, 2]
  141. jtyp = meta["jtyp"] # [N]
  142. Lpos = meta["line_pos_idx"]
  143. Lneg = meta["line_neg_idx"]
  144. n_type = jmap.shape[0]
  145. jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
  146. joff = joff.reshape(n_type, 2, -1)
  147. max_K = M.n_dyn_junc // n_type
  148. N = len(junc)
  149. if mode != "training":
  150. K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
  151. else:
  152. K = min(int(N * 2 + 2), max_K)
  153. if K < 2:
  154. K = 2
  155. device = jmap.device
  156. # index: [N_TYPE, K]
  157. score, index = torch.topk(jmap, k=K)
  158. y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
  159. x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
  160. # xy: [N_TYPE, K, 2]
  161. xy = torch.cat([y[..., None], x[..., None]], dim=-1)
  162. xy_ = xy[..., None, :]
  163. del x, y, index
  164. # dist: [N_TYPE, K, N]
  165. dist = torch.sum((xy_ - junc) ** 2, -1)
  166. cost, match = torch.min(dist, -1)
  167. # xy: [N_TYPE * K, 2]
  168. # match: [N_TYPE, K]
  169. for t in range(n_type):
  170. match[t, jtyp[match[t]] != t] = N
  171. match[cost > 1.5 * 1.5] = N
  172. match = match.flatten()
  173. _ = torch.arange(n_type * K, device=device)
  174. u, v = torch.meshgrid(_, _)
  175. u, v = u.flatten(), v.flatten()
  176. up, vp = match[u], match[v]
  177. label = Lpos[up, vp]
  178. if mode == "training":
  179. c = torch.zeros_like(label, dtype=torch.bool)
  180. # sample positive lines
  181. cdx = label.nonzero().flatten()
  182. if len(cdx) > M.n_dyn_posl:
  183. # print("too many positive lines")
  184. perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_posl]
  185. cdx = cdx[perm]
  186. c[cdx] = 1
  187. # sample negative lines
  188. cdx = Lneg[up, vp].nonzero().flatten()
  189. if len(cdx) > M.n_dyn_negl:
  190. # print("too many negative lines")
  191. perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_negl]
  192. cdx = cdx[perm]
  193. c[cdx] = 1
  194. # sample other (unmatched) lines
  195. cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
  196. c[cdx] = 1
  197. else:
  198. c = (u < v).flatten()
  199. # sample lines
  200. u, v, label = u[c], v[c], label[c]
  201. xy = xy.reshape(n_type * K, 2)
  202. xyu, xyv = xy[u], xy[v]
  203. u2v = xyu - xyv
  204. u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
  205. feat = torch.cat(
  206. [
  207. xyu / 128 * M.use_cood,
  208. xyv / 128 * M.use_cood,
  209. u2v * M.use_slop,
  210. (u[:, None] > K).float(),
  211. (v[:, None] > K).float(),
  212. ],
  213. 1,
  214. )
  215. line = torch.cat([xyu[:, None], xyv[:, None]], 1)
  216. xy = xy.reshape(n_type, K, 2)
  217. jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
  218. return line, label.float(), feat, jcs
  219. def non_maximum_suppression(a):
  220. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  221. mask = (a == ap).float().clamp(min=0.0)
  222. return a * mask
  223. class Bottleneck1D(nn.Module):
  224. def __init__(self, inplanes, outplanes):
  225. super(Bottleneck1D, self).__init__()
  226. planes = outplanes // 2
  227. self.op = nn.Sequential(
  228. nn.BatchNorm1d(inplanes),
  229. nn.ReLU(inplace=True),
  230. nn.Conv1d(inplanes, planes, kernel_size=1),
  231. nn.BatchNorm1d(planes),
  232. nn.ReLU(inplace=True),
  233. nn.Conv1d(planes, planes, kernel_size=3, padding=1),
  234. nn.BatchNorm1d(planes),
  235. nn.ReLU(inplace=True),
  236. nn.Conv1d(planes, outplanes, kernel_size=1),
  237. )
  238. def forward(self, x):
  239. return x + self.op(x)