Bladeren bron

添加线段角度、长度等损失,修复roi映射问题

RenLiqiang 5 maanden geleden
bovenliggende
commit
5338e7af68
1 gewijzigde bestanden met toevoegingen van 201 en 142 verwijderingen
  1. 201 142
      models/line_detect/loi_heads.py

+ 201 - 142
models/line_detect/loi_heads.py

@@ -6,7 +6,7 @@ import torch.nn.functional as F
 import torchvision
 # from scipy.optimize import linear_sum_assignment
 from torch import nn, Tensor
-from  libs.vision_libs.ops import boxes as box_ops, roi_align
+from libs.vision_libs.ops import boxes as box_ops, roi_align
 
 import libs.vision_libs.models.detection._utils as det_utils
 
@@ -129,19 +129,77 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
     )
     return mask_loss
 
+def normalize_tensor(t):
+    return (t - t.min()) / (t.max() - t.min() + 1e-6)
+
+def line_length(lines):
+    """
+    计算每条线段的长度
+    lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
+    返回: [N]
+    """
+    return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
+
+def line_direction(lines):
+    """
+    计算每条线段的单位方向向量
+    lines: [N, 2, 2]
+    返回: [N, 2] 单位方向向量
+    """
+    vec = lines[:, 1] - lines[:, 0]
+    return F.normalize(vec, dim=-1)
+
+def angle_loss_cosine(pred_dir, gt_dir):
+    """
+    使用 cosine similarity 计算方向差异
+    pred_dir: [N, 2]
+    gt_dir: [N, 2]
+    返回: [N]
+    """
+    cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
+    return 1.0 - cos_sim  # 或者 torch.acos(cos_sim) / pi 也可
+
+
+def line_length(lines):
+        """
+        计算每条线段的长度
+        lines: [N, 2, 2] 表示 N 条线段,每条线段由两个点组成
+        返回: [N]
+        """
+        return torch.norm(lines[:, 1] - lines[:, 0], dim=-1)
+
+def line_direction(lines):
+        """
+        计算每条线段的单位方向向量
+        lines: [N, 2, 2]
+        返回: [N, 2] 单位方向向量
+        """
+        vec = lines[:, 1] - lines[:, 0]
+        return F.normalize(vec, dim=-1)
+
+def angle_loss_cosine(pred_dir, gt_dir):
+        """
+        使用 cosine similarity 计算方向差异
+        pred_dir: [N, 2]
+        gt_dir: [N, 2]
+        返回: [N]
+        """
+        cos_sim = torch.sum(pred_dir * gt_dir, dim=-1).clamp(-1.0, 1.0)
+        return 1.0 - cos_sim  # 或者 torch.acos(cos_sim) / pi 也可
+
 def line_points_to_heatmap(keypoints, rois, heatmap_size):
-    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    # type: (Tensor, Tensor, int) -> Tensor
     print(f'rois:{rois.shape}')
     print(f'heatmap_size:{heatmap_size}')
-    offset_x = rois[:, 0]
-    offset_y = rois[:, 1]
-    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
-    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
-
-    offset_x = offset_x[:, None]
-    offset_y = offset_y[:, None]
-    scale_x = scale_x[:, None]
-    scale_y = scale_y[:, None]
+    # offset_x = rois[:, 0]
+    # offset_y = rois[:, 1]
+    # scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    # scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+    #
+    # offset_x = offset_x[:, None]
+    # offset_y = offset_y[:, None]
+    # scale_x = scale_x[:, None]
+    # scale_y = scale_y[:, None]
 
     print(f'keypoints.shape:{keypoints.shape}')
     # batch_size, num_keypoints, _ = keypoints.shape
@@ -149,28 +207,29 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
     x = keypoints[..., 0]
     y = keypoints[..., 1]
 
-    gs=generate_gaussian_heatmaps(x,y,heatmap_size,1.0)
+    gs = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
     # show_heatmap(gs[0],'target')
-    all_roi_heatmap=[]
-    for roi ,heatmap in zip(rois,gs):
+    all_roi_heatmap = []
+    for roi, heatmap in zip(rois, gs):
         # print(f'heatmap:{heatmap.shape}')
-        heatmap=heatmap.unsqueeze(0)
+        heatmap = heatmap.unsqueeze(0)
         x1, y1, x2, y2 = map(int, roi)
         roi_heatmap = torch.zeros_like(heatmap)
-        roi_heatmap[..., y1:y2+1, x1:x2+1]=heatmap[..., y1:y2+1, x1:x2+1]
+        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
         # show_heatmap(roi_heatmap,'roi_heatmap')
         all_roi_heatmap.append(roi_heatmap)
 
-    all_roi_heatmap=torch.cat(all_roi_heatmap)
+    all_roi_heatmap = torch.cat(all_roi_heatmap)
     print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
 
-
     return all_roi_heatmap
 
 
 """
 修改适配的原结构的点 转热图,适用于带roi_pool版本的
 """
+
+
 def line_points_to_heatmap_(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
     print(f'rois:{rois.shape}')
@@ -193,7 +252,6 @@ def line_points_to_heatmap_(keypoints, rois, heatmap_size):
 
     # gs=generate_gaussian_heatmaps(x,y,512,1.0)
 
-
     # print(f'gs_heatmap shape:{gs.shape}')
     #
     # show_heatmap(gs[0],'target')
@@ -215,9 +273,9 @@ def line_points_to_heatmap_(keypoints, rois, heatmap_size):
     vis = keypoints[..., 2] > 0
     valid = (valid_loc & vis).long()
 
-    gs_heatmap=generate_gaussian_heatmaps(x,y,heatmap_size,1.0)
+    gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
 
-    show_heatmap(gs_heatmap[0],'feature')
+    show_heatmap(gs_heatmap[0], 'feature')
 
     # print(f'gs_heatmap:{gs_heatmap.shape}')
     #
@@ -256,9 +314,8 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size, sigma=2.0, device='cuda'):
 
     # print(f'heatmap_size:{heatmap_size}')
     # 初始化输出热图
-    combined_heatmap = torch.zeros((N,heatmap_size, heatmap_size), device=device)
+    combined_heatmap = torch.zeros((N, heatmap_size, heatmap_size), device=device)
     for i in range(N):
-
         mu_x1 = xs[i, 0].clamp(0, heatmap_size - 1).item()
         mu_y1 = ys[i, 0].clamp(0, heatmap_size - 1).item()
 
@@ -277,10 +334,10 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size, sigma=2.0, device='cuda'):
         # 计算高斯分布
         heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
 
-        heatmap=heatmap1+heatmap2
+        heatmap = heatmap1 + heatmap2
 
         # 将当前热图累加到结果中
-        combined_heatmap[i]= heatmap
+        combined_heatmap[i] = heatmap
 
     return combined_heatmap
 
@@ -305,6 +362,7 @@ def show_heatmap(heatmap, title="Heatmap"):
     plt.title(title)
     plt.show()
 
+
 def keypoints_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
     offset_x = rois[:, 0]
@@ -335,7 +393,6 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
     vis = keypoints[..., 2] > 0
     valid = (valid_loc & vis).long()
 
-
     lin_ind = y * heatmap_size + x
     heatmaps = lin_ind * valid
 
@@ -343,7 +400,7 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
 
 
 def _onnx_heatmaps_to_keypoints(
-    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
 ):
     num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
 
@@ -395,7 +452,7 @@ def _onnx_heatmaps_to_keypoints(
 
 @torch.jit._script_if_tracing
 def _onnx_heatmaps_to_keypoints_loop(
-    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
 ):
     xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
     end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
@@ -473,11 +530,14 @@ def heatmaps_to_keypoints(maps, rois):
         end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
 
     return xy_preds.permute(0, 2, 1), end_scores
+
+
 def non_maximum_suppression(a):
     ap = F.max_pool2d(a, 3, stride=1, padding=1)
     mask = (a == ap).float().clamp(min=0.0)
     return a * mask
 
+
 def heatmaps_to_lines(maps, rois):
     """Extract predicted keypoint locations from heatmaps. Output has shape
     (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
@@ -488,54 +548,36 @@ def heatmaps_to_lines(maps, rois):
     # consistency with keypoints_to_heatmap_labels by using the conversion from
     # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
     # continuous coordinate.
-    offset_x = rois[:, 0]
-    offset_y = rois[:, 1]
-
-    widths = rois[:, 2] - rois[:, 0]
-    heights = rois[:, 3] - rois[:, 1]
-    widths = widths.clamp(min=1)
-    heights = heights.clamp(min=1)
-    widths_ceil = widths.ceil()
-    heights_ceil = heights.ceil()
-
-    num_keypoints = maps.shape[1]
-
     xy_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
     end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
 
     for i in range(len(rois)):
-        roi_map_width = int(widths_ceil[i].item())
-        roi_map_height = int(heights_ceil[i].item())
-        width_correction = widths[i] / roi_map_width
-        height_correction = heights[i] / roi_map_height
-        roi_map = F.interpolate(
-            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
-        )[:, 0]
+        roi_map = maps[i]
+
         print(f'roi_map:{roi_map.shape}')
         # roi_map_probs = scores_to_probs(roi_map.copy())
         w = roi_map.shape[2]
-        flatten_map=non_maximum_suppression(roi_map).reshape(1, -1)
+        flatten_map = non_maximum_suppression(roi_map).reshape(1, -1)
         score, index = torch.topk(flatten_map, k=2)
 
         print(f'index:{index}')
 
         # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
 
-        pos=index
+        pos = index
 
-        x_int = pos % w
+        # x_int = pos % w
+        #
+        # y_int = torch.div(pos - x_int, w, rounding_mode="floor")
 
-        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        x = pos % w
 
+        y = torch.div(pos - x, w, rounding_mode="floor")
 
-        # assert (roi_map_probs[k, y_int, x_int] ==
-        #         roi_map_probs[k, :, :].max())
-        x = (x_int.float() + 0.5) * width_correction
-        y = (y_int.float() + 0.5) * height_correction
-        xy_preds[i, 0, :] = x + offset_x[i]
-        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 0, :] = x
+        xy_preds[i, 1, :] = y
         xy_preds[i, 2, :] = 1
-        end_scores[i, :] = roi_map[torch.arange(1, device=roi_map.device), y_int, x_int]
+        end_scores[i, :] = roi_map[torch.arange(1, device=roi_map.device), y, x]
 
     return xy_preds.permute(0, 2, 1), end_scores
 
@@ -544,30 +586,29 @@ def lines_features_align(features, proposals, img_size):
     print(f'lines_features_align features:{features.shape}')
 
     align_feat_list = []
-    for feat, proposals_per_img  in zip(features,proposals):
+    for feat, proposals_per_img in zip(features, proposals):
         # print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
 
-        feat=feat.unsqueeze(0)
+        feat = feat.unsqueeze(0)
         for proposal in proposals_per_img:
             align_feat = torch.zeros_like(feat)
             # print(f'align_feat:{align_feat.shape}')
             x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
             # 将每个proposal框内的部分赋值到align_feats对应位置
-            align_feat[:,:, y1:y2 + 1, x1:x2 + 1] = feat[:,:, y1:y2 + 1, x1:x2 + 1]
+            align_feat[:, :, y1:y2 + 1, x1:x2 + 1] = feat[:, :, y1:y2 + 1, x1:x2 + 1]
             align_feat_list.append(align_feat)
 
-
-    feats_tensor=torch.cat(align_feat_list)
+    feats_tensor = torch.cat(align_feat_list)
 
     print(f'align features :{feats_tensor.shape}')
 
-    return  feats_tensor
+    return feats_tensor
 
 
 def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
     N, K, H, W = line_logits.shape
-    len_proposals=len(proposals)
+    len_proposals = len(proposals)
     print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
     if H != W:
         raise ValueError(
@@ -575,7 +616,7 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
         )
     discretization_size = H
     heatmaps = []
-    gs_heatmaps=[]
+    gs_heatmaps = []
     valid = []
     for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
         print(f'proposals_per_image:{proposals_per_image.shape}')
@@ -584,13 +625,12 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
         gs_heatmaps.append(gs_heatmaps_per_img)
         # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
 
-
         # heatmaps.append(heatmaps_per_image.view(-1))
 
         # valid.append(valid_per_image.view(-1))
 
     # line_targets = torch.cat(heatmaps, dim=0)
-    gs_heatmaps=torch.cat(gs_heatmaps,dim=0)
+    gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
     print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
     # print(f'line_targets:{line_targets.shape},{line_targets}')
 
@@ -606,10 +646,10 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
 
     # line_logits = line_logits.view(N * K, H * W)
     # print(f'line_logits[valid]:{line_logits[valid].shape}')
-    line_logits=line_logits.squeeze(1)
+    line_logits = line_logits.squeeze(1)
 
     # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
-    line_loss=F.cross_entropy(line_logits,gs_heatmaps)
+    line_loss = F.cross_entropy(line_logits, gs_heatmaps)
 
     return line_loss
 
@@ -647,6 +687,7 @@ def lines_to_boxes(lines, img_size=511):
     boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1)  # (N, 4)
     return boxes
 
+
 def box_iou_pairwise(box1, box2):
     """
     输入:
@@ -669,33 +710,63 @@ def box_iou_pairwise(box1, box2):
     ious = inter_area / (union_area + 1e-6)
 
     return ious
-def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511):
+
+
+def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta=1.0, gamma=1.0):
+    """
+    Args:
+        x: [N,1,H,W] 热力图
+        boxes: [N,4] 框坐标
+        gt_lines: [N,2,3] GT线段(含可见性)
+        matched_idx: 匹配 index
+        img_size: 图像尺寸
+        alpha: IoU 损失权重
+        beta: 长度损失权重
+        gamma: 方向角度损失权重
+    """
     losses = []
     boxes_per_image = [box.size(0) for box in boxes]
     x2 = x.split(boxes_per_image, dim=0)
 
     for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
-        p_prob, scores = heatmaps_to_lines(xx, bb)
+        p_prob, _ = heatmaps_to_lines(xx, bb)
         pred_lines = p_prob
         gt_line_points = gt_line[mid]
 
         if len(pred_lines) == 0 or len(gt_line_points) == 0:
             continue
 
+        # IoU 损失
         pred_boxes = lines_to_boxes(pred_lines, img_size)
         gt_boxes = lines_to_boxes(gt_line_points, img_size)
-
         ious = box_iou_pairwise(pred_boxes, gt_boxes)
+        iou_loss = 1.0 - ious  # [N]
+
+        # 长度损失
+        pred_len = line_length(pred_lines)
+        gt_len = line_length(gt_line_points)
+        length_diff = F.l1_loss(pred_len, gt_len, reduction='none')  # [N]
+
+        # 方向角度损失
+        pred_dir = line_direction(pred_lines)
+        gt_dir = line_direction(gt_line_points)
+        ang_loss = angle_loss_cosine(pred_dir, gt_dir)  # [N]
 
-        loss = 1.0 - ious
-        losses.append(loss)
+        # 归一化每一项损失
+        norm_iou = normalize_tensor(iou_loss)
+        norm_len = normalize_tensor(length_diff)
+        norm_ang = normalize_tensor(ang_loss)
 
-    if not losses:  # 如果损失列表为空,则返回默认值或抛出自定义异常
-        print("Warning: No valid losses were computed.")
-        return torch.tensor(1.0, requires_grad=True).to(x.device)  # 返回一个标量张量
+        total = alpha * norm_iou + beta * norm_len + gamma * norm_ang
+        losses.append(total)
+
+
+
+    if not losses:
+        return None
+
+    return torch.mean(torch.cat(losses))
 
-    total_loss = torch.mean(torch.cat(losses))
-    return total_loss
 
 def line_inference(x, boxes):
     # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
@@ -710,12 +781,9 @@ def line_inference(x, boxes):
         points_probs.append(p_prob)
         points_scores.append(scores)
 
-
-
-
-
     return points_probs, points_scores
 
+
 def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
     N, K, H, W = keypoint_logits.shape
@@ -843,7 +911,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
     y_0 = max(box[1], 0)
     y_1 = min(box[3] + 1, im_h)
 
-    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
     return im_mask
 
 
@@ -868,7 +936,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
     y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
     y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
 
-    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
 
     # TODO : replace below with a dynamic padding when support is added in ONNX
 
@@ -919,31 +987,31 @@ class RoIHeads(nn.Module):
     }
 
     def __init__(
-        self,
-        box_roi_pool,
-        box_head,
-        box_predictor,
-        # Faster R-CNN training
-        fg_iou_thresh,
-        bg_iou_thresh,
-        batch_size_per_image,
-        positive_fraction,
-        bbox_reg_weights,
-        # Faster R-CNN inference
-        score_thresh,
-        nms_thresh,
-        detections_per_img,
-        # Line
-        line_roi_pool=None,
-        line_head=None,
-        line_predictor=None,
-        # Mask
-        mask_roi_pool=None,
-        mask_head=None,
-        mask_predictor=None,
-        keypoint_roi_pool=None,
-        keypoint_head=None,
-        keypoint_predictor=None,
+            self,
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            # Faster R-CNN training
+            fg_iou_thresh,
+            bg_iou_thresh,
+            batch_size_per_image,
+            positive_fraction,
+            bbox_reg_weights,
+            # Faster R-CNN inference
+            score_thresh,
+            nms_thresh,
+            detections_per_img,
+            # Line
+            line_roi_pool=None,
+            line_head=None,
+            line_predictor=None,
+            # Mask
+            mask_roi_pool=None,
+            mask_head=None,
+            mask_predictor=None,
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
     ):
         super().__init__()
 
@@ -978,8 +1046,8 @@ class RoIHeads(nn.Module):
         self.keypoint_predictor = keypoint_predictor
 
         self.channel_compress = nn.Sequential(
-             nn.Conv2d(256, 16, kernel_size=1),
-             nn.BatchNorm2d(16),
+            nn.Conv2d(256, 16, kernel_size=1),
+            nn.BatchNorm2d(16),
             nn.ReLU(inplace=True)
         )
 
@@ -1073,9 +1141,9 @@ class RoIHeads(nn.Module):
                 raise ValueError("Every element of targets should have a masks key")
 
     def select_training_samples(
-        self,
-        proposals,  # type: List[Tensor]
-        targets,  # type: Optional[List[Dict[str, Tensor]]]
+            self,
+            proposals,  # type: List[Tensor]
+            targets,  # type: Optional[List[Dict[str, Tensor]]]
     ):
         # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
         self.check_targets(targets)
@@ -1111,11 +1179,11 @@ class RoIHeads(nn.Module):
         return proposals, matched_idxs, labels, regression_targets
 
     def postprocess_detections(
-        self,
-        class_logits,  # type: Tensor
-        box_regression,  # type: Tensor
-        proposals,  # type: List[Tensor]
-        image_shapes,  # type: List[Tuple[int, int]]
+            self,
+            class_logits,  # type: Tensor
+            box_regression,  # type: Tensor
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
     ):
         # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
         device = class_logits.device
@@ -1170,11 +1238,11 @@ class RoIHeads(nn.Module):
         return all_boxes, all_scores, all_labels
 
     def forward(
-        self,
-        features,  # type: Dict[str, Tensor]
-        proposals,  # type: List[Tensor]
-        image_shapes,  # type: List[Tuple[int, int]]
-        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+            self,
+            features,  # type: Dict[str, Tensor]
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
     ):
         # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
         """
@@ -1239,7 +1307,6 @@ class RoIHeads(nn.Module):
                     }
                 )
 
-
         if self.has_line():
             print(f'roi_heads forward has_line()!!!!')
             line_proposals = [p["boxes"] for p in result]
@@ -1281,27 +1348,23 @@ class RoIHeads(nn.Module):
 
             print(f'line_proposals:{len(line_proposals)}')
 
-
-
             # line_features = self.line_roi_pool(features, line_proposals, image_shapes)
 
-
             # print(f'line_features from line_roi_pool:{line_features.shape}')
 
-            line_features=self.channel_compress(features['0'])
-
-            line_features=lines_features_align(line_features,line_proposals,image_shapes)
+            line_features = self.channel_compress(features['0'])
 
+            line_features = lines_features_align(line_features, line_proposals, image_shapes)
 
             line_features = self.line_head(line_features)
             print(f'line_features from line_head:{line_features.shape}')
             # line_logits = self.line_predictor(line_features)
 
-            line_logits=line_features
+            line_logits = line_features
             print(f'line_logits:{line_logits.shape}')
 
             loss_line = {}
-            loss_line_iou={}
+            loss_line_iou = {}
 
             if self.training:
 
@@ -1315,7 +1378,7 @@ class RoIHeads(nn.Module):
                 rcnn_loss_line = lines_point_pair_loss(
                     line_logits, line_proposals, gt_lines, pos_matched_idxs
                 )
-                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
+                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
 
                 loss_line = {"loss_line": rcnn_loss_line}
                 loss_line_iou = {'loss_line_iou': iou_loss}
@@ -1330,8 +1393,8 @@ class RoIHeads(nn.Module):
                     )
                     loss_line = {"loss_line": rcnn_loss_lines}
 
-                    iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
-                    loss_line_iou={'loss_line_iou':iou_loss}
+                    iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs, img_size)
+                    loss_line_iou = {'loss_line_iou': iou_loss}
 
 
                 else:
@@ -1349,8 +1412,6 @@ class RoIHeads(nn.Module):
             losses.update(loss_line)
             losses.update(loss_line_iou)
 
-
-
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]
             if self.training:
@@ -1413,8 +1474,6 @@ class RoIHeads(nn.Module):
 
             keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
 
-
-
             keypoint_features = self.line_head(keypoint_features)
             keypoint_logits = self.line_predictor(keypoint_features)