| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711 |
- from typing import Dict, List, Optional, Tuple
- import matplotlib.pyplot as plt
- import torch
- import torch.nn.functional as F
- import torchvision
- # from scipy.optimize import linear_sum_assignment
- from torch import nn, Tensor
- from libs.vision_libs.ops import boxes as box_ops, roi_align
- import libs.vision_libs.models.detection._utils as det_utils
- from collections import OrderedDict
- def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
- # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
- """
- Computes the loss for Faster R-CNN.
- Args:
- class_logits (Tensor)
- box_regression (Tensor)
- labels (list[BoxList])
- regression_targets (Tensor)
- Returns:
- classification_loss (Tensor)
- box_loss (Tensor)
- """
- # print(f'compute fastrcnn_loss:{labels}')
- labels = torch.cat(labels, dim=0)
- regression_targets = torch.cat(regression_targets, dim=0)
- classification_loss = F.cross_entropy(class_logits, labels)
- # get indices that correspond to the regression targets for
- # the corresponding ground truth labels, to be used with
- # advanced indexing
- sampled_pos_inds_subset = torch.where(labels > 0)[0]
- labels_pos = labels[sampled_pos_inds_subset]
- N, num_classes = class_logits.shape
- box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
- box_loss = F.smooth_l1_loss(
- box_regression[sampled_pos_inds_subset, labels_pos],
- regression_targets[sampled_pos_inds_subset],
- beta=1 / 9,
- reduction="sum",
- )
- box_loss = box_loss / labels.numel()
- return classification_loss, box_loss
- def maskrcnn_inference(x, labels):
- # type: (Tensor, List[Tensor]) -> List[Tensor]
- """
- From the results of the CNN, post process the masks
- by taking the mask corresponding to the class with max
- probability (which are of fixed size and directly output
- by the CNN) and return the masks in the mask field of the BoxList.
- Args:
- x (Tensor): the mask logits
- labels (list[BoxList]): bounding boxes that are used as
- reference, one for ech image
- Returns:
- results (list[BoxList]): one BoxList for each image, containing
- the extra field mask
- """
- mask_prob = x.sigmoid()
- # select masks corresponding to the predicted classes
- num_masks = x.shape[0]
- boxes_per_image = [label.shape[0] for label in labels]
- labels = torch.cat(labels)
- index = torch.arange(num_masks, device=labels.device)
- mask_prob = mask_prob[index, labels][:, None]
- mask_prob = mask_prob.split(boxes_per_image, dim=0)
- return mask_prob
- def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
- # type: (Tensor, Tensor, Tensor, int) -> Tensor
- """
- Given segmentation masks and the bounding boxes corresponding
- to the location of the masks in the image, this function
- crops and resizes the masks in the position defined by the
- boxes. This prepares the masks for them to be fed to the
- loss computation as the targets.
- """
- matched_idxs = matched_idxs.to(boxes)
- rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
- gt_masks = gt_masks[:, None].to(rois)
- return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
- def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- """
- Args:
- proposals (list[BoxList])
- mask_logits (Tensor)
- targets (list[BoxList])
- Return:
- mask_loss (Tensor): scalar tensor containing the loss
- """
- discretization_size = mask_logits.shape[-1]
- labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
- mask_targets = [
- project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
- ]
- labels = torch.cat(labels, dim=0)
- mask_targets = torch.cat(mask_targets, dim=0)
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it separately
- if mask_targets.numel() == 0:
- return mask_logits.sum() * 0
- mask_loss = F.binary_cross_entropy_with_logits(
- mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
- )
- return mask_loss
- def normalize_tensor(t):
- return (t - t.min()) / (t.max() - t.min() + 1e-6)
- def line_length(lines):
- """
- 计算每条线段的长度
- lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
- 返回: [N]
- """
- return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
- def line_direction(lines):
- """
- 计算每条线段的单位方向向量
- lines: [N, 2, 2]
- 返回: [N, 2] 单位方向向量
- """
- vec = lines[:, 1] - lines[:, 0]
- return F.normalize(vec, dim=-1)
- def angle_loss_cosine(pred_dir, gt_dir):
- """
- 使用 cosine similarity 计算方向差异
- pred_dir: [N, 2]
- gt_dir: [N, 2]
- 返回: [N]
- """
- cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
- return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
- def line_length(lines):
- """
- 计算每条线段的长度
- lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
- 返回: [N]
- """
- return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
- def line_direction(lines):
- """
- 计算每条线段的单位方向向量
- lines: [N, 2, 2]
- 返回: [N, 2] 单位方向向量
- """
- vec = lines[:, 1] - lines[:, 0]
- return F.normalize(vec, dim=-1)
- def angle_loss_cosine(pred_dir, gt_dir):
- """
- 使用 cosine similarity 计算方向差异
- pred_dir: [N, 2]
- gt_dir: [N, 2]
- 返回: [N]
- """
- cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
- return 1.0 - cos_sim # 或者 torch.acos(cos_sim) / pi 也可
- def single_point_to_heatmap(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tensor
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- x = keypoints[..., 0].unsqueeze(1)
- y = keypoints[..., 1].unsqueeze(1)
- gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=2.0)
- # show_heatmap(gs[0],'target')
- all_roi_heatmap = []
- for roi, heatmap in zip(rois, gs):
- # show_heatmap(heatmap, 'target')
- # print(f'heatmap:{heatmap.shape}')
- heatmap = heatmap.unsqueeze(0)
- x1, y1, x2, y2 = map(int, roi)
- roi_heatmap = torch.zeros_like(heatmap)
- roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
- # show_heatmap(roi_heatmap[0],'roi_heatmap')
- all_roi_heatmap.append(roi_heatmap)
- all_roi_heatmap = torch.cat(all_roi_heatmap)
- print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
- return all_roi_heatmap
- def line_points_to_heatmap(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tensor
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- x = keypoints[..., 0]
- y = keypoints[..., 1]
- gs = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
- # show_heatmap(gs[0],'target')
- all_roi_heatmap = []
- for roi, heatmap in zip(rois, gs):
- # print(f'heatmap:{heatmap.shape}')
- heatmap = heatmap.unsqueeze(0)
- x1, y1, x2, y2 = map(int, roi)
- roi_heatmap = torch.zeros_like(heatmap)
- roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
- # show_heatmap(roi_heatmap,'roi_heatmap')
- all_roi_heatmap.append(roi_heatmap)
- all_roi_heatmap = torch.cat(all_roi_heatmap)
- print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
- return all_roi_heatmap
- """
- 修改适配的原结构的点 转热图,适用于带roi_pool版本的
- """
- def line_points_to_heatmap_(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
- scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
- offset_x = offset_x[:, None]
- offset_y = offset_y[:, None]
- scale_x = scale_x[:, None]
- scale_y = scale_y[:, None]
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- x = keypoints[..., 0]
- y = keypoints[..., 1]
- # gs=generate_gaussian_heatmaps(x,y,512,1.0)
- # print(f'gs_heatmap shape:{gs.shape}')
- #
- # show_heatmap(gs[0],'target')
- x_boundary_inds = x == rois[:, 2][:, None]
- y_boundary_inds = y == rois[:, 3][:, None]
- x = (x - offset_x) * scale_x
- x = x.floor().long()
- y = (y - offset_y) * scale_y
- y = y.floor().long()
- x[x_boundary_inds] = heatmap_size - 1
- y[y_boundary_inds] = heatmap_size - 1
- # print(f'heatmaps x:{x}')
- # print(f'heatmaps y:{y}')
- valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
- vis = keypoints[..., 2] > 0
- valid = (valid_loc & vis).long()
- gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
- show_heatmap(gs_heatmap[0], 'feature')
- # print(f'gs_heatmap:{gs_heatmap.shape}')
- #
- # lin_ind = y * heatmap_size + x
- # print(f'lin_ind:{lin_ind.shape}')
- # heatmaps = lin_ind * valid
- return gs_heatmap
- def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, device='cuda'):
- """
- 为一组点生成并合并高斯热图。
- Args:
- xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
- ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
- heatmap_size (int): 热图大小 H=W
- sigma (float): 高斯核标准差
- device (str): 设备类型 ('cpu' or 'cuda')
- Returns:
- Tensor: 形状为 (H, W) 的合并后的热图
- """
- assert xs.shape == ys.shape, "x and y must have the same shape"
- print(f'xs:{xs.shape}')
- N = xs.shape[0]
- print(f'N:{N},num_points:{num_points}')
- # 创建网格
- grid_y, grid_x = torch.meshgrid(
- torch.arange(heatmap_size, device=device),
- torch.arange(heatmap_size, device=device),
- indexing='ij'
- )
- # print(f'heatmap_size:{heatmap_size}')
- # 初始化输出热图
- combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
- for i in range(N):
- heatmap= torch.zeros((heatmap_size, heatmap_size), device=device)
- for j in range(num_points):
- mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
- mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
- # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
- # 计算距离平方
- dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
- # 计算高斯分布
- heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
- heatmap+=heatmap1
- # mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
- # mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
- #
- # # 计算距离平方
- # dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
- #
- # # 计算高斯分布
- # heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
- #
- # heatmap = heatmap1 + heatmap2
- # 将当前热图累加到结果中
- combined_heatmap[i] = heatmap
- return combined_heatmap
- # 显示热图的函数
- def show_heatmap(heatmap, title="Heatmap"):
- """
- 使用 matplotlib 显示热图。
- Args:
- heatmap (Tensor): 要显示的热图张量
- title (str): 图表标题
- """
- # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
- if heatmap.is_cuda:
- heatmap = heatmap.cpu().numpy()
- else:
- heatmap = heatmap.numpy()
- plt.imshow(heatmap, cmap='hot', interpolation='nearest')
- plt.colorbar()
- plt.title(title)
- plt.show()
- def keypoints_to_heatmap(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
- scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
- offset_x = offset_x[:, None]
- offset_y = offset_y[:, None]
- scale_x = scale_x[:, None]
- scale_y = scale_y[:, None]
- x = keypoints[..., 0]
- y = keypoints[..., 1]
- x_boundary_inds = x == rois[:, 2][:, None]
- y_boundary_inds = y == rois[:, 3][:, None]
- x = (x - offset_x) * scale_x
- x = x.floor().long()
- y = (y - offset_y) * scale_y
- y = y.floor().long()
- x[x_boundary_inds] = heatmap_size - 1
- y[y_boundary_inds] = heatmap_size - 1
- valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
- vis = keypoints[..., 2] > 0
- valid = (valid_loc & vis).long()
- lin_ind = y * heatmap_size + x
- heatmaps = lin_ind * valid
- return heatmaps, valid
- def _onnx_heatmaps_to_keypoints(
- maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
- ):
- num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
- width_correction = widths_i / roi_map_width
- height_correction = heights_i / roi_map_height
- roi_map = F.interpolate(
- maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
- )[:, 0]
- w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = (pos - x_int) // w
- x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
- dtype=torch.float32
- )
- y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
- dtype=torch.float32
- )
- xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
- xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
- xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
- xy_preds_i = torch.stack(
- [
- xy_preds_i_0.to(dtype=torch.float32),
- xy_preds_i_1.to(dtype=torch.float32),
- xy_preds_i_2.to(dtype=torch.float32),
- ],
- 0,
- )
- # TODO: simplify when indexing without rank will be supported by ONNX
- base = num_keypoints * num_keypoints + num_keypoints + 1
- ind = torch.arange(num_keypoints)
- ind = ind.to(dtype=torch.int64) * base
- end_scores_i = (
- roi_map.index_select(1, y_int.to(dtype=torch.int64))
- .index_select(2, x_int.to(dtype=torch.int64))
- .view(-1)
- .index_select(0, ind.to(dtype=torch.int64))
- )
- return xy_preds_i, end_scores_i
- @torch.jit._script_if_tracing
- def _onnx_heatmaps_to_keypoints_loop(
- maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
- ):
- xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- for i in range(int(rois.size(0))):
- xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
- maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
- )
- xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
- end_scores = torch.cat(
- (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
- )
- return xy_preds, end_scores
- def heatmaps_to_keypoints(maps, rois):
- """Extract predicted keypoint locations from heatmaps. Output has shape
- (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- for each keypoint.
- """
- # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # consistency with keypoints_to_heatmap_labels by using the conversion from
- # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # continuous coordinate.
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- widths = rois[:, 2] - rois[:, 0]
- heights = rois[:, 3] - rois[:, 1]
- widths = widths.clamp(min=1)
- heights = heights.clamp(min=1)
- widths_ceil = widths.ceil()
- heights_ceil = heights.ceil()
- num_keypoints = maps.shape[1]
- if torchvision._is_tracing():
- xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
- maps,
- rois,
- widths_ceil,
- heights_ceil,
- widths,
- heights,
- offset_x,
- offset_y,
- torch.scalar_tensor(num_keypoints, dtype=torch.int64),
- )
- return xy_preds.permute(0, 2, 1), end_scores
- xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
- for i in range(len(rois)):
- roi_map_width = int(widths_ceil[i].item())
- roi_map_height = int(heights_ceil[i].item())
- width_correction = widths[i] / roi_map_width
- height_correction = heights[i] / roi_map_height
- roi_map = F.interpolate(
- maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
- )[:, 0]
- # roi_map_probs = scores_to_probs(roi_map.copy())
- w = roi_map.shape[2]
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = torch.div(pos - x_int, w, rounding_mode="floor")
- # assert (roi_map_probs[k, y_int, x_int] ==
- # roi_map_probs[k, :, :].max())
- x = (x_int.float() + 0.5) * width_correction
- y = (y_int.float() + 0.5) * height_correction
- xy_preds[i, 0, :] = x + offset_x[i]
- xy_preds[i, 1, :] = y + offset_y[i]
- xy_preds[i, 2, :] = 1
- end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
- return xy_preds.permute(0, 2, 1), end_scores
- def non_maximum_suppression(a):
- ap = F.max_pool2d(a, 3, stride=1, padding=1)
- mask = (a == ap).float().clamp(min=0.0)
- return a * mask
- def heatmaps_to_lines(maps, rois):
- """Extract predicted keypoint locations from heatmaps. Output has shape
- (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- for each keypoint.
- """
- # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # consistency with keypoints_to_heatmap_labels by using the conversion from
- # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # continuous coordinate.
- line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
- line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
- point_preds = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
- point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
- print(f'heatmaps_to_lines:{maps.shape}')
- point_maps=maps[:,0]
- line_maps=maps[:,1]
- print(f'point_map:{point_maps.shape}')
- for i in range(len(rois)):
- line_roi_map = line_maps[i].unsqueeze(0)
- print(f'line_roi_map:{line_roi_map.shape}')
- # roi_map_probs = scores_to_probs(roi_map.copy())
- w = line_roi_map.shape[1]
- flatten_line_roi_map = non_maximum_suppression(line_roi_map).reshape(1, -1)
- line_score, line_index = torch.topk(flatten_line_roi_map, k=2)
- print(f'line index:{line_index}')
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- pos = line_index
- line_x = pos % w
- line_y = torch.div(pos - line_x, w, rounding_mode="floor")
- line_preds[i, 0, :] = line_x
- line_preds[i, 1, :] = line_y
- line_preds[i, 2, :] = 1
- line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
- point_roi_map = point_maps[i].unsqueeze(0)
- print(f'point_roi_map:{point_roi_map.shape}')
- # roi_map_probs = scores_to_probs(roi_map.copy())
- w = point_roi_map.shape[2]
- flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
- point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
- print(f'point index:{point_index}')
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- point_x =point_index % w
- point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
- point_preds[i, 0,] = point_x
- point_preds[i, 1,] = point_y
- point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
- return line_preds.permute(0, 2, 1), line_end_scores,point_preds,point_end_scores
- def lines_features_align(features, proposals, img_size):
- print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
- align_feat_list = []
- for feat, proposals_per_img in zip(features, proposals):
- print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
- if proposals_per_img.shape[0]>0:
- feat = feat.unsqueeze(0)
- for proposal in proposals_per_img:
- align_feat = torch.zeros_like(feat)
- # print(f'align_feat:{align_feat.shape}')
- x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
- # 将每个proposal框内的部分赋值到align_feats对应位置
- align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
- align_feat_list.append(align_feat)
- # print(f'align_feat_list:{align_feat_list}')
- feats_tensor = torch.cat(align_feat_list)
- print(f'align features :{feats_tensor.shape}')
- return feats_tensor
- def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = line_logits.shape
- len_proposals = len(proposals)
- print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
- if H != W:
- raise ValueError(
- f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- heatmaps = []
- gs_heatmaps = []
- valid = []
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
- print(f'line_proposals_per_image:{proposals_per_image.shape}')
- print(f'gt_lines:{gt_lines}')
- kp = gt_kp_in_image[midx]
- gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
- gs_heatmaps.append(gs_heatmaps_per_img)
- # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
- # heatmaps.append(heatmaps_per_image.view(-1))
- # valid.append(valid_per_image.view(-1))
- # line_targets = torch.cat(heatmaps, dim=0)
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
- # print(f'line_targets:{line_targets.shape},{line_targets}')
- # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
- # valid = torch.where(valid)[0]
- # print(f' line_targets[valid]:{line_targets[valid]}')
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it sepaartely
- # if line_targets.numel() == 0 or len(valid) == 0:
- # return line_logits.sum() * 0
- # line_logits = line_logits.view(N * K, H * W)
- # print(f'line_logits[valid]:{line_logits[valid].shape}')
- line_logits = line_logits.squeeze(1)
- # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
- line_loss = F.cross_entropy(line_logits, gs_heatmaps)
- return line_loss
- def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = line_logits.shape
- len_proposals = len(proposals)
- empty_count = 0
- non_empty_count = 0
- for prop in proposals:
- if prop.shape[0] == 0:
- empty_count += 1
- else:
- non_empty_count += 1
- print(f"Empty proposals count: {empty_count}")
- print(f"Non-empty proposals count: {non_empty_count}")
- print(f'starte to compute_point_loss')
- print(f'compute_point_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
- if H != W:
- raise ValueError(
- f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- gs_heatmaps = []
- # print(f'point_matched_idxs:{point_matched_idxs}')
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_points, point_matched_idxs):
- print(f'proposals_per_image:{proposals_per_image.shape}')
- kp = gt_kp_in_image[midx]
- # print(f'gt_kp_in_image:{gt_kp_in_image}')
- gs_heatmaps_per_img = single_point_to_heatmap(kp, proposals_per_image, discretization_size)
- gs_heatmaps.append(gs_heatmaps_per_img)
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
- line_logits = line_logits[:,0]
- print(f'single_point_logits:{line_logits.shape}')
- line_loss = F.cross_entropy(line_logits, gs_heatmaps)
- return line_loss
- def lines_to_boxes(lines, img_size=511):
- """
- 输入:
- lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
- img_size: int,图像尺寸,用于 clamp 边界
- 输出:
- boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
- """
- # 提取所有线段的两个端点
- p1 = lines[:, 0] # (N, 2)
- p2 = lines[:, 1] # (N, 2)
- # 每条线段的 x 和 y 坐标
- x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
- y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
- # 计算包围盒边界
- x_min = x_coords.min(dim=1).values
- y_min = y_coords.min(dim=1).values
- x_max = x_coords.max(dim=1).values
- y_max = y_coords.max(dim=1).values
- # 扩展边界并限制在图像范围内
- x_min = (x_min - 1).clamp(min=0, max=img_size)
- y_min = (y_min - 1).clamp(min=0, max=img_size)
- x_max = (x_max + 1).clamp(min=0, max=img_size)
- y_max = (y_max + 1).clamp(min=0, max=img_size)
- # 合成包围盒
- boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
- return boxes
- def box_iou_pairwise(box1, box2):
- """
- 输入:
- box1: shape (N, 4)
- box2: shape (M, 4)
- 输出:
- ious: shape (min(N, M), ), 只计算 i = j 的配对
- """
- N = min(len(box1), len(box2))
- lt = torch.max(box1[:N, :2], box2[:N, :2]) # å·¦ä¸è§
- rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # å³ä¸è§
- wh = (rb - lt).clamp(min=0) # 宽é«
- inter_area = wh[:, 0] * wh[:, 1] # 交éé¢ç§¯
- area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
- area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
- union_area = area1 + area2 - inter_area
- ious = inter_area / (union_area + 1e-6)
- return ious
- def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
- """
- Args:
- x: [N,1,H,W] 热力图
- boxes: [N,4] 框坐标
- gt_lines: [N,2,3] GT线段(含可见性)
- matched_idx: 匹配 index
- img_size: 图像尺寸
- alpha: IoU 损失权重
- beta: 长度损失权重
- gamma: 方向角度损失权重
- """
- losses = []
- boxes_per_image = [box.size(0) for box in boxes]
- x2 = x.split(boxes_per_image, dim=0)
- for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
- p_prob, _ = heatmaps_to_lines(xx, bb)
- pred_lines = p_prob
- gt_line_points = gt_line[mid]
- if len(pred_lines) == 0 or len(gt_line_points) == 0:
- continue
- # IoU 损失
- pred_boxes = lines_to_boxes(pred_lines, img_size)
- gt_boxes = lines_to_boxes(gt_line_points, img_size)
- ious = box_iou_pairwise(pred_boxes, gt_boxes)
- iou_loss = 1.0 - ious # [N]
- # 长度损失
- pred_len = line_length(pred_lines)
- gt_len = line_length(gt_line_points)
- length_diff = F.l1_loss(pred_len, gt_len, reduction='none') # [N]
- # 方向角度损失
- pred_dir = line_direction(pred_lines)
- gt_dir = line_direction(gt_line_points)
- ang_loss = angle_loss_cosine(pred_dir, gt_dir) # [N]
- # 归一化每一项损失
- norm_iou = normalize_tensor(iou_loss)
- norm_len = normalize_tensor(length_diff)
- norm_ang = normalize_tensor(ang_loss)
- total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
- losses.append(total)
- if not losses:
- return None
- return torch.mean(torch.cat(losses))
- def line_inference(x, boxes):
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- lines_probs = []
- lines_scores = []
- points_probs = []
- points_scores = []
- boxes_per_image = [box.size(0) for box in boxes]
- x2 = x.split(boxes_per_image, dim=0)
- for xx, bb in zip(x2, boxes):
- line_prob, line_scores,point_prob,point_scores = heatmaps_to_lines(xx, bb)
- lines_probs.append(line_prob)
- lines_scores.append(line_scores)
- points_probs.append(point_prob.unsqueeze(1))
- points_scores.append(point_scores)
- return lines_probs, lines_scores,points_probs,points_scores
- def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = keypoint_logits.shape
- if H != W:
- raise ValueError(
- f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- heatmaps = []
- valid = []
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
- kp = gt_kp_in_image[midx]
- heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
- heatmaps.append(heatmaps_per_image.view(-1))
- valid.append(valid_per_image.view(-1))
- keypoint_targets = torch.cat(heatmaps, dim=0)
- valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
- valid = torch.where(valid)[0]
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it sepaartely
- if keypoint_targets.numel() == 0 or len(valid) == 0:
- return keypoint_logits.sum() * 0
- keypoint_logits = keypoint_logits.view(N * K, H * W)
- keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
- return keypoint_loss
- def keypointrcnn_inference(x, boxes):
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- kp_probs = []
- kp_scores = []
- boxes_per_image = [box.size(0) for box in boxes]
- x2 = x.split(boxes_per_image, dim=0)
- for xx, bb in zip(x2, boxes):
- kp_prob, scores = heatmaps_to_keypoints(xx, bb)
- kp_probs.append(kp_prob)
- kp_scores.append(scores)
- return kp_probs, kp_scores
- def _onnx_expand_boxes(boxes, scale):
- # type: (Tensor, float) -> Tensor
- w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- w_half = w_half.to(dtype=torch.float32) * scale
- h_half = h_half.to(dtype=torch.float32) * scale
- boxes_exp0 = x_c - w_half
- boxes_exp1 = y_c - h_half
- boxes_exp2 = x_c + w_half
- boxes_exp3 = y_c + h_half
- boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
- return boxes_exp
- # the next two functions should be merged inside Masker
- # but are kept here for the moment while we need them
- # temporarily for paste_mask_in_image
- def expand_boxes(boxes, scale):
- # type: (Tensor, float) -> Tensor
- if torchvision._is_tracing():
- return _onnx_expand_boxes(boxes, scale)
- w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- w_half *= scale
- h_half *= scale
- boxes_exp = torch.zeros_like(boxes)
- boxes_exp[:, 0] = x_c - w_half
- boxes_exp[:, 2] = x_c + w_half
- boxes_exp[:, 1] = y_c - h_half
- boxes_exp[:, 3] = y_c + h_half
- return boxes_exp
- @torch.jit.unused
- def expand_masks_tracing_scale(M, padding):
- # type: (int, int) -> float
- return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
- def expand_masks(mask, padding):
- # type: (Tensor, int) -> Tuple[Tensor, float]
- M = mask.shape[-1]
- if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
- scale = expand_masks_tracing_scale(M, padding)
- else:
- scale = float(M + 2 * padding) / M
- padded_mask = F.pad(mask, (padding,) * 4)
- return padded_mask, scale
- def paste_mask_in_image(mask, box, im_h, im_w):
- # type: (Tensor, Tensor, int, int) -> Tensor
- TO_REMOVE = 1
- w = int(box[2] - box[0] + TO_REMOVE)
- h = int(box[3] - box[1] + TO_REMOVE)
- w = max(w, 1)
- h = max(h, 1)
- # Set shape to [batchxCxHxW]
- mask = mask.expand((1, 1, -1, -1))
- # Resize mask
- mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
- mask = mask[0][0]
- im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
- x_0 = max(box[0], 0)
- x_1 = min(box[2] + 1, im_w)
- y_0 = max(box[1], 0)
- y_1 = min(box[3] + 1, im_h)
- im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- return im_mask
- def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
- one = torch.ones(1, dtype=torch.int64)
- zero = torch.zeros(1, dtype=torch.int64)
- w = box[2] - box[0] + one
- h = box[3] - box[1] + one
- w = torch.max(torch.cat((w, one)))
- h = torch.max(torch.cat((h, one)))
- # Set shape to [batchxCxHxW]
- mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
- # Resize mask
- mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
- mask = mask[0][0]
- x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
- x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
- y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
- y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
- unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- # TODO : replace below with a dynamic padding when support is added in ONNX
- # pad y
- zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
- zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
- concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
- # pad x
- zeros_x0 = torch.zeros(concat_0.size(0), x_0)
- zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
- im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
- return im_mask
- @torch.jit._script_if_tracing
- def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
- res_append = torch.zeros(0, im_h, im_w)
- for i in range(masks.size(0)):
- mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
- mask_res = mask_res.unsqueeze(0)
- res_append = torch.cat((res_append, mask_res))
- return res_append
- def paste_masks_in_image(masks, boxes, img_shape, padding=1):
- # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
- masks, scale = expand_masks(masks, padding=padding)
- boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
- im_h, im_w = img_shape
- if torchvision._is_tracing():
- return _onnx_paste_masks_in_image_loop(
- masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
- )[:, None]
- res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
- if len(res) > 0:
- ret = torch.stack(res, dim=0)[:, None]
- else:
- ret = masks.new_empty((0, 1, im_h, im_w))
- return ret
- class RoIHeads(nn.Module):
- __annotations__ = {
- "box_coder": det_utils.BoxCoder,
- "proposal_matcher": det_utils.Matcher,
- "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
- }
- def __init__(
- self,
- box_roi_pool,
- box_head,
- box_predictor,
- # Faster R-CNN training
- fg_iou_thresh,
- bg_iou_thresh,
- batch_size_per_image,
- positive_fraction,
- bbox_reg_weights,
- # Faster R-CNN inference
- score_thresh,
- nms_thresh,
- detections_per_img,
- # Line
- line_roi_pool=None,
- line_head=None,
- line_predictor=None,
- # Mask
- mask_roi_pool=None,
- mask_head=None,
- mask_predictor=None,
- keypoint_roi_pool=None,
- keypoint_head=None,
- keypoint_predictor=None,
- ):
- super().__init__()
- self.box_similarity = box_ops.box_iou
- # assign ground-truth boxes for each proposal
- self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
- self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
- if bbox_reg_weights is None:
- bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
- self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
- self.box_roi_pool = box_roi_pool
- self.box_head = box_head
- self.box_predictor = box_predictor
- self.score_thresh = score_thresh
- self.nms_thresh = nms_thresh
- self.detections_per_img = detections_per_img
- self.line_roi_pool = line_roi_pool
- self.line_head = line_head
- self.line_predictor = line_predictor
- self.mask_roi_pool = mask_roi_pool
- self.mask_head = mask_head
- self.mask_predictor = mask_predictor
- self.keypoint_roi_pool = keypoint_roi_pool
- self.keypoint_head = keypoint_head
- self.keypoint_predictor = keypoint_predictor
- self.channel_compress = nn.Sequential(
- nn.Conv2d(256, 8, kernel_size=1),
- nn.BatchNorm2d(8),
- nn.ReLU(inplace=True)
- )
- def has_mask(self):
- if self.mask_roi_pool is None:
- return False
- if self.mask_head is None:
- return False
- if self.mask_predictor is None:
- return False
- return True
- def has_keypoint(self):
- if self.keypoint_roi_pool is None:
- return False
- if self.keypoint_head is None:
- return False
- if self.keypoint_predictor is None:
- return False
- return True
- def has_line(self):
- # if self.line_roi_pool is None:
- # return False
- if self.line_head is None:
- return False
- # if self.line_predictor is None:
- # return False
- return True
- def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
- # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- matched_idxs = []
- labels = []
- for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
- if gt_boxes_in_image.numel() == 0:
- # Background image
- device = proposals_in_image.device
- clamped_matched_idxs_in_image = torch.zeros(
- (proposals_in_image.shape[0],), dtype=torch.int64, device=device
- )
- labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
- else:
- # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
- match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
- matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
- clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
- labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
- labels_in_image = labels_in_image.to(dtype=torch.int64)
- # Label background (below the low threshold)
- bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
- labels_in_image[bg_inds] = 0
- # Label ignore proposals (between low and high thresholds)
- ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
- labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
- matched_idxs.append(clamped_matched_idxs_in_image)
- labels.append(labels_in_image)
- return matched_idxs, labels
- def subsample(self, labels):
- # type: (List[Tensor]) -> List[Tensor]
- sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
- sampled_inds = []
- for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
- img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
- sampled_inds.append(img_sampled_inds)
- return sampled_inds
- def add_gt_proposals(self, proposals, gt_boxes):
- # type: (List[Tensor], List[Tensor]) -> List[Tensor]
- proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
- return proposals
- def check_targets(self, targets):
- # type: (Optional[List[Dict[str, Tensor]]]) -> None
- if targets is None:
- raise ValueError("targets should not be None")
- if not all(["boxes" in t for t in targets]):
- raise ValueError("Every element of targets should have a boxes key")
- if not all(["labels" in t for t in targets]):
- raise ValueError("Every element of targets should have a labels key")
- if self.has_mask():
- if not all(["masks" in t for t in targets]):
- raise ValueError("Every element of targets should have a masks key")
- def select_training_samples(
- self,
- proposals, # type: List[Tensor]
- targets, # type: Optional[List[Dict[str, Tensor]]]
- ):
- # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
- self.check_targets(targets)
- if targets is None:
- raise ValueError("targets should not be None")
- dtype = proposals[0].dtype
- device = proposals[0].device
- gt_boxes = [t["boxes"].to(dtype) for t in targets]
- gt_labels = [t["labels"] for t in targets]
- # append ground-truth bboxes to propos
- proposals = self.add_gt_proposals(proposals, gt_boxes)
- # get matching gt indices for each proposal
- matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
- # sample a fixed proportion of positive-negative proposals
- sampled_inds = self.subsample(labels)
- matched_gt_boxes = []
- num_images = len(proposals)
- for img_id in range(num_images):
- img_sampled_inds = sampled_inds[img_id]
- proposals[img_id] = proposals[img_id][img_sampled_inds]
- labels[img_id] = labels[img_id][img_sampled_inds]
- matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
- gt_boxes_in_image = gt_boxes[img_id]
- if gt_boxes_in_image.numel() == 0:
- gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
- matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
- regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
- return proposals, matched_idxs, labels, regression_targets
- def postprocess_detections(
- self,
- class_logits, # type: Tensor
- box_regression, # type: Tensor
- proposals, # type: List[Tensor]
- image_shapes, # type: List[Tuple[int, int]]
- ):
- # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
- device = class_logits.device
- num_classes = class_logits.shape[-1]
- boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
- pred_boxes = self.box_coder.decode(box_regression, proposals)
- pred_scores = F.softmax(class_logits, -1)
- pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
- pred_scores_list = pred_scores.split(boxes_per_image, 0)
- all_boxes = []
- all_scores = []
- all_labels = []
- for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
- boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
- # create labels for each prediction
- labels = torch.arange(num_classes, device=device)
- labels = labels.view(1, -1).expand_as(scores)
- # remove predictions with the background label
- boxes = boxes[:, 1:]
- scores = scores[:, 1:]
- labels = labels[:, 1:]
- # batch everything, by making every class prediction be a separate instance
- boxes = boxes.reshape(-1, 4)
- scores = scores.reshape(-1)
- labels = labels.reshape(-1)
- # remove low scoring boxes
- inds = torch.where(scores > self.score_thresh)[0]
- boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
- # remove empty boxes
- keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
- boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- # non-maximum suppression, independently done per class
- keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
- # keep only topk scoring predictions
- keep = keep[: self.detections_per_img]
- boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- all_boxes.append(boxes)
- all_scores.append(scores)
- all_labels.append(labels)
- return all_boxes, all_scores, all_labels
- def forward(
- self,
- features, # type: Dict[str, Tensor]
- proposals, # type: List[Tensor]
- image_shapes, # type: List[Tuple[int, int]]
- targets=None, # type: Optional[List[Dict[str, Tensor]]]
- ):
- # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
- """
- Args:
- features (List[Tensor])
- proposals (List[Tensor[N, 4]])
- image_shapes (List[Tuple[H, W]])
- targets (List[Dict])
- """
- print(f'roihead forward!!!')
- if targets is not None:
- for t in targets:
- # TODO: https://github.com/pytorch/pytorch/issues/26731
- floating_point_types = (torch.float, torch.double, torch.half)
- if not t["boxes"].dtype in floating_point_types:
- raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
- if not t["labels"].dtype == torch.int64:
- raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
- if self.has_keypoint():
- if not t["keypoints"].dtype == torch.float32:
- raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
- if self.training:
- proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
- else:
- if targets is not None:
- proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
- else:
- labels = None
- regression_targets = None
- matched_idxs = None
- box_features = self.box_roi_pool(features, proposals, image_shapes)
- box_features = self.box_head(box_features)
- class_logits, box_regression = self.box_predictor(box_features)
- result: List[Dict[str, torch.Tensor]] = []
- losses = {}
- # _, C, H, W = features['0'].shape # å¿½ç¥ batch_sizeï¼å 为æä»¬åªå
³å¿ C, H, W
- if self.training:
- if labels is None:
- raise ValueError("labels cannot be None")
- if regression_targets is None:
- raise ValueError("regression_targets cannot be None")
- print(f'boxes compute losses')
- loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
- losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
- else:
- if targets is not None:
- loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
- losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
- boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals,
- image_shapes)
- num_images = len(boxes)
- for i in range(num_images):
- result.append(
- {
- "boxes": boxes[i],
- "labels": labels[i],
- "scores": scores[i],
- }
- )
- if self.has_line():
- print(f'roi_heads forward has_line()!!!!')
- # print(f'labels:{labels}')
- line_proposals = [p["boxes"] for p in result]
- point_proposals = [p["boxes"] for p in result]
- print(f'boxes_proposals:{len(line_proposals)}')
- # if line_proposals is None or len(line_proposals) == 0:
- # # è¿å空ç¹å¾æè
è·³è¿è¯¥é¨å计ç®
- # return torch.empty(0, C, H, W).to(features['0'].device)
- if self.training:
- # during training, only focus on positive boxes
- num_images = len(proposals)
- print(f'num_images:{num_images}')
- line_proposals = []
- point_proposals = []
- arc_proposals = []
- pos_matched_idxs = []
- line_pos_matched_idxs = []
- point_pos_matched_idxs = []
- if matched_idxs is None:
- raise ValueError("if in trainning, matched_idxs should not be None")
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- line_pos=torch.where(labels[img_id] ==2)[0]
- point_pos=torch.where(labels[img_id] ==1)[0]
- line_proposals.append(proposals[img_id][line_pos])
- point_proposals.append(proposals[img_id][point_pos])
- line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
- point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
- # pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- if targets is not None:
- pos_matched_idxs = []
- num_images = len(proposals)
- line_proposals = []
- point_proposals=[]
- arc_proposals=[]
- line_pos_matched_idxs = []
- point_pos_matched_idxs = []
- print(f'val num_images:{num_images}')
- if matched_idxs is None:
- raise ValueError("if in trainning, matched_idxs should not be None")
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- # line_proposals.append(proposals[img_id][pos])
- # pos_matched_idxs.append(matched_idxs[img_id][pos])
- line_pos = torch.where(labels[img_id] == 2)[0]
- point_pos = torch.where(labels[img_id] == 1)[0]
- line_proposals.append(proposals[img_id][line_pos])
- point_proposals.append(proposals[img_id][point_pos])
- line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
- point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
- else:
- pos_matched_idxs = None
- print(f'line_proposals:{len(line_proposals)}')
- # line_features = self.line_roi_pool(features, line_proposals, image_shapes)
- # print(f'line_features from line_roi_pool:{line_features.shape}')
- #(b,256,512,512)
- line_features = self.channel_compress(features['0'])
- #(b.8,512,512)
- all_proposals=line_proposals+point_proposals
- # print(f'point_proposals:{point_proposals}')
- # print(f'all_proposals:{all_proposals}')
- for p in point_proposals:
- print(f'point_proposal:{p.shape}')
- for ap in all_proposals:
- print(f'ap_proposal:{ap.shape}')
- filtered_proposals = [proposal for proposal in all_proposals if proposal.shape[0] > 0]
- filtered_proposals_tensor=torch.cat(filtered_proposals)
- line_proposals_tensor=torch.cat(line_proposals)
- print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
- print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
- point_proposals_tensor=torch.cat(point_proposals)
- print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
- # line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
- line_features = lines_features_align(line_features, point_proposals, image_shapes)
- print(f'line_features from features_align:{line_features.shape}')
- line_features = self.line_head(line_features)
- #(N,1,512,512)
- print(f'line_features from line_head:{line_features.shape}')
- # line_logits = self.line_predictor(line_features)
- line_logits = line_features
- print(f'line_logits:{line_logits.shape}')
- loss_line = {}
- loss_line_iou = {}
- loss_point = {}
- if self.training:
- if targets is None or pos_matched_idxs is None:
- raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- gt_lines = [t["lines"] for t in targets]
- gt_points = [t["points"] for t in targets]
- print(f'gt_lines:{gt_lines[0].shape}')
- h, w = targets[0]["img_size"]
- img_size = h
- # rcnn_loss_line = lines_point_pair_loss(
- # line_logits, line_proposals, gt_lines, pos_matched_idxs
- # )
- # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
- gt_lines_tensor=torch.cat(gt_lines)
- gt_points_tensor = torch.cat(gt_points)
- print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
- print(f'gt_points_tensor:{gt_points_tensor.shape}')
- if gt_lines_tensor.shape[0]>0 :
- loss_line = lines_point_pair_loss(
- line_logits, line_proposals, gt_lines, line_pos_matched_idxs
- )
- loss_line_iou = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
- if gt_points_tensor.shape[0]>0:
- loss_point = compute_point_loss(
- line_logits, point_proposals, gt_points, point_pos_matched_idxs
- )
- if not loss_line:
- loss_line = torch.tensor(0.0, device=line_features.device)
- if not loss_line_iou:
- loss_line_iou = torch.tensor(0.0, device=line_features.device)
- loss_line = {"loss_line": loss_line}
- loss_line_iou = {'loss_line_iou': loss_line_iou}
- loss_point = {"loss_point": loss_point}
- else:
- if targets is not None:
- h, w = targets[0]["img_size"]
- img_size = h
- gt_lines = [t["lines"] for t in targets]
- gt_points = [t["points"] for t in targets]
- gt_lines_tensor = torch.cat(gt_lines)
- gt_points_tensor = torch.cat(gt_points)
- if gt_lines_tensor.shape[0] > 0:
- loss_line = lines_point_pair_loss(
- line_logits, line_proposals, gt_lines, line_pos_matched_idxs
- )
- loss_line_iou = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs,
- img_size)
- if gt_points_tensor.shape[0] > 0:
- loss_point = compute_point_loss(
- line_logits, point_proposals, gt_points, point_pos_matched_idxs
- )
- if not loss_line :
- loss_line=torch.tensor(0.0,device=line_features.device)
- if not loss_line_iou :
- loss_line_iou=torch.tensor(0.0,device=line_features.device)
- loss_line = {"loss_line": loss_line}
- loss_line_iou = {'loss_line_iou': loss_line_iou}
- loss_point={"loss_point":loss_point}
- else:
- if line_logits is None or line_proposals is None:
- raise ValueError(
- "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
- )
- lines_probs, lines_scores,point_probs,points_scores = line_inference(line_logits, line_proposals)
- for keypoint_prob, kps, points,ps,r in zip(lines_probs, lines_scores,point_probs,points_scores, result):
- print(f'points_prob :{points.shape}')
- r["lines"] = keypoint_prob
- r["liness_scores"] = kps
- r["points"] = points
- r["points_scores"] = ps
- losses.update(loss_line)
- losses.update(loss_line_iou)
- losses.update(loss_point)
- print(f'losses:{losses}')
- if self.has_mask():
- mask_proposals = [p["boxes"] for p in result]
- if self.training:
- if matched_idxs is None:
- raise ValueError("if in training, matched_idxs should not be None")
- # during training, only focus on positive boxes
- num_images = len(proposals)
- mask_proposals = []
- pos_matched_idxs = []
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- mask_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- if self.mask_roi_pool is not None:
- mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
- mask_features = self.mask_head(mask_features)
- mask_logits = self.mask_predictor(mask_features)
- else:
- raise Exception("Expected mask_roi_pool to be not None")
- loss_mask = {}
- if self.training:
- if targets is None or pos_matched_idxs is None or mask_logits is None:
- raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
- gt_masks = [t["masks"] for t in targets]
- gt_labels = [t["labels"] for t in targets]
- rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
- loss_mask = {"loss_mask": rcnn_loss_mask}
- else:
- labels = [r["labels"] for r in result]
- masks_probs = maskrcnn_inference(mask_logits, labels)
- for mask_prob, r in zip(masks_probs, result):
- r["masks"] = mask_prob
- losses.update(loss_mask)
- # keep none checks in if conditional so torchscript will conditionally
- # compile each branch
- if self.has_keypoint():
- keypoint_proposals = [p["boxes"] for p in result]
- if self.training:
- # during training, only focus on positive boxes
- num_images = len(proposals)
- keypoint_proposals = []
- pos_matched_idxs = []
- if matched_idxs is None:
- raise ValueError("if in trainning, matched_idxs should not be None")
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- keypoint_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
- keypoint_features = self.line_head(keypoint_features)
- keypoint_logits = self.line_predictor(keypoint_features)
- loss_keypoint = {}
- if self.training:
- if targets is None or pos_matched_idxs is None:
- raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- gt_keypoints = [t["keypoints"] for t in targets]
- rcnn_loss_keypoint = keypointrcnn_loss(
- keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
- )
- loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
- else:
- if keypoint_logits is None or keypoint_proposals is None:
- raise ValueError(
- "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
- )
- keypoints_probs, lines_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
- for keypoint_prob, kps, r in zip(keypoints_probs, lines_scores, result):
- r["keypoints"] = keypoint_prob
- r["keypoints_scores"] = kps
- losses.update(loss_keypoint)
- return result, losses
|