roi_heads.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. import torch.nn.functional as F
  4. import torchvision
  5. from torch import nn, Tensor
  6. from libs.vision_libs.ops import boxes as box_ops, roi_align
  7. import libs.vision_libs.models.detection._utils as det_utils
  8. from collections import OrderedDict
  9. def l2loss(input, target):
  10. return ((target - input) ** 2).mean(2).mean(1)
  11. def cross_entropy_loss(logits, positive):
  12. nlogp = -F.log_softmax(logits, dim=0)
  13. return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
  14. def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
  15. logp = torch.sigmoid(logits) + offset
  16. loss = torch.abs(logp - target)
  17. if mask is not None:
  18. w = mask.mean(2, True).mean(1, True)
  19. w[w == 0] = 1
  20. loss = loss * (mask / w)
  21. return loss.mean(2).mean(1)
  22. class DiceLoss(nn.Module):
  23. def __init__(self, smooth=1.):
  24. super(DiceLoss, self).__init__()
  25. self.smooth = smooth
  26. def forward(self, logits, targets):
  27. probs = torch.sigmoid(logits)
  28. probs = probs.view(-1)
  29. targets = targets.view(-1).float()
  30. intersection = (probs * targets).sum()
  31. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  32. return 1. - dice
  33. bce_loss = nn.BCEWithLogitsLoss()
  34. dice_loss = DiceLoss()
  35. def combined_loss(preds, targets, alpha=0.5):
  36. bce = bce_loss(preds, targets)
  37. d = dice_loss(preds, targets)
  38. return alpha * bce + (1 - alpha) * d
  39. ###计算多头损失
  40. def line_head_loss(input_dict, outputs, feature, loss_weight, mode_train):
  41. # image = input_dict["image"]
  42. # target_b = input_dict["target_b"]
  43. # outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"]) # train时aaa是损失,val时是box
  44. result = {"feature": feature}
  45. batch, channel, row, col = outputs[0].shape
  46. T = input_dict["target"].copy()
  47. n_jtyp = T["junc_map"].shape[1]
  48. # switch to CNHW
  49. for task in ["junc_map"]:
  50. T[task] = T[task].permute(1, 0, 2, 3)
  51. for task in ["junc_offset"]:
  52. T[task] = T[task].permute(1, 2, 0, 3, 4)
  53. offset = [2, 3, 5]
  54. losses = []
  55. for stack, output in enumerate(outputs):
  56. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  57. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  58. lmap = output[offset[0]: offset[1]].squeeze(0)
  59. # print(f"lmap:{lmap.shape}")
  60. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  61. if stack == 0:
  62. result["preds"] = {
  63. "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  64. "lmap": lmap.sigmoid(),
  65. "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  66. }
  67. if mode_train == False:
  68. return result
  69. L = OrderedDict()
  70. L["jmap"] = sum(
  71. cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  72. )
  73. L["lmap"] = (
  74. F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  75. .mean(2)
  76. .mean(1)
  77. )
  78. L["joff"] = sum(
  79. sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  80. for i in range(n_jtyp)
  81. for j in range(2)
  82. )
  83. for loss_name in L:
  84. L[loss_name].mul_(loss_weight[loss_name])
  85. losses.append(L)
  86. result["losses"] = losses
  87. # result["aaa"] = aaa
  88. return result
  89. # 计算线性损失
  90. def line_vectorizer_loss(result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc, loss_weight, mode_train):
  91. if mode_train == False:
  92. p = torch.cat(ps)
  93. s = torch.sigmoid(x)
  94. b = s > 0.5
  95. lines = []
  96. score = []
  97. for i in range(n_batch):
  98. p0 = p[idx[i]: idx[i + 1]]
  99. s0 = s[idx[i]: idx[i + 1]]
  100. mask = b[idx[i]: idx[i + 1]]
  101. p0 = p0[mask]
  102. s0 = s0[mask]
  103. if len(p0) == 0:
  104. lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
  105. score.append(torch.zeros([1, n_out_line], device=p.device))
  106. else:
  107. arg = torch.argsort(s0, descending=True)
  108. p0, s0 = p0[arg], s0[arg]
  109. lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
  110. score.append(s0[None, torch.arange(n_out_line) % len(s0)])
  111. for j in range(len(jcs[i])):
  112. if len(jcs[i][j]) == 0:
  113. jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
  114. jcs[i][j] = jcs[i][j][
  115. None, torch.arange(n_out_junc) % len(jcs[i][j])
  116. ]
  117. result["preds"]["lines"] = torch.cat(lines)
  118. result["preds"]["score"] = torch.cat(score)
  119. result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
  120. if len(jcs[i]) > 1:
  121. result["preds"]["junts"] = torch.cat(
  122. [jcs[i][1] for i in range(n_batch)]
  123. )
  124. # if input_dict["mode"] != "testing":
  125. y = torch.cat(ys)
  126. loss = nn.BCEWithLogitsLoss(reduction="none")
  127. loss = loss(x, y)
  128. lpos_mask, lneg_mask = y, 1 - y
  129. loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
  130. def sum_batch(x):
  131. xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
  132. return torch.cat(xs)
  133. lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
  134. lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
  135. result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
  136. result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
  137. if mode_train == True:
  138. del result["preds"]
  139. return result
  140. def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
  141. # output, feature: head返回结果
  142. # x, y, idx : line中间生成结果
  143. result = {}
  144. batch, channel, row, col = output.shape
  145. wires_targets = [t["wires"] for t in targets]
  146. wires_targets = wires_targets.copy()
  147. # print(f'wires_target:{wires_targets}')
  148. # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  149. junc_maps = [d["junc_map"] for d in wires_targets]
  150. junc_offsets = [d["junc_offset"] for d in wires_targets]
  151. line_maps = [d["line_map"] for d in wires_targets]
  152. junc_map_tensor = torch.stack(junc_maps, dim=0)
  153. junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  154. line_map_tensor = torch.stack(line_maps, dim=0)
  155. T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
  156. n_jtyp = T["junc_map"].shape[1]
  157. for task in ["junc_map"]:
  158. T[task] = T[task].permute(1, 0, 2, 3)
  159. for task in ["junc_offset"]:
  160. T[task] = T[task].permute(1, 2, 0, 3, 4)
  161. offset = [2, 3, 5]
  162. losses = []
  163. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  164. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  165. lmap = output[offset[0]: offset[1]].squeeze(0)
  166. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  167. L = OrderedDict()
  168. # L["junc_map"] = sum(
  169. # cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  170. # ).mean()
  171. # L["line_map"] = (
  172. # F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  173. # .mean(2)
  174. # .mean(1)
  175. # ).mean()
  176. L["junc_map"] = combined_loss(jmap[:, 1, :, :, :], T["junc_map"])
  177. L["line_map"] = combined_loss(lmap, T["line_map"])
  178. L["junc_offset"] = sum(
  179. sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  180. for i in range(n_jtyp)
  181. for j in range(2)
  182. ).mean()
  183. for loss_name in L:
  184. L[loss_name].mul_(loss_weight[loss_name])
  185. losses.append(L)
  186. result["losses"] = losses
  187. loss = nn.BCEWithLogitsLoss(reduction="none")
  188. loss = loss(x, y)
  189. lpos_mask, lneg_mask = y, 1 - y
  190. loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
  191. def sum_batch(x):
  192. xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
  193. return torch.cat(xs)
  194. lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
  195. lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
  196. result["losses"][0]["lpos"] = (lpos * loss_weight["lpos"]).mean()
  197. result["losses"][0]["lneg"] = (lneg * loss_weight["lneg"]).mean()
  198. return result
  199. def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
  200. result = {}
  201. result["wires"] = {}
  202. p = torch.cat(ps)
  203. s = torch.sigmoid(input)
  204. b = s > 0.5
  205. lines = []
  206. score = []
  207. # print(f"n_batch:{n_batch}")
  208. for i in range(n_batch):
  209. # print(f"idx:{idx}")
  210. p0 = p[idx[i]: idx[i + 1]]
  211. s0 = s[idx[i]: idx[i + 1]]
  212. mask = b[idx[i]: idx[i + 1]]
  213. p0 = p0[mask]
  214. s0 = s0[mask]
  215. if len(p0) == 0:
  216. lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
  217. score.append(torch.zeros([1, n_out_line], device=p.device))
  218. else:
  219. arg = torch.argsort(s0, descending=True)
  220. p0, s0 = p0[arg], s0[arg]
  221. lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
  222. score.append(s0[None, torch.arange(n_out_line) % len(s0)])
  223. for j in range(len(jcs[i])):
  224. if len(jcs[i][j]) == 0:
  225. jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
  226. jcs[i][j] = jcs[i][j][
  227. None, torch.arange(n_out_junc) % len(jcs[i][j])
  228. ]
  229. result["wires"]["lines"] = torch.cat(lines)
  230. result["wires"]["score"] = torch.cat(score)
  231. result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
  232. if len(jcs[i]) > 1:
  233. result["preds"]["junts"] = torch.cat(
  234. [jcs[i][1] for i in range(n_batch)]
  235. )
  236. # print(f'predic result:{result}')
  237. return result
  238. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  239. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  240. """
  241. Computes the loss for Faster R-CNN.
  242. Args:
  243. class_logits (Tensor)
  244. box_regression (Tensor)
  245. labels (list[BoxList])
  246. regression_targets (Tensor)
  247. Returns:
  248. classification_loss (Tensor)
  249. box_loss (Tensor)
  250. """
  251. labels = torch.cat(labels, dim=0)
  252. regression_targets = torch.cat(regression_targets, dim=0)
  253. classification_loss = F.cross_entropy(class_logits, labels)
  254. # get indices that correspond to the regression targets for
  255. # the corresponding ground truth labels, to be used with
  256. # advanced indexing
  257. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  258. labels_pos = labels[sampled_pos_inds_subset]
  259. N, num_classes = class_logits.shape
  260. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  261. box_loss = F.smooth_l1_loss(
  262. box_regression[sampled_pos_inds_subset, labels_pos],
  263. regression_targets[sampled_pos_inds_subset],
  264. beta=1 / 9,
  265. reduction="sum",
  266. )
  267. box_loss = box_loss / labels.numel()
  268. return classification_loss, box_loss
  269. def maskrcnn_inference(x, labels):
  270. # type: (Tensor, List[Tensor]) -> List[Tensor]
  271. """
  272. From the results of the CNN, post process the masks
  273. by taking the ins corresponding to the class with max
  274. probability (which are of fixed size and directly output
  275. by the CNN) and return the masks in the ins field of the BoxList.
  276. Args:
  277. x (Tensor): the ins logits
  278. labels (list[BoxList]): bounding boxes that are used as
  279. reference, one for ech image
  280. Returns:
  281. results (list[BoxList]): one BoxList for each image, containing
  282. the extra field ins
  283. """
  284. mask_prob = x.sigmoid()
  285. # select masks corresponding to the predicted classes
  286. num_masks = x.shape[0]
  287. boxes_per_image = [label.shape[0] for label in labels]
  288. labels = torch.cat(labels)
  289. index = torch.arange(num_masks, device=labels.device)
  290. mask_prob = mask_prob[index, labels][:, None]
  291. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  292. return mask_prob
  293. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  294. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  295. """
  296. Given segmentation masks and the bounding boxes corresponding
  297. to the location of the masks in the image, this function
  298. crops and resizes the masks in the position defined by the
  299. boxes. This prepares the masks for them to be fed to the
  300. loss computation as the targets.
  301. """
  302. matched_idxs = matched_idxs.to(boxes)
  303. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  304. gt_masks = gt_masks[:, None].to(rois)
  305. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  306. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  307. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  308. """
  309. Args:
  310. proposals (list[BoxList])
  311. mask_logits (Tensor)
  312. targets (list[BoxList])
  313. Return:
  314. mask_loss (Tensor): scalar tensor containing the loss
  315. """
  316. discretization_size = mask_logits.shape[-1]
  317. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  318. mask_targets = [
  319. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  320. ]
  321. labels = torch.cat(labels, dim=0)
  322. mask_targets = torch.cat(mask_targets, dim=0)
  323. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  324. # accept empty tensors, so handle it separately
  325. if mask_targets.numel() == 0:
  326. return mask_logits.sum() * 0
  327. mask_loss = F.binary_cross_entropy_with_logits(
  328. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  329. )
  330. return mask_loss
  331. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  332. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  333. offset_x = rois[:, 0]
  334. offset_y = rois[:, 1]
  335. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  336. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  337. offset_x = offset_x[:, None]
  338. offset_y = offset_y[:, None]
  339. scale_x = scale_x[:, None]
  340. scale_y = scale_y[:, None]
  341. x = keypoints[..., 0]
  342. y = keypoints[..., 1]
  343. x_boundary_inds = x == rois[:, 2][:, None]
  344. y_boundary_inds = y == rois[:, 3][:, None]
  345. x = (x - offset_x) * scale_x
  346. x = x.floor().long()
  347. y = (y - offset_y) * scale_y
  348. y = y.floor().long()
  349. x[x_boundary_inds] = heatmap_size - 1
  350. y[y_boundary_inds] = heatmap_size - 1
  351. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  352. vis = keypoints[..., 2] > 0
  353. valid = (valid_loc & vis).long()
  354. lin_ind = y * heatmap_size + x
  355. heatmaps = lin_ind * valid
  356. return heatmaps, valid
  357. def _onnx_heatmaps_to_keypoints(
  358. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  359. ):
  360. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  361. width_correction = widths_i / roi_map_width
  362. height_correction = heights_i / roi_map_height
  363. roi_map = F.interpolate(
  364. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  365. )[:, 0]
  366. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  367. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  368. x_int = pos % w
  369. y_int = (pos - x_int) // w
  370. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  371. dtype=torch.float32
  372. )
  373. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  374. dtype=torch.float32
  375. )
  376. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  377. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  378. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  379. xy_preds_i = torch.stack(
  380. [
  381. xy_preds_i_0.to(dtype=torch.float32),
  382. xy_preds_i_1.to(dtype=torch.float32),
  383. xy_preds_i_2.to(dtype=torch.float32),
  384. ],
  385. 0,
  386. )
  387. # TODO: simplify when indexing without rank will be supported by ONNX
  388. base = num_keypoints * num_keypoints + num_keypoints + 1
  389. ind = torch.arange(num_keypoints)
  390. ind = ind.to(dtype=torch.int64) * base
  391. end_scores_i = (
  392. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  393. .index_select(2, x_int.to(dtype=torch.int64))
  394. .view(-1)
  395. .index_select(0, ind.to(dtype=torch.int64))
  396. )
  397. return xy_preds_i, end_scores_i
  398. @torch.jit._script_if_tracing
  399. def _onnx_heatmaps_to_keypoints_loop(
  400. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  401. ):
  402. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  403. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  404. for i in range(int(rois.size(0))):
  405. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  406. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  407. )
  408. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  409. end_scores = torch.cat(
  410. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  411. )
  412. return xy_preds, end_scores
  413. def heatmaps_to_keypoints(maps, rois):
  414. """Extract predicted keypoint locations from heatmaps. Output has shape
  415. (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  416. for each keypoint.
  417. """
  418. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  419. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  420. # consistency with keypoints_to_heatmap_labels by using the conversion from
  421. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  422. # continuous coordinate.
  423. offset_x = rois[:, 0]
  424. offset_y = rois[:, 1]
  425. widths = rois[:, 2] - rois[:, 0]
  426. heights = rois[:, 3] - rois[:, 1]
  427. widths = widths.clamp(min=1)
  428. heights = heights.clamp(min=1)
  429. widths_ceil = widths.ceil()
  430. heights_ceil = heights.ceil()
  431. num_keypoints = maps.shape[1]
  432. if torchvision._is_tracing():
  433. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  434. maps,
  435. rois,
  436. widths_ceil,
  437. heights_ceil,
  438. widths,
  439. heights,
  440. offset_x,
  441. offset_y,
  442. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  443. )
  444. return xy_preds.permute(0, 2, 1), end_scores
  445. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  446. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  447. for i in range(len(rois)):
  448. roi_map_width = int(widths_ceil[i].item())
  449. roi_map_height = int(heights_ceil[i].item())
  450. width_correction = widths[i] / roi_map_width
  451. height_correction = heights[i] / roi_map_height
  452. roi_map = F.interpolate(
  453. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  454. )[:, 0]
  455. # roi_map_probs = scores_to_probs(roi_map.copy())
  456. w = roi_map.shape[2]
  457. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  458. x_int = pos % w
  459. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  460. # assert (roi_map_probs[k, y_int, x_int] ==
  461. # roi_map_probs[k, :, :].max())
  462. x = (x_int.float() + 0.5) * width_correction
  463. y = (y_int.float() + 0.5) * height_correction
  464. xy_preds[i, 0, :] = x + offset_x[i]
  465. xy_preds[i, 1, :] = y + offset_y[i]
  466. xy_preds[i, 2, :] = 1
  467. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  468. return xy_preds.permute(0, 2, 1), end_scores
  469. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  470. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  471. N, K, H, W = keypoint_logits.shape
  472. if H != W:
  473. raise ValueError(
  474. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  475. )
  476. discretization_size = H
  477. heatmaps = []
  478. valid = []
  479. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  480. kp = gt_kp_in_image[midx]
  481. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  482. heatmaps.append(heatmaps_per_image.view(-1))
  483. valid.append(valid_per_image.view(-1))
  484. keypoint_targets = torch.cat(heatmaps, dim=0)
  485. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  486. valid = torch.where(valid)[0]
  487. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  488. # accept empty tensors, so handle it sepaartely
  489. if keypoint_targets.numel() == 0 or len(valid) == 0:
  490. return keypoint_logits.sum() * 0
  491. keypoint_logits = keypoint_logits.view(N * K, H * W)
  492. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  493. return keypoint_loss
  494. def keypointrcnn_inference(x, boxes):
  495. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  496. kp_probs = []
  497. kp_scores = []
  498. boxes_per_image = [box.size(0) for box in boxes]
  499. x2 = x.split(boxes_per_image, dim=0)
  500. for xx, bb in zip(x2, boxes):
  501. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  502. kp_probs.append(kp_prob)
  503. kp_scores.append(scores)
  504. return kp_probs, kp_scores
  505. def _onnx_expand_boxes(boxes, scale):
  506. # type: (Tensor, float) -> Tensor
  507. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  508. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  509. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  510. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  511. w_half = w_half.to(dtype=torch.float32) * scale
  512. h_half = h_half.to(dtype=torch.float32) * scale
  513. boxes_exp0 = x_c - w_half
  514. boxes_exp1 = y_c - h_half
  515. boxes_exp2 = x_c + w_half
  516. boxes_exp3 = y_c + h_half
  517. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  518. return boxes_exp
  519. # the next two functions should be merged inside Masker
  520. # but are kept here for the moment while we need them
  521. # temporarily for paste_mask_in_image
  522. def expand_boxes(boxes, scale):
  523. # type: (Tensor, float) -> Tensor
  524. if torchvision._is_tracing():
  525. return _onnx_expand_boxes(boxes, scale)
  526. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  527. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  528. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  529. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  530. w_half *= scale
  531. h_half *= scale
  532. boxes_exp = torch.zeros_like(boxes)
  533. boxes_exp[:, 0] = x_c - w_half
  534. boxes_exp[:, 2] = x_c + w_half
  535. boxes_exp[:, 1] = y_c - h_half
  536. boxes_exp[:, 3] = y_c + h_half
  537. return boxes_exp
  538. @torch.jit.unused
  539. def expand_masks_tracing_scale(M, padding):
  540. # type: (int, int) -> float
  541. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  542. def expand_masks(mask, padding):
  543. # type: (Tensor, int) -> Tuple[Tensor, float]
  544. M = mask.shape[-1]
  545. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  546. scale = expand_masks_tracing_scale(M, padding)
  547. else:
  548. scale = float(M + 2 * padding) / M
  549. padded_mask = F.pad(mask, (padding,) * 4)
  550. return padded_mask, scale
  551. def paste_mask_in_image(mask, box, im_h, im_w):
  552. # type: (Tensor, Tensor, int, int) -> Tensor
  553. TO_REMOVE = 1
  554. w = int(box[2] - box[0] + TO_REMOVE)
  555. h = int(box[3] - box[1] + TO_REMOVE)
  556. w = max(w, 1)
  557. h = max(h, 1)
  558. # Set shape to [batchxCxHxW]
  559. mask = mask.expand((1, 1, -1, -1))
  560. # Resize ins
  561. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  562. mask = mask[0][0]
  563. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  564. x_0 = max(box[0], 0)
  565. x_1 = min(box[2] + 1, im_w)
  566. y_0 = max(box[1], 0)
  567. y_1 = min(box[3] + 1, im_h)
  568. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  569. return im_mask
  570. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  571. one = torch.ones(1, dtype=torch.int64)
  572. zero = torch.zeros(1, dtype=torch.int64)
  573. w = box[2] - box[0] + one
  574. h = box[3] - box[1] + one
  575. w = torch.max(torch.cat((w, one)))
  576. h = torch.max(torch.cat((h, one)))
  577. # Set shape to [batchxCxHxW]
  578. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  579. # Resize ins
  580. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  581. mask = mask[0][0]
  582. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  583. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  584. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  585. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  586. unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  587. # TODO : replace below with a dynamic padding when support is added in ONNX
  588. # pad y
  589. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  590. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  591. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  592. # pad x
  593. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  594. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  595. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  596. return im_mask
  597. @torch.jit._script_if_tracing
  598. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  599. res_append = torch.zeros(0, im_h, im_w)
  600. for i in range(masks.size(0)):
  601. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  602. mask_res = mask_res.unsqueeze(0)
  603. res_append = torch.cat((res_append, mask_res))
  604. return res_append
  605. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  606. # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  607. masks, scale = expand_masks(masks, padding=padding)
  608. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  609. im_h, im_w = img_shape
  610. if torchvision._is_tracing():
  611. return _onnx_paste_masks_in_image_loop(
  612. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  613. )[:, None]
  614. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  615. if len(res) > 0:
  616. ret = torch.stack(res, dim=0)[:, None]
  617. else:
  618. ret = masks.new_empty((0, 1, im_h, im_w))
  619. return ret
  620. class RoIHeads(nn.Module):
  621. __annotations__ = {
  622. "box_coder": det_utils.BoxCoder,
  623. "proposal_matcher": det_utils.Matcher,
  624. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  625. }
  626. def __init__(
  627. self,
  628. box_roi_pool,
  629. box_head,
  630. box_predictor,
  631. line_head,
  632. line_predictor,
  633. # Faster R-CNN training
  634. fg_iou_thresh,
  635. bg_iou_thresh,
  636. batch_size_per_image,
  637. positive_fraction,
  638. bbox_reg_weights,
  639. # Faster R-CNN inference
  640. score_thresh,
  641. nms_thresh,
  642. detections_per_img,
  643. # Mask
  644. mask_roi_pool=None,
  645. mask_head=None,
  646. mask_predictor=None,
  647. keypoint_roi_pool=None,
  648. keypoint_head=None,
  649. keypoint_predictor=None,
  650. ):
  651. super().__init__()
  652. self.box_similarity = box_ops.box_iou
  653. # assign ground-truth boxes for each proposal
  654. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  655. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  656. if bbox_reg_weights is None:
  657. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  658. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  659. self.box_roi_pool = box_roi_pool
  660. self.box_head = box_head
  661. self.box_predictor = box_predictor
  662. self.line_head = line_head
  663. self.line_predictor = line_predictor
  664. self.score_thresh = score_thresh
  665. self.nms_thresh = nms_thresh
  666. self.detections_per_img = detections_per_img
  667. self.mask_roi_pool = mask_roi_pool
  668. self.mask_head = mask_head
  669. self.mask_predictor = mask_predictor
  670. self.keypoint_roi_pool = keypoint_roi_pool
  671. self.keypoint_head = keypoint_head
  672. self.keypoint_predictor = keypoint_predictor
  673. def has_line(self):
  674. # if self.mask_roi_pool is None:
  675. # return False
  676. if self.line_head is None:
  677. return False
  678. if self.line_predictor is None:
  679. return False
  680. return True
  681. def has_mask(self):
  682. if self.mask_roi_pool is None:
  683. return False
  684. if self.mask_head is None:
  685. return False
  686. if self.mask_predictor is None:
  687. return False
  688. return True
  689. def has_keypoint(self):
  690. if self.keypoint_roi_pool is None:
  691. return False
  692. if self.keypoint_head is None:
  693. return False
  694. if self.keypoint_predictor is None:
  695. return False
  696. return True
  697. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  698. # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  699. matched_idxs = []
  700. labels = []
  701. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  702. if gt_boxes_in_image.numel() == 0:
  703. # Background image
  704. device = proposals_in_image.device
  705. clamped_matched_idxs_in_image = torch.zeros(
  706. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  707. )
  708. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  709. else:
  710. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  711. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  712. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  713. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  714. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  715. labels_in_image = labels_in_image.to(dtype=torch.int64)
  716. # Label background (below the low threshold)
  717. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  718. labels_in_image[bg_inds] = 0
  719. # Label ignore proposals (between low and high thresholds)
  720. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  721. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  722. matched_idxs.append(clamped_matched_idxs_in_image)
  723. labels.append(labels_in_image)
  724. return matched_idxs, labels
  725. def subsample(self, labels):
  726. # type: (List[Tensor]) -> List[Tensor]
  727. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  728. sampled_inds = []
  729. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  730. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  731. sampled_inds.append(img_sampled_inds)
  732. return sampled_inds
  733. def add_gt_proposals(self, proposals, gt_boxes):
  734. # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  735. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  736. return proposals
  737. def check_targets(self, targets):
  738. # type: (Optional[List[Dict[str, Tensor]]]) -> None
  739. if targets is None:
  740. raise ValueError("targets should not be None")
  741. if not all(["boxes" in t for t in targets]):
  742. raise ValueError("Every element of targets should have a boxes key")
  743. if not all(["labels" in t for t in targets]):
  744. raise ValueError("Every element of targets should have a labels key")
  745. if self.has_mask():
  746. if not all(["masks" in t for t in targets]):
  747. raise ValueError("Every element of targets should have a masks key")
  748. def select_training_samples(
  749. self,
  750. proposals, # type: List[Tensor]
  751. targets, # type: Optional[List[Dict[str, Tensor]]]
  752. ):
  753. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  754. self.check_targets(targets)
  755. if targets is None:
  756. raise ValueError("targets should not be None")
  757. dtype = proposals[0].dtype
  758. device = proposals[0].device
  759. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  760. gt_labels = [t["labels"] for t in targets]
  761. # append ground-truth bboxes to propos
  762. proposals = self.add_gt_proposals(proposals, gt_boxes)
  763. # get matching gt indices for each proposal
  764. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  765. # sample a fixed proportion of positive-negative proposals
  766. sampled_inds = self.subsample(labels)
  767. matched_gt_boxes = []
  768. num_images = len(proposals)
  769. for img_id in range(num_images):
  770. img_sampled_inds = sampled_inds[img_id]
  771. proposals[img_id] = proposals[img_id][img_sampled_inds]
  772. labels[img_id] = labels[img_id][img_sampled_inds]
  773. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  774. gt_boxes_in_image = gt_boxes[img_id]
  775. if gt_boxes_in_image.numel() == 0:
  776. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  777. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  778. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  779. return proposals, matched_idxs, labels, regression_targets
  780. def postprocess_detections(
  781. self,
  782. class_logits, # type: Tensor
  783. box_regression, # type: Tensor
  784. proposals, # type: List[Tensor]
  785. image_shapes, # type: List[Tuple[int, int]]
  786. ):
  787. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  788. device = class_logits.device
  789. num_classes = class_logits.shape[-1]
  790. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  791. pred_boxes = self.box_coder.decode(box_regression, proposals)
  792. pred_scores = F.softmax(class_logits, -1)
  793. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  794. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  795. all_boxes = []
  796. all_scores = []
  797. all_labels = []
  798. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  799. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  800. # create labels for each prediction
  801. labels = torch.arange(num_classes, device=device)
  802. labels = labels.view(1, -1).expand_as(scores)
  803. # remove predictions with the background label
  804. boxes = boxes[:, 1:]
  805. scores = scores[:, 1:]
  806. labels = labels[:, 1:]
  807. # batch everything, by making every class prediction be a separate instance
  808. boxes = boxes.reshape(-1, 4)
  809. scores = scores.reshape(-1)
  810. labels = labels.reshape(-1)
  811. # remove low scoring boxes
  812. inds = torch.where(scores > self.score_thresh)[0]
  813. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  814. # remove empty boxes
  815. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  816. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  817. # non-maximum suppression, independently done per class
  818. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  819. # keep only topk scoring predictions
  820. keep = keep[: self.detections_per_img]
  821. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  822. all_boxes.append(boxes)
  823. all_scores.append(scores)
  824. all_labels.append(labels)
  825. return all_boxes, all_scores, all_labels
  826. def forward(
  827. self,
  828. features, # type: Dict[str, Tensor]
  829. proposals, # type: List[Tensor]
  830. lines,
  831. image_shapes, # type: List[Tuple[int, int]]
  832. targets=None, # type: Optional[List[Dict[str, Tensor]]]
  833. ):
  834. # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  835. """
  836. Args:
  837. features (List[Tensor])
  838. proposals (List[Tensor[N, 4]])
  839. image_shapes (List[Tuple[H, W]])
  840. targets (List[Dict])
  841. """
  842. # if targets is not None:
  843. # self.training = True
  844. # # print(f'targets is not None')
  845. #
  846. # else:
  847. # self.training = False
  848. # # print(f'targets is None')
  849. # print(f'roihead forward!!!')
  850. if targets is not None:
  851. for t in targets:
  852. # TODO: https://github.com/pytorch/pytorch/issues/26731
  853. floating_point_types = (torch.float, torch.double, torch.half)
  854. if not t["boxes"].dtype in floating_point_types:
  855. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  856. if not t["labels"].dtype == torch.int64:
  857. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  858. if self.has_keypoint():
  859. if not t["keypoints"].dtype == torch.float32:
  860. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  861. if self.training:
  862. print(f'targets:{targets}')
  863. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  864. else:
  865. if targets is not None:
  866. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  867. else:
  868. labels = None
  869. regression_targets = None
  870. matched_idxs = None
  871. box_features = self.box_roi_pool(features, proposals, image_shapes)
  872. box_features = self.box_head(box_features)
  873. class_logits, box_regression = self.box_predictor(box_features)
  874. result: List[Dict[str, torch.Tensor]] = []
  875. losses = {}
  876. if self.training:
  877. if labels is None:
  878. raise ValueError("labels cannot be None")
  879. if regression_targets is None:
  880. raise ValueError("regression_targets cannot be None")
  881. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  882. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  883. else:
  884. if targets is not None:
  885. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  886. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  887. else:
  888. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
  889. num_images = len(boxes)
  890. for i in range(num_images):
  891. result.append(
  892. {
  893. "boxes": boxes[i],
  894. "labels": labels[i],
  895. "scores": scores[i],
  896. "lines":lines[i],
  897. }
  898. )
  899. line_features = features['0']
  900. if self.has_line():
  901. # print('has line_head')
  902. # outputs = self.line_head(features_lcnn)
  903. # outputs = line_features[:, 0:5, :, :]
  904. loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
  905. x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
  906. inputs=line_features, features=line_features, targets=targets)
  907. # # line_loss(multitasklearner)
  908. # if self.training:
  909. # head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=True)
  910. # line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
  911. # loss_weight, mode_train=True)
  912. # else:
  913. # head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=False)
  914. # line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
  915. # loss_weight, mode_train=False)
  916. if self.training:
  917. rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
  918. loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
  919. # print(f'loss_wirepoint:{loss_wirepoint}')
  920. else:
  921. # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
  922. if targets is not None:
  923. rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
  924. loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
  925. else:
  926. print(f'model inference!!!')
  927. pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
  928. result.append(line_features)
  929. result.append(pred)
  930. loss_wirepoint = {}
  931. losses.update(loss_wirepoint)
  932. else:
  933. pass
  934. # print('has not line_head')
  935. if self.has_mask():
  936. mask_proposals = [p["boxes"] for p in result]
  937. if self.training:
  938. if matched_idxs is None:
  939. raise ValueError("if in training, matched_idxs should not be None")
  940. # during training, only focus on positive boxes
  941. num_images = len(proposals)
  942. mask_proposals = []
  943. pos_matched_idxs = []
  944. for img_id in range(num_images):
  945. pos = torch.where(labels[img_id] > 0)[0]
  946. mask_proposals.append(proposals[img_id][pos])
  947. pos_matched_idxs.append(matched_idxs[img_id][pos])
  948. else:
  949. pos_matched_idxs = None
  950. if self.mask_roi_pool is not None:
  951. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  952. mask_features = self.mask_head(mask_features)
  953. mask_logits = self.mask_predictor(mask_features)
  954. else:
  955. raise Exception("Expected mask_roi_pool to be not None")
  956. loss_mask = {}
  957. if self.training:
  958. if targets is None or pos_matched_idxs is None or mask_logits is None:
  959. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  960. gt_masks = [t["masks"] for t in targets]
  961. gt_labels = [t["labels"] for t in targets]
  962. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  963. loss_mask = {"loss_mask": rcnn_loss_mask}
  964. else:
  965. labels = [r["labels"] for r in result]
  966. masks_probs = maskrcnn_inference(mask_logits, labels)
  967. for mask_prob, r in zip(masks_probs, result):
  968. r["masks"] = mask_prob
  969. losses.update(loss_mask)
  970. # keep none checks in if conditional so torchscript will conditionally
  971. # compile each branch
  972. if (
  973. self.line_roi_pool is not None
  974. and self.line_head is not None
  975. and self.line_predictor is not None
  976. ):
  977. keypoint_proposals = [p["boxes"] for p in result]
  978. if self.training:
  979. # during training, only focus on positive boxes
  980. num_images = len(proposals)
  981. keypoint_proposals = []
  982. pos_matched_idxs = []
  983. if matched_idxs is None:
  984. raise ValueError("if in trainning, matched_idxs should not be None")
  985. for img_id in range(num_images):
  986. pos = torch.where(labels[img_id] > 0)[0]
  987. keypoint_proposals.append(proposals[img_id][pos])
  988. pos_matched_idxs.append(matched_idxs[img_id][pos])
  989. else:
  990. pos_matched_idxs = None
  991. keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
  992. keypoint_features = self.line_head(keypoint_features)
  993. keypoint_logits = self.line_predictor(keypoint_features)
  994. loss_keypoint = {}
  995. if self.training:
  996. if targets is None or pos_matched_idxs is None:
  997. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  998. gt_keypoints = [t["keypoints"] for t in targets]
  999. rcnn_loss_keypoint = keypointrcnn_loss(
  1000. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  1001. )
  1002. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  1003. else:
  1004. if keypoint_logits is None or keypoint_proposals is None:
  1005. raise ValueError(
  1006. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  1007. )
  1008. keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  1009. for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
  1010. r["keypoints"] = keypoint_prob
  1011. r["keypoints_scores"] = kps
  1012. losses.update(loss_keypoint)
  1013. # print(f'roi losses:{losses}')
  1014. return result, losses