RenLiqiang 8 mesiacov pred
rodič
commit
660c1be194

+ 4 - 6
config/wireframe.yaml

@@ -1,12 +1,10 @@
 io:
   logdir: logs/
-  datadir: D:\python\PycharmProjects\data
-#  datadir: /home/dieu/lcnn/dataset/line_data_104
+  datadir: I:/datasets/wirenet_1000
   resume_from:
-#  resume_from: /home/dieu/lcnn/logs/241112-163302-175fb79-my_data_104_resume
   num_workers: 0
   tensorboard_port: 0
-  validation_interval: 300    # 评估间隔
+  validation_interval: 300
 
 model:
   image:
@@ -17,14 +15,14 @@ model:
   batch_size_eval: 2
 
   # backbone multi-task parameters
-  head_size: [[2], [1], [2],[4]]
+  head_size: [[2], [1], [2]]
   loss_weight:
     jmap: 8.0
     lmap: 0.5
     joff: 0.25
     lpos: 1
     lneg: 1
-    boxes: 1.0  # 新增 box loss 权重
+    boxes: 1.0
 
   # backbone parameters
   backbone: fasterrcnn_resnet50

+ 876 - 0
lcnn/models/detection/ROI_heads.py

@@ -0,0 +1,876 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 846 - 0
lcnn/models/detection/faster_rcnn.py

@@ -0,0 +1,846 @@
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
+from .generalized_rcnn import GeneralizedRCNN
+from .roi_heads import RoIHeads
+from .rpn import RegionProposalNetwork, RPNHead
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+    "FasterRCNN",
+    "FasterRCNN_ResNet50_FPN_Weights",
+    "FasterRCNN_ResNet50_FPN_V2_Weights",
+    "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
+    "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
+    "fasterrcnn_resnet50_fpn",
+    "fasterrcnn_resnet50_fpn_v2",
+    "fasterrcnn_mobilenet_v3_large_fpn",
+    "fasterrcnn_mobilenet_v3_large_320_fpn",
+]
+
+
+def _default_anchorgen():
+    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    return AnchorGenerator(anchor_sizes, aspect_ratios)
+
+
+class FasterRCNN(GeneralizedRCNN):
+    """
+    Implements Faster R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): Images are rescaled before feeding them to the backbone:
+            we attempt to preserve the aspect ratio and scale the shorter edge
+            to ``min_size``. If the resulting longer edge exceeds ``max_size``,
+            then downscale so that the longer edge does not exceed ``max_size``.
+            This may result in the shorter edge beeing lower than ``min_size``.
+        max_size (int): See ``min_size``.
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import FasterRCNN
+        >>> from torchvision.models.detection.rpn import AnchorGenerator
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # FasterRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> # put the pieces together inside a FasterRCNN model
+        >>> model = FasterRCNN(backbone,
+        >>>                    num_classes=2,
+        >>>                    rpn_anchor_generator=anchor_generator,
+        >>>                    box_roi_pool=roi_pooler)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+        self,
+        backbone,
+        num_classes=None,
+        # transform parameters
+        min_size=512,   # 原800
+        max_size=1333,
+        image_mean=None,
+        image_std=None,
+        # RPN parameters
+        rpn_anchor_generator=None,
+        rpn_head=None,
+        rpn_pre_nms_top_n_train=2000,
+        rpn_pre_nms_top_n_test=1000,
+        rpn_post_nms_top_n_train=2000,
+        rpn_post_nms_top_n_test=1000,
+        rpn_nms_thresh=0.7,
+        rpn_fg_iou_thresh=0.7,
+        rpn_bg_iou_thresh=0.3,
+        rpn_batch_size_per_image=256,
+        rpn_positive_fraction=0.5,
+        rpn_score_thresh=0.0,
+        # Box parameters
+        box_roi_pool=None,
+        box_head=None,
+        box_predictor=None,
+        box_score_thresh=0.05,
+        box_nms_thresh=0.5,
+        box_detections_per_img=100,
+        box_fg_iou_thresh=0.5,
+        box_bg_iou_thresh=0.5,
+        box_batch_size_per_image=512,
+        box_positive_fraction=0.25,
+        bbox_reg_weights=None,
+        **kwargs,
+    ):
+
+        if not hasattr(backbone, "out_channels"):
+            raise ValueError(
+                "backbone should contain an attribute out_channels "
+                "specifying the number of output channels (assumed to be the "
+                "same for all the levels)"
+            )
+
+        if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
+            raise TypeError(
+                f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
+            )
+        if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
+            )
+
+        if num_classes is not None:
+            if box_predictor is not None:
+                raise ValueError("num_classes should be None when box_predictor is specified")
+        else:
+            if box_predictor is None:
+                raise ValueError("num_classes should not be None when box_predictor is not specified")
+
+        out_channels = backbone.out_channels
+
+        if rpn_anchor_generator is None:
+            rpn_anchor_generator = _default_anchorgen()
+        if rpn_head is None:
+            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+
+        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+        rpn = RegionProposalNetwork(
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_pre_nms_top_n,
+            rpn_post_nms_top_n,
+            rpn_nms_thresh,
+            score_thresh=rpn_score_thresh,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+        )
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        super().__init__(backbone, rpn, roi_heads, transform)
+
+
+class TwoMLPHead(nn.Module):
+    """
+    Standard heads for FPN-based models
+
+    Args:
+        in_channels (int): number of input channels
+        representation_size (int): size of the intermediate representation
+    """
+
+    def __init__(self, in_channels, representation_size):
+        super().__init__()
+
+        self.fc6 = nn.Linear(in_channels, representation_size)
+        self.fc7 = nn.Linear(representation_size, representation_size)
+
+    def forward(self, x):
+        x = x.flatten(start_dim=1)
+
+        x = F.relu(self.fc6(x))
+        x = F.relu(self.fc7(x))
+
+        return x
+
+
+class FastRCNNConvFCHead(nn.Sequential):
+    def __init__(
+        self,
+        input_size: Tuple[int, int, int],
+        conv_layers: List[int],
+        fc_layers: List[int],
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        """
+        Args:
+            input_size (Tuple[int, int, int]): the input size in CHW format.
+            conv_layers (list): feature dimensions of each Convolution layer
+            fc_layers (list): feature dimensions of each FCN layer
+            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+        """
+        in_channels, in_height, in_width = input_size
+
+        blocks = []
+        previous_channels = in_channels
+        for current_channels in conv_layers:
+            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
+            previous_channels = current_channels
+        blocks.append(nn.Flatten())
+        previous_channels = previous_channels * in_height * in_width
+        for current_channels in fc_layers:
+            blocks.append(nn.Linear(previous_channels, current_channels))
+            blocks.append(nn.ReLU(inplace=True))
+            previous_channels = current_channels
+
+        super().__init__(*blocks)
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+                if layer.bias is not None:
+                    nn.init.zeros_(layer.bias)
+
+
+class FastRCNNPredictor(nn.Module):
+    """
+    Standard classification + bounding box regression layers
+    for Fast R-CNN.
+
+    Args:
+        in_channels (int): number of input channels
+        num_classes (int): number of output classes (including background)
+    """
+
+    def __init__(self, in_channels, num_classes):
+        super().__init__()
+        self.cls_score = nn.Linear(in_channels, num_classes)
+        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+    def forward(self, x):
+        if x.dim() == 4:
+            torch._assert(
+                list(x.shape[2:]) == [1, 1],
+                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
+            )
+        x = x.flatten(start_dim=1)
+        scores = self.cls_score(x)
+        bbox_deltas = self.bbox_pred(x)
+
+        return scores, bbox_deltas
+
+
+_COMMON_META = {
+    "categories": _COCO_CATEGORIES,
+    "min_size": (1, 1),
+}
+
+
+class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 41755286,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 37.0,
+                }
+            },
+            "_ops": 134.38,
+            "_file_size": 159.743,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 43712278,
+            "recipe": "https://github.com/pytorch/vision/pull/5763",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 46.7,
+                }
+            },
+            "_ops": 280.371,
+            "_file_size": 167.104,
+            "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 32.8,
+                }
+            },
+            "_ops": 4.494,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 22.8,
+                }
+            },
+            "_ops": 0.719,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn(
+    *,
+    weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
+    Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
+    paper.
+
+    .. betastatus:: detection module
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and a targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each detection
+        - scores (``Tensor[N]``): the scores of each detection
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> # For training
+        >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
+        >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
+        >>> labels = torch.randint(1, 91, (4, 11))
+        >>> images = list(image for image in images)
+        >>> targets = []
+        >>> for i in range(len(images)):
+        >>>     d = {}
+        >>>     d['boxes'] = boxes[i]
+        >>>     d['labels'] = labels[i]
+        >>>     targets.append(d)
+        >>> output = model(images, targets)
+        >>> # For inference
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn_v2(
+    *,
+    weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = None,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
+    Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
+        :members:
+    """
+    weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    backbone = resnet50(weights=weights_backbone, progress=progress)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+    rpn_anchor_generator = _default_anchorgen()
+    rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+    box_head = FastRCNNConvFCHead(
+        (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+    )
+    model = FasterRCNN(
+        backbone,
+        num_classes=num_classes,
+        rpn_anchor_generator=rpn_anchor_generator,
+        rpn_head=rpn_head,
+        box_head=box_head,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+def _fasterrcnn_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
+    progress: bool,
+    num_classes: Optional[int],
+    weights_backbone: Optional[MobileNet_V3_Large_Weights],
+    trainable_backbone_layers: Optional[int],
+    **kwargs: Any,
+) -> FasterRCNN:
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
+    anchor_sizes = (
+        (
+            32,
+            64,
+            128,
+            256,
+            512,
+        ),
+    ) * 3
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    model = FasterRCNN(
+        backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_320_fpn(
+    *,
+    weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "min_size": 320,
+        "max_size": 640,
+        "rpn_pre_nms_top_n_test": 150,
+        "rpn_post_nms_top_n_test": 150,
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _fasterrcnn_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _fasterrcnn_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )

+ 336 - 0
lcnn/models/detection/transform.py

@@ -0,0 +1,336 @@
+import math
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torchvision
+from torch import nn, Tensor
+
+# from .image_list import ImageList
+# from .roi_heads import paste_masks_in_image
+from .ROI_heads import paste_masks_in_image
+
+class ImageList:
+    """
+    Structure that holds a list of images (of possibly
+    varying sizes) as a single tensor.
+    This works by padding the images to the same size,
+    and storing in a field the original sizes of each image
+
+    Args:
+        tensors (tensor): Tensor containing images.
+        image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
+    """
+
+    def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
+        self.tensors = tensors
+        self.image_sizes = image_sizes
+
+    def to(self, device: torch.device) -> "ImageList":
+        cast_tensor = self.tensors.to(device)
+        return ImageList(cast_tensor, self.image_sizes)
+
+
+def _get_shape_onnx(image: Tensor) -> Tensor:
+    from torch.onnx import operators
+
+    return operators.shape_as_tensor(image)[-2:]
+
+def _fake_cast_onnx(v: Tensor) -> float:
+    # ONNX requires a tensor but here we fake its type for JIT.
+    return v
+
+def _resize_image_and_masks(
+    image: Tensor,
+    self_min_size: int,
+    self_max_size: int,
+    target: Optional[Dict[str, Tensor]] = None,
+    fixed_size: Optional[Tuple[int, int]] = None,
+) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+    if torchvision._is_tracing():
+        im_shape = _get_shape_onnx(image)
+    elif torch.jit.is_scripting():
+        im_shape = torch.tensor(image.shape[-2:])
+    else:
+        im_shape = image.shape[-2:]
+
+    size: Optional[List[int]] = None
+    scale_factor: Optional[float] = None
+    recompute_scale_factor: Optional[bool] = None
+    if fixed_size is not None:
+        size = [fixed_size[1], fixed_size[0]]
+    else:
+        if torch.jit.is_scripting() or torchvision._is_tracing():
+            min_size = torch.min(im_shape).to(dtype=torch.float32)
+            max_size = torch.max(im_shape).to(dtype=torch.float32)
+            self_min_size_f = float(self_min_size)
+            self_max_size_f = float(self_max_size)
+            scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
+
+            if torchvision._is_tracing():
+                scale_factor = _fake_cast_onnx(scale)
+            else:
+                scale_factor = scale.item()
+
+        else:
+            # Do it the normal way
+            min_size = min(im_shape)
+            max_size = max(im_shape)
+            scale_factor = min(self_min_size / min_size, self_max_size / max_size)
+
+        recompute_scale_factor = True
+
+    image = torch.nn.functional.interpolate(
+        image[None],
+        size=size,
+        scale_factor=scale_factor,
+        mode="bilinear",
+        recompute_scale_factor=recompute_scale_factor,
+        align_corners=False,
+    )[0]
+
+    if target is None:
+        return image, target
+
+    if "masks" in target:
+        mask = target["masks"]
+        mask = torch.nn.functional.interpolate(
+            mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
+        )[:, 0].byte()
+        target["masks"] = mask
+    return image, target
+
+
+def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+    ratios = [
+        torch.tensor(s, dtype=torch.float32, device=boxes.device)
+        / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
+        for s, s_orig in zip(new_size, original_size)
+    ]
+    ratio_height, ratio_width = ratios
+    xmin, ymin, xmax, ymax = boxes.unbind(1)
+
+    xmin = xmin * ratio_width
+    xmax = xmax * ratio_width
+    ymin = ymin * ratio_height
+    ymax = ymax * ratio_height
+    return torch.stack((xmin, ymin, xmax, ymax), dim=1)
+
+
+def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+    ratios = [
+        torch.tensor(s, dtype=torch.float32, device=keypoints.device)
+        / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
+        for s, s_orig in zip(new_size, original_size)
+    ]
+    ratio_h, ratio_w = ratios
+    resized_data = keypoints.clone()
+    if torch._C._get_tracing_state():
+        resized_data_0 = resized_data[:, :, 0] * ratio_w
+        resized_data_1 = resized_data[:, :, 1] * ratio_h
+        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
+    else:
+        resized_data[..., 0] *= ratio_w
+        resized_data[..., 1] *= ratio_h
+    return resized_data
+
+
+class GeneralizedRCNNTransform(nn.Module):
+    """
+    Performs input / target transformation before feeding the data to a GeneralizedRCNN
+    model.
+
+    The transformations it performs are:
+        - input normalization (mean subtraction and std division)
+        - input / target resizing to match min_size / max_size
+
+    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
+    """
+
+    def __init__(
+        self,
+        min_size: int,
+        max_size: int,
+        image_mean: List[float],
+        image_std: List[float],
+        size_divisible: int = 32,
+        fixed_size: Optional[Tuple[int, int]] = None,
+        **kwargs: Any,
+    ):
+        super().__init__()
+        if not isinstance(min_size, (list, tuple)):
+            min_size = (min_size,)
+        self.min_size = min_size
+        self.max_size = max_size
+        self.image_mean = image_mean
+        self.image_std = image_std
+        self.size_divisible = size_divisible
+        self.fixed_size = fixed_size
+        self._skip_resize = kwargs.pop("_skip_resize", False)
+
+    def forward(
+        self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
+    ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
+        images = [img for img in images]
+        if targets is not None:
+            # make a copy of targets to avoid modifying it in-place
+            # once torchscript supports dict comprehension
+            # this can be simplified as follows
+            # targets = [{k: v for k,v in t.items()} for t in targets]
+            targets_copy: List[Dict[str, Tensor]] = []
+            for t in targets:
+                data: Dict[str, Tensor] = {}
+                for k, v in t.items():
+                    data[k] = v
+                targets_copy.append(data)
+            targets = targets_copy
+        for i in range(len(images)):
+            image = images[i]
+            target_index = targets[i] if targets is not None else None
+
+            if image.dim() != 3:
+                raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
+            image = self.normalize(image)
+            image, target_index = self.resize(image, target_index)
+            images[i] = image
+            if targets is not None and target_index is not None:
+                targets[i] = target_index
+
+        image_sizes = [img.shape[-2:] for img in images]
+        images = self.batch_images(images, size_divisible=self.size_divisible)
+        image_sizes_list: List[Tuple[int, int]] = []
+        for image_size in image_sizes:
+            torch._assert(
+                len(image_size) == 2,
+                f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
+            )
+            image_sizes_list.append((image_size[0], image_size[1]))
+
+        image_list = ImageList(images, image_sizes_list)
+        return image_list, targets
+
+    def normalize(self, image: Tensor) -> Tensor:
+        if not image.is_floating_point():
+            raise TypeError(
+                f"Expected input images to be of floating type (in range [0, 1]), "
+                f"but found type {image.dtype} instead"
+            )
+        dtype, device = image.dtype, image.device
+        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
+        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
+        return (image - mean[:, None, None]) / std[:, None, None]
+
+    def torch_choice(self, k: List[int]) -> int:
+        """
+        Implements `random.choice` via torch ops, so it can be compiled with
+        TorchScript and we use PyTorch's RNG (not native RNG)
+        """
+        index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
+        return k[index]
+
+    def resize(
+        self,
+        image: Tensor,
+        target: Optional[Dict[str, Tensor]] = None,
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        h, w = image.shape[-2:]
+        if self.training:
+            if self._skip_resize:
+                return image, target
+            size = self.torch_choice(self.min_size)
+        else:
+            size = self.min_size[-1]
+        image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
+
+        if target is None:
+            return image, target
+
+        bbox = target["boxes"]
+        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
+        target["boxes"] = bbox
+
+        if "keypoints" in target:
+            keypoints = target["keypoints"]
+            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
+            target["keypoints"] = keypoints
+        return image, target
+
+    # _onnx_batch_images() is an implementation of
+    # batch_images() that is supported by ONNX tracing.
+    @torch.jit.unused
+    def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+        max_size = []
+        for i in range(images[0].dim()):
+            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
+            max_size.append(max_size_i)
+        stride = size_divisible
+        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
+        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
+        max_size = tuple(max_size)
+
+        # work around for
+        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+        # which is not yet supported in onnx
+        padded_imgs = []
+        for img in images:
+            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+            padded_imgs.append(padded_img)
+
+        return torch.stack(padded_imgs)
+
+    def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
+        maxes = the_list[0]
+        for sublist in the_list[1:]:
+            for index, item in enumerate(sublist):
+                maxes[index] = max(maxes[index], item)
+        return maxes
+
+    def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+        if torchvision._is_tracing():
+            # batch_images() does not export well to ONNX
+            # call _onnx_batch_images() instead
+            return self._onnx_batch_images(images, size_divisible)
+
+        max_size = self.max_by_axis([list(img.shape) for img in images])
+        stride = float(size_divisible)
+        max_size = list(max_size)
+        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
+        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
+
+        batch_shape = [len(images)] + max_size
+        batched_imgs = images[0].new_full(batch_shape, 0)
+        for i in range(batched_imgs.shape[0]):
+            img = images[i]
+            batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+
+        return batched_imgs
+
+    def postprocess(
+        self,
+        result: List[Dict[str, Tensor]],
+        image_shapes: List[Tuple[int, int]],
+        original_image_sizes: List[Tuple[int, int]],
+    ) -> List[Dict[str, Tensor]]:
+        if self.training:
+            return result
+        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
+            boxes = pred["boxes"]
+            boxes = resize_boxes(boxes, im_s, o_im_s)
+            result[i]["boxes"] = boxes
+            if "masks" in pred:
+                masks = pred["masks"]
+                masks = paste_masks_in_image(masks, boxes, o_im_s)
+                result[i]["masks"] = masks
+            if "keypoints" in pred:
+                keypoints = pred["keypoints"]
+                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
+                result[i]["keypoints"] = keypoints
+        return result
+
+    def __repr__(self) -> str:
+        format_string = f"{self.__class__.__name__}("
+        _indent = "\n    "
+        format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
+        format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
+        format_string += "\n)"
+        return format_string

+ 77 - 0
lcnn/models/fasterrcnn_resnet50.py

@@ -0,0 +1,77 @@
+import torch
+import torch.nn as nn
+import torchvision
+from torchvision.models.detection.transform import GeneralizedRCNNTransform
+# from .detection.transform import GeneralizedRCNNTransform
+
+def get_model(num_classes):
+    # 加载预训练的ResNet-50 FPN backbone
+    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+
+    # 获取分类器的输入特征数
+    in_features = model.roi_heads.box_predictor.cls_score.in_features
+
+    # 替换分类器以适应新的类别数量
+    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
+
+    return model
+
+
+
+class Fasterrcnn_resnet50(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1):
+        super(Fasterrcnn_resnet50, self).__init__()
+
+        self.model = get_model(num_classes=5)
+        self.backbone = self.model.backbone
+
+        # self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
+
+        # out_channels = self.backbone.out_channels
+        # resolution = self.box_roi_pool.output_size[0]
+        # representation_size = 1024
+        # self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+        #
+        # self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        # 多任务输出层
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
+
+        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
+                                             image_std=[0.229, 0.224, 0.225])
+        images, targets = transform(x, target1)
+        x_ = self.backbone(images.tensors)
+
+        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
+        # print(f'backbone:{self.backbone}')
+        # print(f'Fasterrcnn_resnet50 x_:{x_}')
+        feature_ = x_['0']  # 图片特征
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(feature_)
+            outputs.append(output)  # 多头
+
+        if train_or_val == "training":
+            loss_box = self.model(x, target1)
+            return outputs, feature_, loss_box
+        else:
+            box_all = self.model(x, target1)
+            return outputs, feature_, box_all
+
+
+def fasterrcnn_resnet50(**kwargs):
+    model = Fasterrcnn_resnet50(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1)
+    )
+    return model

+ 427 - 0
lcnn/trainer.py

@@ -0,0 +1,427 @@
+import atexit
+import os
+import os.path as osp
+import shutil
+import signal
+import subprocess
+import threading
+import time
+from timeit import default_timer as timer
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn.functional as F
+from skimage import io
+from tensorboardX import SummaryWriter
+
+from lcnn.config import C, M
+from lcnn.utils import recursive_to
+import matplotlib
+
+from 冻结参数训练 import verify_freeze_params
+import os
+
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+from .postprocess import postprocess
+
+os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
+
+
+# matplotlib.use('Agg')  # 使用无窗口后端
+
+
+# 绘图
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["box"][0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred["preds"]
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.7]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+class Trainer(object):
+    def __init__(self, device, model, optimizer, train_loader, val_loader, out):
+        self.device = device
+
+        self.model = model
+        self.optim = optimizer
+
+        self.train_loader = train_loader
+        self.val_loader = val_loader
+        self.batch_size = C.model.batch_size
+
+        self.validation_interval = C.io.validation_interval
+
+        self.out = out
+        if not osp.exists(self.out):
+            os.makedirs(self.out)
+
+        # self.run_tensorboard()
+        self.writer = SummaryWriter('logs/')
+        time.sleep(1)
+
+        self.epoch = 0
+        self.iteration = 0
+        self.max_epoch = C.optim.max_epoch
+        self.lr_decay_epoch = C.optim.lr_decay_epoch
+        self.num_stacks = C.model.num_stacks
+        self.mean_loss = self.best_mean_loss = 1e1000
+
+        self.loss_labels = None
+        self.avg_metrics = None
+        self.metrics = np.zeros(0)
+
+        self.show_line = show_line
+
+    # def run_tensorboard(self):
+    #     board_out = osp.join(self.out, "tensorboard")
+    #     if not osp.exists(board_out):
+    #         os.makedirs(board_out)
+    #     self.writer = SummaryWriter(board_out)
+    #     os.environ["CUDA_VISIBLE_DEVICES"] = ""
+    #     p = subprocess.Popen(
+    #         ["tensorboard", f"--logdir={board_out}", f"--port={C.io.tensorboard_port}"]
+    #     )
+    #
+    #     def killme():
+    #         os.kill(p.pid, signal.SIGTERM)
+    #
+    #     atexit.register(killme)
+
+    def _loss(self, result):
+        losses = result["losses"]
+        # Don't move loss label to other place.
+        # If I want to change the loss, I just need to change this function.
+        if self.loss_labels is None:
+            self.loss_labels = ["sum"] + list(losses[0].keys())
+            self.metrics = np.zeros([self.num_stacks, len(self.loss_labels)])
+            print()
+            print(
+                "| ".join(
+                    ["progress "]
+                    + list(map("{:7}".format, self.loss_labels))
+                    + ["speed"]
+                )
+            )
+            with open(f"{self.out}/loss.csv", "a") as fout:
+                print(",".join(["progress"] + self.loss_labels), file=fout)
+
+        total_loss = 0
+        for i in range(self.num_stacks):
+            for j, name in enumerate(self.loss_labels):
+                if name == "sum":
+                    continue
+                if name not in losses[i]:
+                    assert i != 0
+                    continue
+                loss = losses[i][name].mean()
+                self.metrics[i, 0] += loss.item()
+                self.metrics[i, j] += loss.item()
+                total_loss += loss
+        return total_loss
+
+
+    def validate(self):
+        tprint("Running validation...", " " * 75)
+        training = self.model.training
+        self.model.eval()
+
+        # viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
+        # npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
+        # osp.exists(viz) or os.makedirs(viz)
+        # osp.exists(npz) or os.makedirs(npz)
+
+        total_loss = 0
+        self.metrics[...] = 0
+        with torch.no_grad():
+            for batch_idx, (image, meta, target, target_b) in enumerate(self.val_loader):
+                input_dict = {
+                    "image": recursive_to(image, self.device),
+                    "meta": recursive_to(meta, self.device),
+                    "target": recursive_to(target, self.device),
+                    "target_b": recursive_to(target_b, self.device),
+                    "mode": "validation",
+                }
+                result = self.model(input_dict)
+                # print(f'image:{image.shape}')
+                # print(result["box"])
+
+                # total_loss += self._loss(result)
+
+                print(f'self.epoch:{self.epoch}')
+                # print(result.keys())
+                self.show_line(image[0], result, self.epoch, self.writer)
+
+                # H = result["preds"]
+                # for i in range(H["jmap"].shape[0]):
+                #     index = batch_idx * M.batch_size_eval + i
+                #     np.savez(
+                #         f"{npz}/{index:06}.npz",
+                #         **{k: v[i].cpu().numpy() for k, v in H.items()},
+                #     )
+                #     if index >= 20:
+                #         continue
+                #     self._plot_samples(i, index, H, meta, target, f"{viz}/{index:06}")
+
+        # self._write_metrics(len(self.val_loader), total_loss, "validation", True)
+        # self.mean_loss = total_loss / len(self.val_loader)
+
+        torch.save(
+            {
+                "iteration": self.iteration,
+                "arch": self.model.__class__.__name__,
+                "optim_state_dict": self.optim.state_dict(),
+                "model_state_dict": self.model.state_dict(),
+                "best_mean_loss": self.best_mean_loss,
+            },
+            osp.join(self.out, "checkpoint_latest.pth"),
+        )
+        # shutil.copy(
+        #     osp.join(self.out, "checkpoint_latest.pth"),
+        #     osp.join(npz, "checkpoint.pth"),
+        # )
+        if self.mean_loss < self.best_mean_loss:
+            self.best_mean_loss = self.mean_loss
+            shutil.copy(
+                osp.join(self.out, "checkpoint_latest.pth"),
+                osp.join(self.out, "checkpoint_best.pth"),
+            )
+
+        if training:
+            self.model.train()
+
+    def verify_freeze_params(model, freeze_config):
+        """
+        验证参数冻结是否生效
+        """
+        print("\n===== Verifying Parameter Freezing =====")
+
+        for name, module in model.named_children():
+            if name in freeze_config:
+                if freeze_config[name]:
+                    print(f"\nChecking module: {name}")
+                    for param_name, param in module.named_parameters():
+                        print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+            # 特别处理fc2子模块
+            if name == 'fc2' and 'fc2_submodules' in freeze_config:
+                for subname, submodule in module.named_children():
+                    if subname in freeze_config['fc2_submodules']:
+                        if freeze_config['fc2_submodules'][subname]:
+                            print(f"\nChecking fc2 submodule: {subname}")
+                            for param_name, param in submodule.named_parameters():
+                                print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+    def train_epoch(self):
+        self.model.train()
+
+        time = timer()
+        for batch_idx, (image, meta, target, target_b) in enumerate(self.train_loader):
+            self.optim.zero_grad()
+            self.metrics[...] = 0
+
+            input_dict = {
+                "image": recursive_to(image, self.device),
+                "meta": recursive_to(meta, self.device),
+                "target": recursive_to(target, self.device),
+                "target_b": recursive_to(target_b, self.device),
+                "mode": "training",
+            }
+            result = self.model(input_dict)
+
+            loss = self._loss(result)
+            if np.isnan(loss.item()):
+                raise ValueError("loss is nan while training")
+            loss.backward()
+            self.optim.step()
+
+            if self.avg_metrics is None:
+                self.avg_metrics = self.metrics
+            else:
+                self.avg_metrics = self.avg_metrics * 0.9 + self.metrics * 0.1
+            self.iteration += 1
+            self._write_metrics(1, loss.item(), "training", do_print=False)
+
+            if self.iteration % 4 == 0:
+                tprint(
+                    f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
+                    + "| ".join(map("{:.5f}".format, self.avg_metrics[0]))
+                    + f"| {4 * self.batch_size / (timer() - time):04.1f} "
+                )
+                time = timer()
+            num_images = self.batch_size * self.iteration
+            # if num_images % self.validation_interval == 0 or num_images == 4:
+            #     self.validate()
+            #     time = timer()
+        self.validate()
+        # verify_freeze_params()
+
+    def _write_metrics(self, size, total_loss, prefix, do_print=False):
+        for i, metrics in enumerate(self.metrics):
+            for label, metric in zip(self.loss_labels, metrics):
+                self.writer.add_scalar(
+                    f"{prefix}/{i}/{label}", metric / size, self.iteration
+                )
+            if i == 0 and do_print:
+                csv_str = (
+                        f"{self.epoch:03}/{self.iteration * self.batch_size:07},"
+                        + ",".join(map("{:.11f}".format, metrics / size))
+                )
+                prt_str = (
+                        f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
+                        + "| ".join(map("{:.5f}".format, metrics / size))
+                )
+                with open(f"{self.out}/loss.csv", "a") as fout:
+                    print(csv_str, file=fout)
+                pprint(prt_str, " " * 7)
+        self.writer.add_scalar(
+            f"{prefix}/total_loss", total_loss / size, self.iteration
+        )
+        return total_loss
+
+    def _plot_samples(self, i, index, result, meta, target, prefix):
+        fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
+        img = io.imread(fn)
+        imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
+
+        mask_result = result["jmap"][i].cpu().numpy()
+        mask_target = target["jmap"][i].cpu().numpy()
+        for ch, (ia, ib) in enumerate(zip(mask_target, mask_result)):
+            imshow(ia), plt.savefig(f"{prefix}_mask_{ch}a.jpg"), plt.close()
+            imshow(ib), plt.savefig(f"{prefix}_mask_{ch}b.jpg"), plt.close()
+
+        line_result = result["lmap"][i].cpu().numpy()
+        line_target = target["lmap"][i].cpu().numpy()
+        imshow(line_target), plt.savefig(f"{prefix}_line_a.jpg"), plt.close()
+        imshow(line_result), plt.savefig(f"{prefix}_line_b.jpg"), plt.close()
+
+        def draw_vecl(lines, sline, juncs, junts, fn):
+            imshow(img)
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+            if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+                for i, j in enumerate(junts):
+                    if i > 0 and (i == junts[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+            plt.savefig(fn), plt.close()
+
+        junc = meta[i]["junc"].cpu().numpy() * 4
+        jtyp = meta[i]["jtyp"].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+        rjuncs = result["juncs"][i].cpu().numpy() * 4
+        rjunts = None
+        if "junts" in result:
+            rjunts = result["junts"][i].cpu().numpy() * 4
+
+        lpre = meta[i]["lpre"].cpu().numpy() * 4
+        vecl_target = meta[i]["lpre_label"].cpu().numpy()
+        vecl_result = result["lines"][i].cpu().numpy() * 4
+        score = result["score"][i].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
+        draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+
+    def train(self):
+        plt.rcParams["figure.figsize"] = (24, 24)
+        # if self.iteration == 0:
+        #     self.validate()
+        epoch_size = len(self.train_loader)
+        start_epoch = self.iteration // epoch_size
+
+        for self.epoch in range(start_epoch, self.max_epoch):
+            print(f"Epoch {self.epoch}/{C.optim.max_epoch} - Iteration {self.iteration}/{epoch_size}")
+            if self.epoch == self.lr_decay_epoch:
+                self.optim.param_groups[0]["lr"] /= 10
+            self.train_epoch()
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def tprint(*args):
+    """Temporarily prints things on the screen"""
+    print("\r", end="")
+    print(*args, end="")
+
+
+def pprint(*args):
+    """Permanently prints things on the screen"""
+    print("\r", end="")
+    print(*args)
+
+
+def _launch_tensorboard(board_out, port, out):
+    os.environ["CUDA_VISIBLE_DEVICES"] = ""
+    p = subprocess.Popen(["tensorboard", f"--logdir={board_out}", f"--port={port}"])
+
+    def kill():
+        os.kill(p.pid, signal.SIGTERM)
+
+    atexit.register(kill)

+ 1 - 1
libs/vision_libs/models/__init__.py

@@ -13,7 +13,7 @@ from .squeezenet import *
 from .vgg import *
 from .vision_transformer import *
 from .swin_transformer import *
-from .maxvit import *
+# from .maxvit import *
 from . import detection, optical_flow, quantization, segmentation, video
 
 # The Weights and WeightsEnum are developer-facing utils that we make public for

+ 1 - 6
models/dataset_tool.py

@@ -224,17 +224,12 @@ def line_boxes(target):
     lines = lpre
     sline = np.ones(lpre.shape[0])
 
-    keypoints = []
-
     if len(lines) > 0 and not (lines[0] == 0).all():
         for i, ((a, b), s) in enumerate(zip(lines, sline)):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
 
-            keypoints.append([a[0], b[0]])
-            keypoints.append([a[1], b[1]])
-
             if a[1] > b[1]:
                 ymax = a[1] + 10
                 ymin = b[1] - 10
@@ -249,7 +244,7 @@ def line_boxes(target):
                 xmax = b[0] + 10
             boxs.append([ymin, xmin, ymax, xmax])
 
-    return torch.tensor(boxs), torch.tensor(keypoints)
+    return torch.tensor(boxs)
 
 
 def read_polygon_points_wire(lbl_path, shape):

+ 0 - 0
models/ins/__init__.py


+ 143 - 0
models/ins/maskrcnn.py

@@ -0,0 +1,143 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.ins.trainer import train_cfg
+from tools import utils
+
+
+class MaskRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=0, transforms=None):
+        super(MaskRCNNModel, self).__init__()
+        self.__model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
+            weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
+        if transforms is None:
+            self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        if num_classes != 0:
+            self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes=parameters['num_classes']
+        # print(f'num_classes:{num_classes}')
+        self.set_num_classes(num_classes)
+        train_cfg(self.__model, cfg)
+
+    def set_num_classes(self, num_classes):
+        in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+        self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+        in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+        hidden_layer = 256
+        self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
+                                                                  num_classes=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+    def predict(self, src, show_box=True, show_mask=True):
+        self.__model.eval()
+
+        img = read_image(src)
+        img = self.transforms(img)
+        img = img.to(self.device)
+        result = self.__model([img])
+        print(f'result:{result}')
+        masks = result[0]['masks']
+        boxes = result[0]['boxes']
+        # cv2.imshow('mask',masks[0].cpu().detach().numpy())
+        boxes = boxes.cpu().detach()
+        drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
+        print(f'drawn_boxes:{drawn_boxes.shape}')
+        boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
+        # boxed_img=cv2.resize(boxed_img,(800,800))
+        # cv2.imshow('boxes',boxed_img)
+
+        mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
+
+        mask = cv2.resize(mask, (800, 800))
+        # cv2.imshow('mask',mask)
+        img = img.cpu().detach().permute(1, 2, 0).numpy()
+
+        masked_img = self.overlay_masks_on_image(boxed_img, masks)
+        masked_img = cv2.resize(masked_img, (800, 800))
+        cv2.imshow('img_masks', masked_img)
+        # show_img_boxes_masks(img, boxes, masks)
+        cv2.waitKey(0)
+
+    def generate_colors(self, n):
+        """
+        生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
+
+        :param n: 需要的颜色数量
+        :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
+        """
+        hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
+        bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
+        return bgr_colors
+
+    def overlay_masks_on_image(self, image, masks, alpha=0.6):
+        """
+        在原图上叠加多个掩码,每个掩码使用不同的颜色。
+
+        :param image: 原图 (NumPy 数组)
+        :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
+        :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
+        :param alpha: 掩码的透明度 (0.0 到 1.0)
+        :return: 叠加了多个掩码的图像
+        """
+        colors = self.generate_colors(len(masks))
+        if len(masks) != len(colors):
+            raise ValueError("The number of masks and colors must be the same.")
+
+        # 复制原图,避免修改原始图像
+        overlay = image.copy()
+
+        for mask, color in zip(masks, colors):
+            # 确保掩码是二值图像
+            mask = mask.cpu().detach().permute(1, 2, 0).numpy()
+            binary_mask = (mask > 0).astype(np.uint8) * 255  # 你可以根据实际情况调整阈值
+
+            # 创建彩色掩码
+            colored_mask = np.zeros_like(image)
+
+            colored_mask[:] = color
+            colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
+
+            # 将彩色掩码与当前的叠加图像混合
+            overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
+
+        return overlay
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    ins_model = MaskRCNNModel()
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    ins_model.train(cfg='train.yaml')

+ 93 - 0
models/ins/maskrcnn_dataset.py

@@ -0,0 +1,93 @@
+import os
+
+import PIL
+import cv2
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.data import Dataset
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
+
+
+class MaskRCNNDataset(Dataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
+        self.data_path = dataset_path
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        # print('maskrcnn inited!')
+
+    def __getitem__(self, item):
+        # print('__getitem__')
+        img_path = os.path.join(self.img_path, self.imgs[item])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
+        img = PIL.Image.open(img_path).convert('RGB')
+        # h, w = np.array(img).shape[:2]
+        w, h = img.size
+        # print(f'h,w:{h, w}')
+        target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img,target)
+        else:
+            img=self.deafult_transform(img)
+        # print(f'img:{img.shape},target:{target}')
+        return img, target
+
+    def create_masks_from_polygons(self, polygons, image_shape):
+        """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+        masks = []
+
+        for polygon_data, col in zip(polygons, colors):
+            mask = np.zeros(image_shape[:2], dtype=np.uint8)
+            # 将多边形顶点转换为 NumPy 数组
+            _, polygon = polygon_data
+            pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+            # 使用 OpenCV 的 fillPoly 函数填充多边形
+            # print(f'color:{col[:3]}')
+            cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+            mask = torch.from_numpy(mask)
+            mask[mask != 0] = 1
+            masks.append(mask)
+
+        return masks
+
+    def read_target(self, item, lbl_path, shape):
+        # print(f'lbl_path:{lbl_path}')
+        h, w = shape
+        labels = []
+        masks = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels, masks = read_masks_from_pixels(lbl_path, shape)
+
+        target = {}
+        target["boxes"] = masks_to_boxes(torch.stack(masks))
+        target["labels"] = torch.stack(labels)
+        target["masks"] = torch.stack(masks)
+        target["image_id"] = torch.tensor(item)
+        target["area"] = torch.zeros(len(masks))
+        target["iscrowd"] = torch.zeros(len(masks))
+        return target
+
+    def heatmap_enhance(self, img):
+        # 直方图均衡化
+        img_eq = cv2.equalizeHist(img)
+
+        # 自适应直方图均衡化
+        # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
+        # img_clahe = clahe.apply(img)
+
+        # 将灰度图转换为热力图
+        heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
+
+    def __len__(self):
+        return len(self.imgs)

+ 31 - 0
models/ins/train.yaml

@@ -0,0 +1,31 @@
+
+
+dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
+
+#train parameters
+num_classes: 5
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.0005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: polygon
+enable_logs: True
+augmentation: True
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 220 - 0
models/ins/trainer.py

@@ -0,0 +1,220 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from tools import utils, presets
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+
+def load_train_parameter(cfg):
+    parameters = read_yaml(cfg)
+    return parameters
+
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = MaskRCNNDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 312 - 0
models/keypoint/keypoint_dataset.py

@@ -1,3 +1,107 @@
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
+========
+# import glob
+# import json
+# import math
+# import os
+# import random
+#
+# import numpy as np
+# import numpy.linalg as LA
+# import torch
+# from skimage import io
+# from torch.utils.data import Dataset
+# from torch.utils.data.dataloader import default_collate
+#
+# from lcnn.config import M
+#
+# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
+#
+#
+# class WireframeDataset(Dataset):
+#     def __init__(self, rootdir, split):
+#         self.rootdir = rootdir
+#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
+#         filelist.sort()
+#
+#         # print(f"n{split}:", len(filelist))
+#         self.split = split
+#         self.filelist = filelist
+#
+#     def __len__(self):
+#         return len(self.filelist)
+#
+#     def __getitem__(self, idx):
+#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
+#         image = io.imread(iname).astype(float)[:, :, :3]
+#         if "a1" in self.filelist[idx]:
+#             image = image[:, ::-1, :]
+#         image = (image - M.image.mean) / M.image.stddev
+#         image = np.rollaxis(image, 2).copy()
+#
+#         with np.load(self.filelist[idx]) as npz:
+#             target = {
+#                 name: torch.from_numpy(npz[name]).float()
+#                 for name in ["jmap", "joff", "lmap"]
+#             }
+#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
+#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
+#             npos, nneg = len(lpos), len(lneg)
+#             lpre = np.concatenate([lpos, lneg], 0)
+#             for i in range(len(lpre)):
+#                 if random.random() > 0.5:
+#                     lpre[i] = lpre[i, ::-1]
+#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+#             feat = [
+#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
+#                 ldir * M.use_slop,
+#                 lpre[:, :, 2],
+#             ]
+#             feat = np.concatenate(feat, 1)
+#             meta = {
+#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
+#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
+#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
+#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
+#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
+#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
+#                 "lpre_feat": torch.from_numpy(feat),
+#             }
+#
+#         labels = []
+#         labels = read_masks_from_pixels_wire(iname, (512, 512))
+#         # if self.target_type == 'polygon':
+#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
+#         # elif self.target_type == 'pixel':
+#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
+#
+#         target["labels"] = torch.stack(labels)
+#         target["boxes"] = line_boxes_faster(meta)
+#
+#
+#         return torch.from_numpy(image).float(), meta, target
+#
+#     def adjacency_matrix(self, n, link):
+#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+#         link = torch.from_numpy(link)
+#         if len(link) > 0:
+#             mat[link[:, 0], link[:, 1]] = 1
+#             mat[link[:, 1], link[:, 0]] = 1
+#         return mat
+#
+#
+# def collate(batch):
+#     return (
+#         default_collate([b[0] for b in batch]),
+#         [b[1] for b in batch],
+#         default_collate([b[2] for b in batch]),
+#     )
+
+
+# 原LCNN数据格式,改了属性名,加了box相关
+
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
 from torch.utils.data.dataset import T_co
 
 from models.base.base_dataset import BaseDataset
@@ -30,8 +134,12 @@ def validate_keypoints(keypoints, image_width, image_height):
         if not (0 <= x < image_width and 0 <= y < image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 
 class KeypointDataset(BaseDataset):
+========
+class  WireframeDataset(BaseDataset):
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
         super().__init__(dataset_path)
 
@@ -199,7 +307,211 @@ class KeypointDataset(BaseDataset):
 
 
 
+<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 if __name__ == '__main__':
     path=r"I:\datasets\wirenet_1000"
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
     dataset.show(7)
+========
+
+'''
+# 使用roi_head数据格式有要求,更改数据格式
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+from tools.presets import DetectionPresetTrain
+
+
+class WireframeDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+        self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip")  # multiscale会改变图像大小
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        # if self.transforms:
+        #     img, target = self.transforms(img, target)
+        # else:
+        #     img = self.default_transform(img)
+
+        img, target = self.data_augmentation(img, target)
+
+        print(f'img:{img.shape}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        # if self.target_type == 'polygon':
+        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        # elif self.target_type == 'pixel':
+        #     labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+        # target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        self._draw_vecl(img_path, target)
+
+    def show_img(self, img_path):
+
+        """根据给定的图片路径展示图像及其标注信息"""
+        # 获取对应的标签文件路径
+        img_name = os.path.basename(img_path)
+        img_path = os.path.join(self.img_path, img_name)
+        print(img_path)
+        lbl_name = img_name[:-3] + 'json'
+        lbl_path = os.path.join(self.lbl_path, lbl_name)
+        print(lbl_path)
+
+        if not os.path.exists(lbl_path):
+            raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        target = self.read_target(0, lbl_path, shape=(h, w))
+
+        # 调用绘图函数
+        self._draw_vecl(img_path, target)
+
+
+    def _draw_vecl(self, img_path, target, fn=None):
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        lines = lpre
+        sline = np.ones(lpre.shape[0])
+        imshow(io.imread(img_path))
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                if i > 0 and (lines[i] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+        if not (juncs[0] == 0).all():
+            for i, j in enumerate(juncs):
+                if i > 0 and (i == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                          colors="yellow", width=1)
+        plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+        plt.show()
+
+        if fn != None:
+            plt.savefig(fn)
+
+'''
+>>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py

+ 1 - 1
models/keypoint/trainer.py

@@ -19,7 +19,7 @@ from tools.coco_eval import CocoEvaluator
 import time
 
 from models.config.config_tool import read_yaml
-from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
 from models.keypoint.keypoint_dataset import KeypointDataset
 from tools import utils, presets
 

+ 178 - 0
models/line_detect/dataset_LD.py

@@ -0,0 +1,178 @@
+# 使用roi_head数据格式有要求,更改数据格式
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+from tools.presets import DetectionPresetTrain
+
+
+class WirePointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+        target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass

+ 120 - 0
models/line_detect/fasterrcnn_resnet50.py

@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+import torchvision
+from typing import Dict, List, Optional, Tuple
+import torch.nn.functional as F
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
+from torchvision.models.detection.transform import GeneralizedRCNNTransform
+
+
+def get_model(num_classes):
+    # 加载预训练的ResNet-50 FPN backbone
+    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+
+    # 获取分类器的输入特征数
+    in_features = model.roi_heads.box_predictor.cls_score.in_features
+
+    # 替换分类器以适应新的类别数量
+    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
+
+    return model
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+class Fasterrcnn_resnet50(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1):
+        super(Fasterrcnn_resnet50, self).__init__()
+
+        self.model = get_model(num_classes=5)
+        self.backbone = self.model.backbone
+
+        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
+
+        out_channels = self.backbone.out_channels
+        resolution = self.box_roi_pool.output_size[0]
+        representation_size = 1024
+        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        # 多任务输出层
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
+
+        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
+                                             image_std=[0.229, 0.224, 0.225])
+        images, targets = transform(x, target1)
+        x_ = self.backbone(images.tensors)
+
+        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
+        # print(f'backbone:{self.backbone}')
+        # print(f'Fasterrcnn_resnet50 x_:{x_}')
+        feature_ = x_['0']  # 图片特征
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(feature_)
+            outputs.append(output)  # 多头
+
+        if train_or_val == "training":
+            loss_box = self.model(x, target1)
+            return outputs, feature_, loss_box
+        else:
+            box_all = self.model(x, target1)
+            return outputs, feature_, box_all
+
+
+def fasterrcnn_resnet50(**kwargs):
+    model = Fasterrcnn_resnet50(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1)
+    )
+    return model

+ 364 - 48
models/line_detect/line_rcnn.py

@@ -7,13 +7,20 @@ from torchvision.ops import MultiScaleRoIAlign
 from libs.vision_libs.ops import misc as misc_nn_ops
 from libs.vision_libs.transforms._presets import ObjectDetection
 from .roi_heads import RoIHeads
-from libs.vision_libs.models._api  import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
 from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
 from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
 from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+
+from models.config.config_tool import read_yaml
+import numpy as np
+import torch.nn.functional as F
+
+FEATURE_DIM = 8
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 __all__ = [
     "LineRCNN",
@@ -22,6 +29,33 @@ __all__ = [
 ]
 
 
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+class Bottleneck1D(nn.Module):
+    def __init__(self, inplanes, outplanes):
+        super(Bottleneck1D, self).__init__()
+
+        planes = outplanes // 2
+        self.op = nn.Sequential(
+            nn.BatchNorm1d(inplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(inplanes, planes, kernel_size=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, outplanes, kernel_size=1),
+        )
+
+    def forward(self, x):
+        return x + self.op(x)
+
+
 class LineRCNN(FasterRCNN):
     """
     Implements Keypoint R-CNN.
@@ -164,7 +198,7 @@ class LineRCNN(FasterRCNN):
             backbone,
             num_classes=None,
             # transform parameters
-            min_size=None,
+            min_size=512,  # 原为None
             max_size=1333,
             image_mean=None,
             image_std=None,
@@ -216,12 +250,11 @@ class LineRCNN(FasterRCNN):
         out_channels = backbone.out_channels
 
         if line_head is None:
-            keypoint_layers = tuple(512 for _ in range(8))
-            line_head = LineRCNNHeads(out_channels, keypoint_layers)
+            num_class = 5
+            line_head = LineRCNNHeads(out_channels, num_class)
 
         if line_predictor is None:
-            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LineRCNNPredictor(keypoint_dim_reduced)
+            line_predictor = LineRCNNPredictor()
 
         super().__init__(
             backbone,
@@ -259,6 +292,18 @@ class LineRCNN(FasterRCNN):
             **kwargs,
         )
 
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
         roi_heads = RoIHeads(
             # Box
             box_roi_pool,
@@ -276,52 +321,324 @@ class LineRCNN(FasterRCNN):
             box_detections_per_img,
 
         )
-        super().roi_heads = roi_heads
-        # self.roi_heads = roi_heads
-
-        # self.roi_heads.line_head = line_head
-        # self.roi_heads.line_predictor = line_predictor
+        # super().roi_heads = roi_heads
+        self.roi_heads = roi_heads
+        self.roi_heads.line_head = line_head
+        self.roi_heads.line_predictor = line_predictor
 
 
 class LineRCNNHeads(nn.Sequential):
-    pass
-    # def __init__(self, in_channels, layers):
-    #     d = []
-    #     next_feature = in_channels
-    #     for out_channels in layers:
-    #         d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
-    #         d.append(nn.ReLU(inplace=True))
-    #         next_feature = out_channels
-    #     super().__init__(*d)
-    #     for m in self.children():
-    #         if isinstance(m, nn.Conv2d):
-    #             nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
-    #             nn.init.constant_(m.bias, 0)
+    def __init__(self, input_channels, num_class):
+        super(LineRCNNHeads, self).__init__()
+        # print("输入的维度是:", input_channels)
+        m = int(input_channels / 4)
+        heads = []
+        self.head_size = [[2], [1], [2]]
+        for output_channels in sum(self.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+        assert num_class == sum(sum(self.head_size, []))
+
+    def forward(self, x):
+        return torch.cat([head(x) for head in self.heads], dim=1)
+
 
 
 class LineRCNNPredictor(nn.Module):
-    pass
-    # def __init__(self, in_channels, num_keypoints):
-    #     super().__init__()
-    #     input_features = in_channels
-    #     deconv_kernel = 4
-    #     self.kps_score_lowres = nn.ConvTranspose2d(
-    #         input_features,
-    #         num_keypoints,
-    #         deconv_kernel,
-    #         stride=2,
-    #         padding=deconv_kernel // 2 - 1,
-    #     )
-    #     nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
-    #     nn.init.constant_(self.kps_score_lowres.bias, 0)
-    #     self.up_scale = 2
-    #     self.out_channels = num_keypoints
-    #
-    # def forward(self, x):
-    #     x = self.kps_score_lowres(x)
-    #     return torch.nn.functional.interpolate(
-    #         x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
-    #     )
+    def __init__(self):
+        super().__init__()
+        # self.backbone = backbone
+        # self.cfg = read_yaml(cfg)
+        self.cfg = read_yaml(r'./config/wireframe.yaml')
+        self.n_pts0 = self.cfg['model']['n_pts0']
+        self.n_pts1 = self.cfg['model']['n_pts1']
+        self.n_stc_posl = self.cfg['model']['n_stc_posl']
+        self.dim_loi = self.cfg['model']['dim_loi']
+        self.use_conv = self.cfg['model']['use_conv']
+        self.dim_fc = self.cfg['model']['dim_fc']
+        self.n_out_line = self.cfg['model']['n_out_line']
+        self.n_out_junc = self.cfg['model']['n_out_junc']
+        self.loss_weight = self.cfg['model']['loss_weight']
+        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
+        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
+        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
+        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
+        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
+        self.use_cood = self.cfg['model']['use_cood']
+        self.use_slop = self.cfg['model']['use_slop']
+        self.n_stc_negl = self.cfg['model']['n_stc_negl']
+        self.head_size = self.cfg['model']['head_size']
+        self.num_class = sum(sum(self.head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in self.head_size])
+
+        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
+        scale_factor = self.n_pts0 // self.n_pts1
+        if self.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(self.dim_loi, self.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, inputs, features, targets=None):
+
+        # outputs, features = input
+        # for out in outputs:
+        #     print(f'out:{out.shape}')
+        # outputs=merge_features(outputs,100)
+        batch, channel, row, col = inputs.shape
+        # print(f'outputs:{inputs.shape}')
+        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
+
+        if targets is not None:
+            self.training = True
+            # print(f'target:{targets}')
+            wires_targets = [t["wires"] for t in targets]
+            # print(f'wires_target:{wires_targets}')
+            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+            junc_maps = [d["junc_map"] for d in wires_targets]
+            junc_offsets = [d["junc_offset"] for d in wires_targets]
+            line_maps = [d["line_map"] for d in wires_targets]
+
+            junc_map_tensor = torch.stack(junc_maps, dim=0)
+            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+            line_map_tensor = torch.stack(line_maps, dim=0)
+
+            wires_meta = {
+                "junc_map": junc_map_tensor,
+                "junc_offset": junc_offset_tensor,
+                # "line_map": line_map_tensor,
+            }
+        else:
+            self.training = False
+            t = {
+                "junc_coords": torch.zeros(1, 2),
+                "jtyp": torch.zeros(1, dtype=torch.uint8),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+            wires_targets = [t for b in range(inputs.size(0))]
+
+            wires_meta = {
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+
+        T = wires_meta.copy()
+        n_jtyp = T["junc_map"].shape[1]
+        offset = self.head_off
+        result = {}
+        for stack, output in enumerate([inputs]):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+
+        h = result["preds"]
+        # print(f'features shape:{features.shape}')
+        x = self.fc1(features)
+
+        # print(f'x:{x.shape}')
+
+        n_batch, n_channel, row, col = x.shape
+
+        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
+
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+
+        for i, meta in enumerate(wires_targets):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i],
+            )
+            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            ys.append(label)
+            if self.training and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                .reshape(n_channel, -1, self.n_pts0)
+                .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            # print(f'xp.shape:{xp.shape}')
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+            # print(f'idx__:{idx}')
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+
+        # print("Weight dtype:", self.fc2.weight.dtype)
+        x = torch.cat([x, f], 1)
+        # print("Input dtype:", x.dtype)
+        x = x.to(dtype=torch.float32)
+        # print("Input dtype1:", x.dtype)
+        x = self.fc2(x).flatten()
+
+        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
+        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+
+        # if mode != "training":
+        # self.inference(x, idx, jcs, n_batch, ps)
+
+        # return result
+
+    def sample_lines(self, meta, jmap, joff):
+        with torch.no_grad():
+            junc = meta["junc_coords"]  # [N, 2]
+            jtyp = meta["jtyp"]  # [N]
+            Lpos = meta["line_pos_idx"]
+            Lneg = meta["line_neg_idx"]
+
+            n_type = jmap.shape[0]
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = self.n_dyn_junc // n_type
+            N = len(junc)
+            # if mode != "training":
+            if not self.training:
+                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            # if mode == "training":
+            if self.training:
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > self.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > self.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * self.use_cood,
+                    xyv / 128 * self.use_cood,
+                    u2v * self.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
 
 
 _COMMON_META = {
@@ -462,7 +779,6 @@ def linercnn_resnet50_fpn(
     """
     weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
     weights_backbone = ResNet50_Weights.verify(weights_backbone)
-
     if weights is not None:
         weights_backbone = None
         num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))

+ 334 - 67
models/line_detect/roi_heads.py

@@ -6,11 +6,256 @@ import torchvision
 from torch import nn, Tensor
 from torchvision.ops import boxes as box_ops, roi_align
 
-from libs.vision_libs.models.detection  import _utils as det_utils
+import libs.vision_libs.models.detection._utils as det_utils
+
+from collections import OrderedDict
+
+
+def l2loss(input, target):
+    return ((target - input) ** 2).mean(2).mean(1)
+
+
+def cross_entropy_loss(logits, positive):
+    nlogp = -F.log_softmax(logits, dim=0)
+    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
+
+
+def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
+    logp = torch.sigmoid(logits) + offset
+    loss = torch.abs(logp - target)
+    if mask is not None:
+        w = mask.mean(2, True).mean(1, True)
+        w[w == 0] = 1
+        loss = loss * (mask / w)
+
+    return loss.mean(2).mean(1)
+
 
 ###计算多头损失
-def line_loss():
-    pass
+def line_head_loss(input_dict, outputs, feature, loss_weight, mode_train):
+    # image = input_dict["image"]
+    # target_b = input_dict["target_b"]
+    # outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
+
+    result = {"feature": feature}
+    batch, channel, row, col = outputs[0].shape
+
+    T = input_dict["target"].copy()
+    n_jtyp = T["junc_map"].shape[1]
+
+    # switch to CNHW
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    for stack, output in enumerate(outputs):
+        output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+        jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+        lmap = output[offset[0]: offset[1]].squeeze(0)
+        # print(f"lmap:{lmap.shape}")
+        joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+        if stack == 0:
+            result["preds"] = {
+                "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                "lmap": lmap.sigmoid(),
+                "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+            }
+            if mode_train == False:
+                return result
+
+        L = OrderedDict()
+        L["jmap"] = sum(
+            cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+        )
+        L["lmap"] = (
+            F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+                .mean(2)
+                .mean(1)
+        )
+        L["joff"] = sum(
+            sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+            for i in range(n_jtyp)
+            for j in range(2)
+        )
+        for loss_name in L:
+            L[loss_name].mul_(loss_weight[loss_name])
+        losses.append(L)
+    result["losses"] = losses
+    # result["aaa"] = aaa
+    return result
+
+
+#  计算线性损失
+def line_vectorizer_loss(result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc, loss_weight, mode_train):
+    if mode_train == False:
+        p = torch.cat(ps)
+        s = torch.sigmoid(x)
+        b = s > 0.5
+        lines = []
+        score = []
+        for i in range(n_batch):
+            p0 = p[idx[i]: idx[i + 1]]
+            s0 = s[idx[i]: idx[i + 1]]
+            mask = b[idx[i]: idx[i + 1]]
+            p0 = p0[mask]
+            s0 = s0[mask]
+            if len(p0) == 0:
+                lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+                score.append(torch.zeros([1, n_out_line], device=p.device))
+            else:
+                arg = torch.argsort(s0, descending=True)
+                p0, s0 = p0[arg], s0[arg]
+                lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+                score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+            for j in range(len(jcs[i])):
+                if len(jcs[i][j]) == 0:
+                    jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+                jcs[i][j] = jcs[i][j][
+                    None, torch.arange(n_out_junc) % len(jcs[i][j])
+                ]
+        result["preds"]["lines"] = torch.cat(lines)
+        result["preds"]["score"] = torch.cat(score)
+        result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+        if len(jcs[i]) > 1:
+            result["preds"]["junts"] = torch.cat(
+                [jcs[i][1] for i in range(n_batch)]
+            )
+
+    # if input_dict["mode"] != "testing":
+    y = torch.cat(ys)
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
+    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+
+    if mode_train == True:
+        del result["preds"]
+
+    return result
+
+
+
+def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
+    # output, feature: head返回结果
+    # x, y, idx : line中间生成结果
+    result = {}
+    batch, channel, row, col = output.shape
+
+    wires_targets = [t["wires"] for t in targets]
+    wires_targets = wires_targets.copy()
+    # print(f'wires_target:{wires_targets}')
+    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+    junc_maps = [d["junc_map"] for d in wires_targets]
+    junc_offsets = [d["junc_offset"] for d in wires_targets]
+    line_maps = [d["line_map"] for d in wires_targets]
+
+    junc_map_tensor = torch.stack(junc_maps, dim=0)
+    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+    line_map_tensor = torch.stack(line_maps, dim=0)
+    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
+
+    n_jtyp = T["junc_map"].shape[1]
+
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+    lmap = output[offset[0]: offset[1]].squeeze(0)
+    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+    L = OrderedDict()
+    L["junc_map"] = sum(
+        cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+    )
+    L["line_map"] = (
+        F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+            .mean(2)
+            .mean(1)
+    )
+    L["junc_offset"] = sum(
+        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+        for i in range(n_jtyp)
+        for j in range(2)
+    )
+    for loss_name in L:
+        L[loss_name].mul_(loss_weight[loss_name])
+    losses.append(L)
+    result["losses"] = losses
+
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
+    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+
+    return result
+
+
+def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
+    result = {}
+    result["wires"] = {}
+    p = torch.cat(ps)
+    s = torch.sigmoid(input)
+    b = s > 0.5
+    lines = []
+    score = []
+    # print(f"n_batch:{n_batch}")
+    for i in range(n_batch):
+        # print(f"idx:{idx}")
+        p0 = p[idx[i]: idx[i + 1]]
+        s0 = s[idx[i]: idx[i + 1]]
+        mask = b[idx[i]: idx[i + 1]]
+        p0 = p0[mask]
+        s0 = s0[mask]
+        if len(p0) == 0:
+            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+            score.append(torch.zeros([1, n_out_line], device=p.device))
+        else:
+            arg = torch.argsort(s0, descending=True)
+            p0, s0 = p0[arg], s0[arg]
+            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+        for j in range(len(jcs[i])):
+            if len(jcs[i][j]) == 0:
+                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+            jcs[i][j] = jcs[i][j][
+                None, torch.arange(n_out_junc) % len(jcs[i][j])
+            ]
+    result["wires"]["lines"] = torch.cat(lines)
+    result["wires"]["score"] = torch.cat(score)
+    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+
+    if len(jcs[i]) > 1:
+        result["preds"]["junts"] = torch.cat(
+            [jcs[i][1] for i in range(n_batch)]
+        )
+
+    return result
+
 
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
     # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
@@ -166,7 +411,7 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
 
 
 def _onnx_heatmaps_to_keypoints(
-    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
 ):
     num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
 
@@ -208,9 +453,9 @@ def _onnx_heatmaps_to_keypoints(
     ind = ind.to(dtype=torch.int64) * base
     end_scores_i = (
         roi_map.index_select(1, y_int.to(dtype=torch.int64))
-        .index_select(2, x_int.to(dtype=torch.int64))
-        .view(-1)
-        .index_select(0, ind.to(dtype=torch.int64))
+            .index_select(2, x_int.to(dtype=torch.int64))
+            .view(-1)
+            .index_select(0, ind.to(dtype=torch.int64))
     )
 
     return xy_preds_i, end_scores_i
@@ -218,7 +463,7 @@ def _onnx_heatmaps_to_keypoints(
 
 @torch.jit._script_if_tracing
 def _onnx_heatmaps_to_keypoints_loop(
-    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
 ):
     xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
     end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
@@ -424,7 +669,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
     y_0 = max(box[1], 0)
     y_1 = min(box[3] + 1, im_h)
 
-    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
     return im_mask
 
 
@@ -449,7 +694,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
     y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
     y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
 
-    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
 
     # TODO : replace below with a dynamic padding when support is added in ONNX
 
@@ -500,29 +745,29 @@ class RoIHeads(nn.Module):
     }
 
     def __init__(
-        self,
-        box_roi_pool,
-        box_head,
-        box_predictor,
-        line_head,
-        line_predictor,
-        # Faster R-CNN training
-        fg_iou_thresh,
-        bg_iou_thresh,
-        batch_size_per_image,
-        positive_fraction,
-        bbox_reg_weights,
-        # Faster R-CNN inference
-        score_thresh,
-        nms_thresh,
-        detections_per_img,
-        # Mask
-        mask_roi_pool=None,
-        mask_head=None,
-        mask_predictor=None,
-        keypoint_roi_pool=None,
-        keypoint_head=None,
-        keypoint_predictor=None,
+            self,
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            line_head,
+            line_predictor,
+            # Faster R-CNN training
+            fg_iou_thresh,
+            bg_iou_thresh,
+            batch_size_per_image,
+            positive_fraction,
+            bbox_reg_weights,
+            # Faster R-CNN inference
+            score_thresh,
+            nms_thresh,
+            detections_per_img,
+            # Mask
+            mask_roi_pool=None,
+            mask_head=None,
+            mask_predictor=None,
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
     ):
         super().__init__()
 
@@ -540,8 +785,8 @@ class RoIHeads(nn.Module):
         self.box_head = box_head
         self.box_predictor = box_predictor
 
-        self.line_head=line_head
-        self.line_predictor=line_predictor
+        self.line_head = line_head
+        self.line_predictor = line_predictor
 
         self.score_thresh = score_thresh
         self.nms_thresh = nms_thresh
@@ -556,7 +801,14 @@ class RoIHeads(nn.Module):
         self.keypoint_predictor = keypoint_predictor
 
     def has_line(self):
-        pass
+        # if self.mask_roi_pool is None:
+        #     return False
+        if self.line_head is None:
+            return False
+        if self.line_predictor is None:
+            return False
+        return True
+
     def has_mask(self):
         if self.mask_roi_pool is None:
             return False
@@ -638,9 +890,9 @@ class RoIHeads(nn.Module):
                 raise ValueError("Every element of targets should have a masks key")
 
     def select_training_samples(
-        self,
-        proposals,  # type: List[Tensor]
-        targets,  # type: Optional[List[Dict[str, Tensor]]]
+            self,
+            proposals,  # type: List[Tensor]
+            targets,  # type: Optional[List[Dict[str, Tensor]]]
     ):
         # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
         self.check_targets(targets)
@@ -676,11 +928,11 @@ class RoIHeads(nn.Module):
         return proposals, matched_idxs, labels, regression_targets
 
     def postprocess_detections(
-        self,
-        class_logits,  # type: Tensor
-        box_regression,  # type: Tensor
-        proposals,  # type: List[Tensor]
-        image_shapes,  # type: List[Tuple[int, int]]
+            self,
+            class_logits,  # type: Tensor
+            box_regression,  # type: Tensor
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
     ):
         # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
         device = class_logits.device
@@ -734,13 +986,12 @@ class RoIHeads(nn.Module):
 
         return all_boxes, all_scores, all_labels
 
-
     def forward(
-        self,
-        features,  # type: Dict[str, Tensor]
-        proposals,  # type: List[Tensor]
-        image_shapes,  # type: List[Tuple[int, int]]
-        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+            self,
+            features,  # type: Dict[str, Tensor]
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
     ):
         # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
         """
@@ -750,6 +1001,12 @@ class RoIHeads(nn.Module):
             image_shapes (List[Tuple[H, W]])
             targets (List[Dict])
         """
+        if targets is not None:
+            self.training = True
+
+        else:
+            self.training = False
+
         if targets is not None:
             for t in targets:
                 # TODO: https://github.com/pytorch/pytorch/issues/26731
@@ -771,10 +1028,6 @@ class RoIHeads(nn.Module):
 
         box_features = self.box_roi_pool(features, proposals, image_shapes)
         box_features = self.box_head(box_features)
-
-
-
-
         class_logits, box_regression = self.box_predictor(box_features)
 
         result: List[Dict[str, torch.Tensor]] = []
@@ -785,8 +1038,6 @@ class RoIHeads(nn.Module):
             if regression_targets is None:
                 raise ValueError("regression_targets cannot be None")
             loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
-
-
             losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
         else:
             boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
@@ -799,17 +1050,33 @@ class RoIHeads(nn.Module):
                         "scores": scores[i],
                     }
                 )
-        if self.has_line():
 
-            line_features = self.line_head(features)
-            _ = self.line_predictor(line_features)
-            ### line_loss(multitasklearner)
-
-
-            ### infer
+        features_lcnn = features['0']
+        if self.has_line():
+            outputs = self.line_head(features_lcnn)
+            loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
+                inputs=outputs, features=features_lcnn, targets=targets)
+
+            # # line_loss(multitasklearner)
+            # if self.training:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=True)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=True)
+            # else:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=False)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=False)
 
+            if self.training:
+                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+            else:
+                pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                result.append(pred)
+                loss_wirepoint = {}
+            losses.update(loss_wirepoint)
 
-            pass
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]
             if self.training:
@@ -854,9 +1121,9 @@ class RoIHeads(nn.Module):
         # keep none checks in if conditional so torchscript will conditionally
         # compile each branch
         if (
-            self.keypoint_roi_pool is not None
-            and self.keypoint_head is not None
-            and self.keypoint_predictor is not None
+                self.keypoint_roi_pool is not None
+                and self.keypoint_head is not None
+                and self.keypoint_predictor is not None
         ):
             keypoint_proposals = [p["boxes"] for p in result]
             if self.training:

+ 1 - 1
models/wirenet/wirepoint_rcnn.py

@@ -22,7 +22,7 @@ from torchvision.ops import misc as misc_nn_ops
 
 from models.config import config_tool
 from models.config.config_tool import read_yaml
-from models.ins_detect.trainer import get_transform
+from models.ins.trainer import get_transform
 from models.wirenet.head import RoIHeads
 from models.wirenet.wirepoint_dataset import WirePointDataset
 from tools import utils

+ 1 - 1
models/wirenet2/trainer.py

@@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 
 from models.config.config_tool import read_yaml
-from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
 from models.keypoint.keypoint_dataset import KeypointDataset
 from tools import utils, presets
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):

+ 360 - 0
train——line_rcnn.py

@@ -0,0 +1,360 @@
+# 根据LCNN写的train    2025/2/7
+'''
+#!/usr/bin/env python3
+import datetime
+import glob
+import os
+import os.path as osp
+import platform
+import pprint
+import random
+import shlex
+import shutil
+import subprocess
+import sys
+import numpy as np
+import torch
+import torchvision
+import yaml
+import lcnn
+from lcnn.config import C, M
+from lcnn.datasets import WireframeDataset, collate
+from lcnn.models.line_vectorizer import LineVectorizer
+from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
+from torchvision.models import resnet50
+
+from models.line_detect.line_rcnn import linercnn_resnet50_fpn
+
+
+
+def main():
+
+    # 训练配置参数
+    config = {
+        # 数据集配置
+        'datadir': r'D:\python\PycharmProjects\data',  # 数据集目录
+        'config_file': 'config/wireframe.yaml',  # 配置文件路径
+
+        # GPU配置
+        'devices': '0',  # 使用的GPU设备
+        'identifier': 'fasterrcnn_resnet50',  # 训练标识符 stacked_hourglass unet
+
+        # 预训练模型路径
+        # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth',  # 预训练模型路径
+    }
+
+    # 更新配置
+    C.update(C.from_yaml(filename=config['config_file']))
+    M.update(C.model)
+
+    # 设置随机数种子
+    random.seed(0)
+    np.random.seed(0)
+    torch.manual_seed(0)
+
+    # 设备配置
+    device_name = "cpu"
+    os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
+
+    if torch.cuda.is_available():
+        device_name = "cuda"
+        torch.backends.cudnn.deterministic = True
+        torch.cuda.manual_seed(0)
+        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
+    else:
+        print("CUDA is not available")
+
+    device = torch.device(device_name)
+
+    # 数据加载
+    kwargs = {
+        "collate_fn": collate,
+        "num_workers": C.io.num_workers if os.name != "nt" else 0,
+        "pin_memory": True,
+    }
+
+    train_loader = torch.utils.data.DataLoader(
+        WireframeDataset(config['datadir'], dataset_type="train"),
+        shuffle=True,
+        batch_size=M.batch_size,
+        **kwargs,
+    )
+
+    val_loader = torch.utils.data.DataLoader(
+        WireframeDataset(config['datadir'], dataset_type="val"),
+        shuffle=False,
+        batch_size=M.batch_size_eval,
+        **kwargs,
+    )
+
+    model = linercnn_resnet50_fpn().to(device)
+
+    # 加载预训练权重
+
+    try:
+        # 加载模型权重
+        checkpoint = torch.load(config['pretrained_model'], map_location=device)
+
+        # 根据实际的检查点结构选择加载方式
+        if 'model_state_dict' in checkpoint:
+            # 如果是完整的检查点
+            model.load_state_dict(checkpoint['model_state_dict'])
+        elif 'state_dict' in checkpoint:
+            # 如果是只有状态字典的检查点
+            model.load_state_dict(checkpoint['state_dict'])
+        else:
+            # 直接加载权重字典
+            model.load_state_dict(checkpoint)
+
+        print("Successfully loaded pre-trained model weights.")
+    except Exception as e:
+        print(f"Error loading model weights: {e}")
+
+
+    # 优化器配置
+    if C.optim.name == "Adam":
+        optim = torch.optim.Adam(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=C.optim.lr,
+            weight_decay=C.optim.weight_decay,
+            amsgrad=C.optim.amsgrad,
+        )
+    elif C.optim.name == "SGD":
+        optim = torch.optim.SGD(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=C.optim.lr,
+            weight_decay=C.optim.weight_decay,
+            momentum=C.optim.momentum,
+        )
+    else:
+        raise NotImplementedError
+
+    # 输出目录
+    outdir = osp.join(
+        osp.expanduser(C.io.logdir),
+        f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
+    )
+    os.makedirs(outdir, exist_ok=True)
+
+    try:
+        trainer = lcnn.trainer.Trainer(
+            device=device,
+            model=model,
+            optimizer=optim,
+            train_loader=train_loader,
+            val_loader=val_loader,
+            out=outdir,
+        )
+
+        print("Starting training...")
+        trainer.train()
+        print("Training completed.")
+
+    except BaseException:
+        if len(glob.glob(f"{outdir}/viz/*")) <= 1:
+            shutil.rmtree(outdir)
+        raise
+
+
+if __name__ == "__main__":
+    main()
+'''
+
+
+# 2025/2/9
+import os
+from typing import Optional, Any
+
+import cv2
+import numpy as np
+import torch
+
+from models.config.config_tool import read_yaml
+from models.line_detect.dataset_LD import WirePointDataset
+from tools import utils
+
+from torch.utils.tensorboard import SummaryWriter
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from skimage import io
+
+from models.line_detect.line_rcnn import linercnn_resnet50_fpn
+from torchvision.utils import draw_bounding_boxes
+from models.wirenet.postprocess import postprocess
+from torchvision import transforms
+from collections import OrderedDict
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+
+    return total_loss
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def show_line(img, pred,  epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred[1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.8]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+
+if __name__ == '__main__':
+    cfg = r'./config/wireframe.yaml'
+    cfg = read_yaml(cfg)
+    print(f'cfg:{cfg}')
+    print(cfg['model']['n_dyn_negl'])
+    # net = WirepointPredictor()
+
+    dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
+    train_sampler = torch.utils.data.RandomSampler(dataset_train)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+    train_collate_fn = utils.collate_fn_wirepoint
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
+    )
+
+    dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+    val_sampler = torch.utils.data.RandomSampler(dataset_val)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
+    val_collate_fn = utils.collate_fn_wirepoint
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
+    )
+
+    model = linercnn_resnet50_fpn().to(device)
+
+    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
+    writer = SummaryWriter(cfg['io']['logdir'])
+
+    def move_to_device(data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+
+    # def writer_loss(writer, losses, epoch):
+    #     try:
+    #         for key, value in losses.items():
+    #             if key == 'loss_wirepoint':
+    #                 for subdict in losses['loss_wirepoint']['losses']:
+    #                     for subkey, subvalue in subdict.items():
+    #                         writer.add_scalar(f'loss_wirepoint/{subkey}',
+    #                                           subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+    #                                           epoch)
+    #             elif isinstance(value, torch.Tensor):
+    #                 writer.add_scalar(key, value.item(), epoch)
+    #     except Exception as e:
+    #         print(f"TensorBoard logging error: {e}")
+    def writer_loss(writer, losses, epoch):
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            writer.add_scalar(f'loss/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+
+    for epoch in range(cfg['optim']['max_epoch']):
+        print(f"epoch:{epoch}")
+        model.train()
+
+        for imgs, targets in data_loader_train:
+            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+            # print(losses)
+            loss = _loss(losses)
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            writer_loss(writer, losses, epoch)
+
+        model.eval()
+        with torch.no_grad():
+            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                pred = model(move_to_device(imgs, device))
+                if batch_idx == 0:
+                    show_line(imgs[0], pred, epoch, writer)
+                break
+
+