Browse Source

修改LineDetect为keypoint方式

RenLiqiang 5 months ago
parent
commit
a59b471b64

+ 16 - 16
libs/vision_libs/models/detection/rpn.py

@@ -361,11 +361,11 @@ class RegionProposalNetwork(torch.nn.Module):
         features = list(features.values())
 
         objectness, pred_bbox_deltas = self.head(features)
-        for obj in objectness:
-            print(f'objectness:{obj.shape}')
+        # for obj in objectness:
+            # print(f'objectness:{obj.shape}')
 
-        for pred_bbox in pred_bbox_deltas:
-            print(f'pred_bbox:{pred_bbox.shape}')
+        # for pred_bbox in pred_bbox_deltas:
+            # print(f'pred_bbox:{pred_bbox.shape}')
 
         anchors = self.anchor_generator(images, features)
 
@@ -381,19 +381,19 @@ class RegionProposalNetwork(torch.nn.Module):
         # note that we detach the deltas because Faster R-CNN do not backprop through
         # the proposals
         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
-        print(f'box_coder.decode proposals:{proposals.shape}')
+        # print(f'box_coder.decode proposals:{proposals.shape}')
         proposals = proposals.view(num_images, -1, 4)
         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
-        print(f'boxes:{boxes[0].shape},scores:{scores[0].shape}')
-
-        lines=self.lines_generator(features,300)
-
-        # 合并所有线段为一个 Tensor(假设 batch_size=2)
-        lines_all = torch.cat(lines, dim=0)  # [Total_Lines, 4]
-
-        # 过滤出在 boxes 内的线段
-        lines =self.filter_lines_inside_boxes(lines_all, boxes)
-        print(f'filter_lines:{lines}')
+        # print(f'boxes:{boxes[0].shape},scores:{scores[0].shape}')
+        #
+        # lines=self.lines_generator(features,300)
+        #
+        # # 合并所有线段为一个 Tensor(假设 batch_size=2)
+        # lines_all = torch.cat(lines, dim=0)  # [Total_Lines, 4]
+        #
+        # # 过滤出在 boxes 内的线段
+        # lines =self.filter_lines_inside_boxes(lines_all, boxes)
+        # print(f'filter_lines:{lines}')
 
 
         losses = {}
@@ -410,7 +410,7 @@ class RegionProposalNetwork(torch.nn.Module):
                 "loss_rpn_box_reg": loss_rpn_box_reg,
             }
         # print(f'boxes:{boxes[0].shape}')
-        return boxes,losses,lines
+        return boxes,losses
 
     def lines_generator(self, features: torch.Tensor, topk=300):
         """

+ 3 - 2
models/base/base_detection_net.py

@@ -44,6 +44,7 @@ class BaseDetectionNet(BaseModel):
             return  losses
         else:
             if targets is not None:
+                print(f'returned (detections,losses):{losses}')
                 return detections,losses
             else:
                 return detections
@@ -112,14 +113,14 @@ class BaseDetectionNet(BaseModel):
 
         if isinstance(features, torch.Tensor):
             features = OrderedDict([("0", features)])
-        proposals, proposal_losses,lines = self.rpn(images, features, targets)
+        proposals, proposal_losses = self.rpn(images, features, targets)
 
 
         # print(f'proposals:{proposals[0].shape}')
 
 
 
-        detections, detector_losses = self.roi_heads(features, proposals, lines, images.image_sizes, targets)
+        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
 
         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
 

BIN
models/line_detect/color_img.jpg


+ 208 - 0
models/line_detect/line_dataset.py

@@ -0,0 +1,208 @@
+from torch.utils.data.dataset import T_co
+
+from libs.vision_libs.utils import draw_keypoints
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+def validate_keypoints(keypoints, image_width, image_height):
+    for kp in keypoints:
+        x, y, v = kp
+        if not (0 <= x < image_width and 0 <= y < image_height):
+            raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
+
+
+class LineDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+
+        # target["labels"] = torch.stack(labels)
+
+        # print(f'labels:{target["labels"]}')
+        # target["boxes"] = line_boxes(target)
+        target["boxes"], keypoints = line_boxes(target)
+        target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
+        # keypoints=keypoints/512
+        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
+
+        # keypoints= wire_labels["junc_coords"]
+        a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
+        keypoints = torch.cat((keypoints, a), dim=1)
+
+        target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
+        # print(f'boxes:{target["boxes"].shape}')
+        # 在 __getitem__ 方法中调用此函数
+        validate_keypoints(keypoints, shape[0], shape[1])
+        # print(f'keypoints:{target["keypoints"].shape}')
+        # print(f'target:{target}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        img = PIL.Image.open(img_path).convert('RGB')
+        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+        keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
+        plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
+        plt.show()
+
+
+
+
+
+    def show_img(self, img_path):
+        pass
+
+def line_boxes(target):
+    boxs = []
+    lpre = target['wires']["lpre"].cpu().numpy()
+    vecl_target = target['wires']["lpre_label"].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    lines = lpre
+    sline = np.ones(lpre.shape[0])
+
+    keypoints = []
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+            keypoints.append([a[1], a[0]])
+            keypoints.append([b[1], b[0]])
+
+            if a[1] > b[1]:
+                ymax = a[1] + 1
+                ymin = b[1] - 1
+            else:
+                ymin = a[1] - 1
+                ymax = b[1] + 1
+            if a[0] > b[0]:
+                xmax = a[0] + 1
+                xmin = b[0] - 1
+            else:
+                xmin = a[0] - 1
+                xmax = b[0] + 1
+            boxs.append([ymin, xmin, ymax, xmax])
+
+    return torch.tensor(boxs), torch.tensor(keypoints)
+
+if __name__ == '__main__':
+    path=r"\\192.168.50.222/share/lm/Dataset_all"
+    dataset= LineDataset(dataset_path=path, dataset_type='train')
+    dataset.show(10)

+ 603 - 0
models/line_detect/line_detect.py

@@ -0,0 +1,603 @@
+import os
+from typing import Any, Callable, List, Optional, Tuple, Union
+import torch
+from torch import nn
+
+
+from libs.vision_libs import ops
+from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
+    efficientnet_v2_s, detection, EfficientNet_V2_L_Weights, efficientnet_v2_l, EfficientNet_V2_M_Weights, \
+    efficientnet_v2_m
+from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
+from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
+from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
+from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
+from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
+from libs.vision_libs.transforms._presets import ObjectDetection
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
+    BackboneWithFPN
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+from .roi_heads import RoIHeads
+
+from .trainer import Trainer
+from ..base import backbone_factory
+from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
+# from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
+from ..base.base_detection_net import BaseDetectionNet
+import torch.nn.functional as F
+
+from ..base.high_reso_resnet import resnet50fpn, resnet18fpn
+
+__all__ = [
+    "LineDetect",
+    "LineDetect_ResNet50_FPN_Weights",
+    "linedetect_resnet50_fpn",
+]
+
+def _default_anchorgen():
+    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    return AnchorGenerator(anchor_sizes, aspect_ratios)
+
+
+class LineDetect(BaseDetectionNet):
+
+
+    def __init__(
+            self,
+            backbone,
+            num_classes=None,
+            # transform parameters
+            min_size=512,
+            max_size=1333,
+            image_mean=None,
+            image_std=None,
+            # RPN parameters
+            rpn_anchor_generator=None,
+            rpn_head=None,
+            rpn_pre_nms_top_n_train=2000,
+            rpn_pre_nms_top_n_test=1000,
+            rpn_post_nms_top_n_train=2000,
+            rpn_post_nms_top_n_test=1000,
+            rpn_nms_thresh=0.7,
+            rpn_fg_iou_thresh=0.7,
+            rpn_bg_iou_thresh=0.3,
+            rpn_batch_size_per_image=256,
+            rpn_positive_fraction=0.5,
+            rpn_score_thresh=0.0,
+            # Box parameters
+            box_roi_pool=None,
+            box_head=None,
+            box_predictor=None,
+            box_score_thresh=0.05,
+            box_nms_thresh=0.5,
+            box_detections_per_img=100,
+            box_fg_iou_thresh=0.5,
+            box_bg_iou_thresh=0.5,
+            box_batch_size_per_image=512,
+            box_positive_fraction=0.25,
+            bbox_reg_weights=None,
+            # keypoint parameters
+            line_roi_pool=None,
+            line_head=None,
+            line_predictor=None,
+            num_keypoints=None,
+            **kwargs,
+    ):
+
+        out_channels = backbone.out_channels
+
+        if rpn_anchor_generator is None:
+            rpn_anchor_generator = _default_anchorgen()
+        if rpn_head is None:
+            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+
+        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+        rpn = RegionProposalNetwork(
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_pre_nms_top_n,
+            rpn_post_nms_top_n,
+            rpn_nms_thresh,
+            score_thresh=rpn_score_thresh,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = ObjectionPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+        )
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        super().__init__(backbone, rpn, roi_heads, transform)
+
+
+
+        if not isinstance(line_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+            )
+        if min_size is None:
+            min_size = (640, 672, 704, 736, 768, 800)
+
+        if num_keypoints is not None:
+            if line_predictor is not None:
+                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        else:
+            num_keypoints = 2
+
+
+        if line_roi_pool is None:
+            line_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if line_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            line_head = LineHeads(out_channels, keypoint_layers)
+
+        if line_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            line_predictor = LinePredictor(keypoint_dim_reduced, num_keypoints)
+
+
+        self.roi_heads.keypoint_roi_pool = line_roi_pool
+        self.roi_heads.keypoint_head = line_head
+        self.roi_heads.keypoint_predictor = line_predictor
+
+    def start_train(self, cfg):
+        # cfg = read_yaml(cfg)
+        self.trainer = Trainer()
+        self.trainer.train_from_cfg(model=self, cfg=cfg)
+
+    def load_weights(self, save_path, device='cuda'):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+
+            self.load_state_dict(checkpoint['model_state_dict'])
+            # if optimizer is not None:
+            #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            # epoch = checkpoint['epoch']
+            # loss = checkpoint['loss']
+            # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+            print(f"Loaded model from {save_path}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return self
+
+
+class TwoMLPHead(nn.Module):
+    """
+    Standard heads for FPN-based models
+
+    Args:
+        in_channels (int): number of input channels
+        representation_size (int): size of the intermediate representation
+    """
+
+    def __init__(self, in_channels, representation_size):
+        super().__init__()
+
+        self.fc6 = nn.Linear(in_channels, representation_size)
+        self.fc7 = nn.Linear(representation_size, representation_size)
+
+    def forward(self, x):
+        x = x.flatten(start_dim=1)
+
+        x = F.relu(self.fc6(x))
+        x = F.relu(self.fc7(x))
+
+        return x
+
+
+class ObjectionConvFCHead(nn.Sequential):
+    def __init__(
+        self,
+        input_size: Tuple[int, int, int],
+        conv_layers: List[int],
+        fc_layers: List[int],
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        """
+        Args:
+            input_size (Tuple[int, int, int]): the input size in CHW format.
+            conv_layers (list): feature dimensions of each Convolution layer
+            fc_layers (list): feature dimensions of each FCN layer
+            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+        """
+        in_channels, in_height, in_width = input_size
+
+        blocks = []
+        previous_channels = in_channels
+        for current_channels in conv_layers:
+            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
+            previous_channels = current_channels
+        blocks.append(nn.Flatten())
+        previous_channels = previous_channels * in_height * in_width
+        for current_channels in fc_layers:
+            blocks.append(nn.Linear(previous_channels, current_channels))
+            blocks.append(nn.ReLU(inplace=True))
+            previous_channels = current_channels
+
+        super().__init__(*blocks)
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+                if layer.bias is not None:
+                    nn.init.zeros_(layer.bias)
+
+
+class ObjectionPredictor(nn.Module):
+    """
+    Standard classification + bounding box regression layers
+    for Fast R-CNN.
+
+    Args:
+        in_channels (int): number of input channels
+        num_classes (int): number of output classes (including background)
+    """
+
+    def __init__(self, in_channels, num_classes):
+        super().__init__()
+        self.cls_score = nn.Linear(in_channels, num_classes)
+        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+    def forward(self, x):
+        if x.dim() == 4:
+            torch._assert(
+                list(x.shape[2:]) == [1, 1],
+                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
+            )
+        x = x.flatten(start_dim=1)
+        scores = self.cls_score(x)
+        bbox_deltas = self.bbox_pred(x)
+
+        return scores, bbox_deltas
+
+class LineHeads(nn.Sequential):
+    def __init__(self, in_channels, layers):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        super().__init__(*d)
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+
+class LinePredictor(nn.Module):
+    def __init__(self, in_channels, num_keypoints):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = num_keypoints
+
+    def forward(self, x):
+        x = self.kps_score_lowres(x)
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}
+
+
+class LineDetect_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_LEGACY = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/issues/1606",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 50.6,
+                    "kp_map": 61.1,
+                }
+            },
+            "_ops": 133.924,
+            "_file_size": 226.054,
+            "_docs": """
+                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
+                from an early epoch.
+            """,
+        },
+    )
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 54.6,
+                    "kp_map": 65.0,
+                }
+            },
+            "_ops": 137.42,
+            "_file_size": 226.054,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=(
+            "pretrained",
+            lambda kwargs: LineDetect_ResNet50_FPN_Weights.COCO_LEGACY
+            if kwargs["pretrained"] == "legacy"
+            else LineDetect_ResNet50_FPN_Weights.COCO_V1,
+    ),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+
+def lineDetect_resnet18_fpn(
+        *,
+        weights: Optional[LineDetect_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> LineDetect:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = LineDetect_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    # if weights_backbone is None:
+
+    weights_backbone = ResNet18_Weights.IMAGENET1K_V1
+
+    if weights is not None:
+        # weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 2
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = LineDetect(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == LineDetect_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+def linedetect_resnet50_fpn(
+        *,
+        weights: Optional[LineDetect_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> LineDetect:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.linedetect_resnet50_fpn(weights=LineDetect_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = LineDetect_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = LineDetect(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == LineDetect_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model

+ 111 - 398
models/line_detect/roi_heads.py

@@ -11,277 +11,6 @@ import libs.vision_libs.models.detection._utils as det_utils
 from collections import OrderedDict
 
 
-def l2loss(input, target):
-    return ((target - input) ** 2).mean(2).mean(1)
-
-
-def cross_entropy_loss(logits, positive):
-    nlogp = -F.log_softmax(logits, dim=0)
-    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
-
-
-def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
-    logp = torch.sigmoid(logits) + offset
-    loss = torch.abs(logp - target)
-    if mask is not None:
-        w = mask.mean(2, True).mean(1, True)
-        w[w == 0] = 1
-        loss = loss * (mask / w)
-
-    return loss.mean(2).mean(1)
-class DiceLoss(nn.Module):
-    def __init__(self, smooth=1.):
-        super(DiceLoss, self).__init__()
-        self.smooth = smooth
-
-    def forward(self, logits, targets):
-        probs = torch.sigmoid(logits)
-        probs = probs.view(-1)
-        targets = targets.view(-1).float()
-
-        intersection = (probs * targets).sum()
-        dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
-        return 1. - dice
-
-
-
-bce_loss = nn.BCEWithLogitsLoss()
-dice_loss = DiceLoss()
-
-
-def combined_loss(preds, targets, alpha=0.5):
-    bce = bce_loss(preds, targets)
-    d = dice_loss(preds, targets)
-    return alpha * bce + (1 - alpha) * d
-
-###计算多头损失
-def line_head_loss(input_dict, outputs, feature, loss_weight, mode_train):
-    # image = input_dict["image"]
-    # target_b = input_dict["target_b"]
-    # outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
-
-    result = {"feature": feature}
-    batch, channel, row, col = outputs[0].shape
-
-    T = input_dict["target"].copy()
-    n_jtyp = T["junc_map"].shape[1]
-
-    # switch to CNHW
-    for task in ["junc_map"]:
-        T[task] = T[task].permute(1, 0, 2, 3)
-    for task in ["junc_offset"]:
-        T[task] = T[task].permute(1, 2, 0, 3, 4)
-
-    offset = [2, 3, 5]
-    losses = []
-    for stack, output in enumerate(outputs):
-        output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-        jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-        lmap = output[offset[0]: offset[1]].squeeze(0)
-        # print(f"lmap:{lmap.shape}")
-        joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-        if stack == 0:
-            result["preds"] = {
-                "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-                "lmap": lmap.sigmoid(),
-                "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-            }
-            if mode_train == False:
-                return result
-
-        L = OrderedDict()
-        L["jmap"] = sum(
-            cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-        )
-        L["lmap"] = (
-            F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
-                .mean(2)
-                .mean(1)
-        )
-        L["joff"] = sum(
-            sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
-            for i in range(n_jtyp)
-            for j in range(2)
-        )
-        for loss_name in L:
-            L[loss_name].mul_(loss_weight[loss_name])
-        losses.append(L)
-    result["losses"] = losses
-    # result["aaa"] = aaa
-    return result
-
-
-#  计算线性损失
-def line_vectorizer_loss(result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc, loss_weight, mode_train):
-    if mode_train == False:
-        p = torch.cat(ps)
-        s = torch.sigmoid(x)
-        b = s > 0.5
-        lines = []
-        score = []
-        for i in range(n_batch):
-            p0 = p[idx[i]: idx[i + 1]]
-            s0 = s[idx[i]: idx[i + 1]]
-            mask = b[idx[i]: idx[i + 1]]
-            p0 = p0[mask]
-            s0 = s0[mask]
-            if len(p0) == 0:
-                lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
-                score.append(torch.zeros([1, n_out_line], device=p.device))
-            else:
-                arg = torch.argsort(s0, descending=True)
-                p0, s0 = p0[arg], s0[arg]
-                lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
-                score.append(s0[None, torch.arange(n_out_line) % len(s0)])
-            for j in range(len(jcs[i])):
-                if len(jcs[i][j]) == 0:
-                    jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
-                jcs[i][j] = jcs[i][j][
-                    None, torch.arange(n_out_junc) % len(jcs[i][j])
-                ]
-        result["preds"]["lines"] = torch.cat(lines)
-        result["preds"]["score"] = torch.cat(score)
-        result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-        if len(jcs[i]) > 1:
-            result["preds"]["junts"] = torch.cat(
-                [jcs[i][1] for i in range(n_batch)]
-            )
-
-    # if input_dict["mode"] != "testing":
-    y = torch.cat(ys)
-    loss = nn.BCEWithLogitsLoss(reduction="none")
-    loss = loss(x, y)
-    lpos_mask, lneg_mask = y, 1 - y
-    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
-
-    def sum_batch(x):
-        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
-        return torch.cat(xs)
-
-    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
-    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
-    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
-    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
-
-    if mode_train == True:
-        del result["preds"]
-
-    return result
-
-
-def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
-    # output, feature: head返回结果
-    # x, y, idx : line中间生成结果
-    result = {}
-    batch, channel, row, col = output.shape
-
-    wires_targets = [t["wires"] for t in targets]
-    wires_targets = wires_targets.copy()
-    # print(f'wires_target:{wires_targets}')
-    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
-    junc_maps = [d["junc_map"] for d in wires_targets]
-    junc_offsets = [d["junc_offset"] for d in wires_targets]
-    line_maps = [d["line_map"] for d in wires_targets]
-
-    junc_map_tensor = torch.stack(junc_maps, dim=0)
-    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
-    line_map_tensor = torch.stack(line_maps, dim=0)
-    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
-
-    n_jtyp = T["junc_map"].shape[1]
-
-    for task in ["junc_map"]:
-        T[task] = T[task].permute(1, 0, 2, 3)
-    for task in ["junc_offset"]:
-        T[task] = T[task].permute(1, 2, 0, 3, 4)
-
-    offset = [2, 3, 5]
-    losses = []
-    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-    lmap = output[offset[0]: offset[1]].squeeze(0)
-    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-    L = OrderedDict()
-    # L["junc_map"] = sum(
-    #     cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
-    # ).mean()
-    # L["line_map"] = (
-    #     F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
-    #         .mean(2)
-    #         .mean(1)
-    # ).mean()
-    L["junc_map"] = combined_loss(jmap[:, 1, :, :, :], T["junc_map"])
-
-    L["line_map"] = combined_loss(lmap, T["line_map"])
-    L["junc_offset"] = sum(
-        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
-        for i in range(n_jtyp)
-        for j in range(2)
-    ).mean()
-    for loss_name in L:
-        L[loss_name].mul_(loss_weight[loss_name])
-    losses.append(L)
-    result["losses"] = losses
-
-    loss = nn.BCEWithLogitsLoss(reduction="none")
-    loss = loss(x, y)
-    lpos_mask, lneg_mask = y, 1 - y
-    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
-
-    def sum_batch(x):
-        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
-        return torch.cat(xs)
-
-    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
-    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
-    result["losses"][0]["lpos"] = (lpos * loss_weight["lpos"]).mean()
-    result["losses"][0]["lneg"] = (lneg * loss_weight["lneg"]).mean()
-
-    return result
-
-
-def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
-    result = {}
-    result["wires"] = {}
-    p = torch.cat(ps)
-    s = torch.sigmoid(input)
-    b = s > 0.5
-    lines = []
-    score = []
-    # print(f"n_batch:{n_batch}")
-    for i in range(n_batch):
-        # print(f"idx:{idx}")
-        p0 = p[idx[i]: idx[i + 1]]
-        s0 = s[idx[i]: idx[i + 1]]
-        mask = b[idx[i]: idx[i + 1]]
-        p0 = p0[mask]
-        s0 = s0[mask]
-        if len(p0) == 0:
-            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
-            score.append(torch.zeros([1, n_out_line], device=p.device))
-        else:
-            arg = torch.argsort(s0, descending=True)
-            p0, s0 = p0[arg], s0[arg]
-            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
-            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
-        for j in range(len(jcs[i])):
-            if len(jcs[i][j]) == 0:
-                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
-            jcs[i][j] = jcs[i][j][
-                None, torch.arange(n_out_junc) % len(jcs[i][j])
-            ]
-    result["wires"]["lines"] = torch.cat(lines)
-    result["wires"]["score"] = torch.cat(score)
-    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
-
-    if len(jcs[i]) > 1:
-        result["preds"]["junts"] = torch.cat(
-            [jcs[i][1] for i in range(n_batch)]
-        )
-    # print(f'predic result:{result}')
-    return result
-
-
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
     # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
     """
@@ -297,7 +26,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
         classification_loss (Tensor)
         box_loss (Tensor)
     """
-
+    # print(f'compute fastrcnn_loss:{labels}')
     labels = torch.cat(labels, dim=0)
     regression_targets = torch.cat(regression_targets, dim=0)
 
@@ -436,7 +165,7 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
 
 
 def _onnx_heatmaps_to_keypoints(
-        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
 ):
     num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
 
@@ -478,9 +207,9 @@ def _onnx_heatmaps_to_keypoints(
     ind = ind.to(dtype=torch.int64) * base
     end_scores_i = (
         roi_map.index_select(1, y_int.to(dtype=torch.int64))
-            .index_select(2, x_int.to(dtype=torch.int64))
-            .view(-1)
-            .index_select(0, ind.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
     )
 
     return xy_preds_i, end_scores_i
@@ -488,7 +217,7 @@ def _onnx_heatmaps_to_keypoints(
 
 @torch.jit._script_if_tracing
 def _onnx_heatmaps_to_keypoints_loop(
-        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
 ):
     xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
     end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
@@ -694,7 +423,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
     y_0 = max(box[1], 0)
     y_1 = min(box[3] + 1, im_h)
 
-    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
     return im_mask
 
 
@@ -719,7 +448,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
     y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
     y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
 
-    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
 
     # TODO : replace below with a dynamic padding when support is added in ONNX
 
@@ -770,29 +499,31 @@ class RoIHeads(nn.Module):
     }
 
     def __init__(
-            self,
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            line_head,
-            line_predictor,
-            # Faster R-CNN training
-            fg_iou_thresh,
-            bg_iou_thresh,
-            batch_size_per_image,
-            positive_fraction,
-            bbox_reg_weights,
-            # Faster R-CNN inference
-            score_thresh,
-            nms_thresh,
-            detections_per_img,
-            # Mask
-            mask_roi_pool=None,
-            mask_head=None,
-            mask_predictor=None,
-            keypoint_roi_pool=None,
-            keypoint_head=None,
-            keypoint_predictor=None,
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Line
+        line_roi_pool=None,
+        line_head=None,
+        line_predictor=None,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
     ):
         super().__init__()
 
@@ -810,13 +541,14 @@ class RoIHeads(nn.Module):
         self.box_head = box_head
         self.box_predictor = box_predictor
 
-        self.line_head = line_head
-        self.line_predictor = line_predictor
-
         self.score_thresh = score_thresh
         self.nms_thresh = nms_thresh
         self.detections_per_img = detections_per_img
 
+        self.line_roi_pool = line_roi_pool
+        self.line_head = line_head
+        self.line_predictor = line_predictor
+
         self.mask_roi_pool = mask_roi_pool
         self.mask_head = mask_head
         self.mask_predictor = mask_predictor
@@ -825,15 +557,6 @@ class RoIHeads(nn.Module):
         self.keypoint_head = keypoint_head
         self.keypoint_predictor = keypoint_predictor
 
-    def has_line(self):
-        # if self.mask_roi_pool is None:
-        #     return False
-        if self.line_head is None:
-            return False
-        if self.line_predictor is None:
-            return False
-        return True
-
     def has_mask(self):
         if self.mask_roi_pool is None:
             return False
@@ -852,6 +575,15 @@ class RoIHeads(nn.Module):
             return False
         return True
 
+    def has_line(self):
+        if self.line_roi_pool is None:
+            return False
+        if self.line_head is None:
+            return False
+        if self.line_predictor is None:
+            return False
+        return True
+
     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
         # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
         matched_idxs = []
@@ -915,9 +647,9 @@ class RoIHeads(nn.Module):
                 raise ValueError("Every element of targets should have a masks key")
 
     def select_training_samples(
-            self,
-            proposals,  # type: List[Tensor]
-            targets,  # type: Optional[List[Dict[str, Tensor]]]
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
     ):
         # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
         self.check_targets(targets)
@@ -953,11 +685,11 @@ class RoIHeads(nn.Module):
         return proposals, matched_idxs, labels, regression_targets
 
     def postprocess_detections(
-            self,
-            class_logits,  # type: Tensor
-            box_regression,  # type: Tensor
-            proposals,  # type: List[Tensor]
-            image_shapes,  # type: List[Tuple[int, int]]
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
     ):
         # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
         device = class_logits.device
@@ -1012,12 +744,11 @@ class RoIHeads(nn.Module):
         return all_boxes, all_scores, all_labels
 
     def forward(
-            self,
-            features,  # type: Dict[str, Tensor]
-            proposals,  # type: List[Tensor]
-            lines,
-            image_shapes,  # type: List[Tuple[int, int]]
-            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
     ):
         # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
         """
@@ -1027,13 +758,6 @@ class RoIHeads(nn.Module):
             image_shapes (List[Tuple[H, W]])
             targets (List[Dict])
         """
-        # if targets is not None:
-        #     self.training = True
-        #     # print(f'targets is not None')
-        #
-        # else:
-        #     self.training = False
-        #     # print(f'targets is None')
         print(f'roihead forward!!!')
         if targets is not None:
             for t in targets:
@@ -1050,13 +774,9 @@ class RoIHeads(nn.Module):
         if self.training:
             proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
         else:
-            if targets is not None:
-                proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
-            else:
-                labels = None
-                regression_targets = None
-                matched_idxs = None
-
+            labels = None
+            regression_targets = None
+            matched_idxs = None
 
         box_features = self.box_roi_pool(features, proposals, image_shapes)
         box_features = self.box_head(box_features)
@@ -1069,68 +789,66 @@ class RoIHeads(nn.Module):
                 raise ValueError("labels cannot be None")
             if regression_targets is None:
                 raise ValueError("regression_targets cannot be None")
+            print(f'boxes compute losses')
             loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
             losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
         else:
-            if targets is not None:
-                loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
-                losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
-            else:
-                boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
-                num_images = len(boxes)
-                for i in range(num_images):
-                    result.append(
-                        {
-                            "boxes": boxes[i],
-                            "labels": labels[i],
-                            "scores": scores[i],
-                            "lines":lines[i],
-                        }
-                    )
+            print(f'boxes postprocess')
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
 
-        line_features = features['0']
-        if self.has_line():
-            # print('has line_head')
-            # outputs = self.line_head(features_lcnn)
-            # outputs = line_features[:, 0:5, :, :]
 
+        if self.has_line():
+            print(f'roi_heads forward has_line()!!!!')
+            line_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                line_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
 
-            loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
-            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
-                inputs=line_features, features=line_features, targets=targets)
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    line_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
 
-            # # line_loss(multitasklearner)
-            # if self.training:
-            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=True)
-            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
-            #                                        loss_weight, mode_train=True)
-            # else:
-            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=False)
-            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
-            #                                        loss_weight, mode_train=False)
+            line_features = self.keypoint_roi_pool(features, line_proposals, image_shapes)
+            line_features = self.keypoint_head(line_features)
+            line_logits = self.keypoint_predictor(line_features)
 
+            loss_keypoint = {}
             if self.training:
-                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
-                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
-                # print(f'loss_wirepoint:{loss_wirepoint}')
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    line_logits, line_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
             else:
-                # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
-                if targets is not None:
-                    rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
-                    loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
-                else:
-                    print(f'model inference!!!')
-                    pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
-                    result.append(line_features)
-                    result.append(pred)
-                    loss_wirepoint = {}
-
-            losses.update(loss_wirepoint)
-        else:
-            pass
-            # print('has not line_head')
+                if line_logits is None or line_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
 
+                keypoints_probs, kp_scores = keypointrcnn_inference(line_logits, line_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]
             if self.training:
@@ -1174,11 +892,7 @@ class RoIHeads(nn.Module):
 
         # keep none checks in if conditional so torchscript will conditionally
         # compile each branch
-        if (
-                self.keypoint_roi_pool is not None
-                and self.keypoint_head is not None
-                and self.keypoint_predictor is not None
-        ):
+        if self.has_keypoint():
             keypoint_proposals = [p["boxes"] for p in result]
             if self.training:
                 # during training, only focus on positive boxes
@@ -1221,5 +935,4 @@ class RoIHeads(nn.Module):
                     r["keypoints_scores"] = kps
             losses.update(loss_keypoint)
 
-        # print(f'roi losses:{losses}')
         return result, losses

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/lm/Dataset_all
+  datadir: \\192.168.50.222\share\rlq\datasets\250612
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 5 - 6
models/line_detect/train_demo.py

@@ -1,8 +1,9 @@
 import torch
 
-from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
+from models.line_detect.line_detect import lineDetect_resnet18_fpn
+from models.line_net.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
     get_line_net_convnext_fpn, linenet_newresnet18fpn
-from models.line_detect.trainer import Trainer
+from models.line_net.trainer import Trainer
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
@@ -12,8 +13,6 @@ if __name__ == '__main__':
     # model = linenet_resnet18_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=linenet_newresnet50fpn()
-    model = linenet_newresnet18fpn()
-    # model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
-    # trainer = Trainer()
-    # trainer.train_cfg(model,cfg='./train.yaml')
+    model = lineDetect_resnet18_fpn()
+
     model.start_train(cfg='train.yaml')

+ 26 - 69
models/line_detect/trainer.py

@@ -7,11 +7,12 @@ import torch
 from matplotlib import pyplot as plt
 from torch.utils.tensorboard import SummaryWriter
 
-from libs.vision_libs.utils import draw_bounding_boxes
+from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
 from models.base.base_model import BaseModel
 from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
-from models.line_detect.dataset_LD import WirePointDataset
+from models.line_detect.line_dataset import LineDataset
+from models.line_net.dataset_LD import WirePointDataset
 from models.wirenet.postprocess import postprocess
 from tools import utils
 from torchvision import transforms
@@ -148,68 +149,19 @@ class Trainer(BaseTrainer):
 
     def writer_predict_result(self, img, result, epoch):
         img = img.cpu().detach()
-        im = img.permute(1, 2, 0)
-        self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
+        im = img.permute(1, 2, 0)  # [512, 512, 3]
+        self.writer.add_image("ori", im, epoch, dataformats="HWC")
 
-        boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
+        boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result["boxes"],
                                           colors="yellow", width=1)
-        self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
-
-        PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-        # print(f'pred[1]:{pred[1]}')
-        heatmaps = result[-2][0]
-        print(f'heatmaps:{heatmaps.shape}')
-        jmap = heatmaps[1: 2].cpu().detach()
-        lmap = heatmaps[2: 3].cpu().detach()
-        self.writer.add_image("z-jmap", jmap, epoch)
-        self.writer.add_image("z-lmap", lmap, epoch)
-        # plt.imshow(lmap)
+
+        # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
         # plt.show()
-        H = result[-1]['wires']
-        # lines = H["lines"][0].cpu().numpy()
-        lines=result[0]["lines"]
-        scores =100
-        for i in range(1, len(lines)):
-            if (lines[i] == lines[0]).all():
-                lines = lines[:i]
-                scores = scores[:i]
-                break
-
-        # postprocess lines to remove overlapped lines
-        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-        nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
-
-        for i, t in enumerate([0]):
-            plt.gca().set_axis_off()
-            plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-            plt.margins(0, 0)
-            for (a, b), s in zip(nlines, nscores):
-                if s < t:
-                    continue
-                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
-                plt.scatter(a[1], a[0], **PLTOPTS)
-                plt.scatter(b[1], b[0], **PLTOPTS)
-            plt.gca().xaxis.set_major_locator(plt.NullLocator())
-            plt.gca().yaxis.set_major_locator(plt.NullLocator())
-            plt.imshow(im)
-            plt.tight_layout()
-            fig = plt.gcf()
-            fig.canvas.draw()
-
-            width, height = fig.get_size_inches() * fig.get_dpi()  # 获取图像尺寸
-            tmp_img = fig.canvas.tostring_argb()
-            tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
-            tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
-
-            img_rgb = tmp_img_np[:, :, 1:]  # 提取RGB部分,忽略Alpha通道
-
-            # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
-            #     fig.canvas.get_width_height()[::-1] + (3,))
-            plt.close()
-
-            img2 = transforms.ToTensor()(img_rgb)
-
-            self.writer.add_image("z-output", img2, epoch)
+
+        self.writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+        keypoint_img = draw_keypoints(boxed_image, result['keypoints'], colors='red', width=3)
+
+        self.writer.add_image("output", keypoint_img, epoch)
 
     def writer_loss(self, losses, epoch, phase='train'):
         try:
@@ -236,8 +188,8 @@ class Trainer(BaseTrainer):
 
         self.init_params(**kwargs)
 
-        dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
-        dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
+        dataset_train = LineDataset(dataset_path=self.dataset_path, dataset_type='train')
+        dataset_val = LineDataset(dataset_path=self.dataset_path, dataset_type='val')
 
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
@@ -293,27 +245,32 @@ class Trainer(BaseTrainer):
             imgs = self.move_to_device(imgs, device)
             targets = self.move_to_device(targets, device)
             if phase== 'val':
-
-                result,losses = model(imgs, targets)
+                result,loss_dict = model(imgs, targets)
+                losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
+                print(f'val losses:{losses}')
             else:
-                losses = model(imgs, targets)
+                loss_dict = model(imgs, targets)
+                losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
+                print(f'train losses:{losses}')
 
-            loss = _loss(losses)
+            # loss = _loss(losses)
+            loss=losses
             total_loss += loss.item()
             if phase == 'train':
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
-            self.writer_loss(losses, global_step, phase=phase)
+            self.writer_loss(loss_dict, global_step, phase=phase)
             global_step += 1
 
             if epoch_step == 0 and phase == 'val':
                 t_start = time.time()
                 print(f'start to predict:{t_start}')
                 result = model(self.move_to_device(imgs, self.device))
+                print(f'result:{result}')
                 t_end = time.time()
                 print(f'predict used:{t_end - t_start}')
-                self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
+                self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
                 epoch_step+=1
 
         avg_loss = total_loss / len(data_loader)

+ 0 - 0
models/line_detect/untitled.py → models/line_net/__init__.py


+ 2 - 2
models/line_detect/aaa.py → models/line_net/aaa.py

@@ -10,13 +10,13 @@ import skimage.color
 from torchvision import transforms
 import shutil
 import matplotlib.pyplot as plt
-from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_net.line_net import linenet_resnet50_fpn
 from models.wirenet.postprocess import postprocess
 from rtree import index
 import time
 import multiprocessing as mp
 
-# from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_detect.boxline import show_box
+# from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_net.boxline import show_box
 
 # ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn'
 mp.set_start_method('spawn', force=True)

+ 0 - 0
models/line_detect/dataset_LD.py → models/line_net/dataset_LD.py


+ 2 - 2
models/line_detect/infer.py → models/line_net/infer.py

@@ -10,13 +10,13 @@ import skimage.color
 from torchvision import transforms
 import shutil
 import matplotlib.pyplot as plt
-from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_net.line_net import linenet_resnet50_fpn
 from models.wirenet.postprocess import postprocess
 from rtree import index
 import time
 import multiprocessing as mp
 
-# from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_detect.boxline import show_box
+# from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_net.boxline import show_box
 
 # ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn'
 mp.set_start_method('spawn', force=True)

+ 480 - 0
models/line_net/line_detect.py

@@ -0,0 +1,480 @@
+import os
+from typing import Any, Callable, List, Optional, Tuple, Union
+import torch
+from torch import nn
+
+
+from libs.vision_libs import ops
+from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
+    efficientnet_v2_s, detection, EfficientNet_V2_L_Weights, efficientnet_v2_l, EfficientNet_V2_M_Weights, \
+    efficientnet_v2_m
+from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
+from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
+from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
+from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
+from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
+from libs.vision_libs.transforms._presets import ObjectDetection
+from .line_head import LineRCNNHeads
+from .line_predictor import LineRCNNPredictor
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
+    BackboneWithFPN
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+from .roi_heads import RoIHeads
+
+from .trainer import Trainer
+from ..base import backbone_factory
+from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
+# from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
+from ..base.base_detection_net import BaseDetectionNet
+import torch.nn.functional as F
+
+from .predict import Predict1, Predict
+from ..base.high_reso_resnet import resnet50fpn, resnet18fpn
+
+__all__ = [
+    "LineDetect",
+    "LineDetect_ResNet50_FPN_Weights",
+    "linedetect_resnet50_fpn",
+]
+
+def _default_anchorgen():
+    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    return AnchorGenerator(anchor_sizes, aspect_ratios)
+
+
+class LineDetect(BaseDetectionNet):
+
+
+    def __init__(
+            self,
+            backbone,
+            num_classes=None,
+            # transform parameters
+            min_size=512,
+            max_size=1333,
+            image_mean=None,
+            image_std=None,
+            # RPN parameters
+            rpn_anchor_generator=None,
+            rpn_head=None,
+            rpn_pre_nms_top_n_train=2000,
+            rpn_pre_nms_top_n_test=1000,
+            rpn_post_nms_top_n_train=2000,
+            rpn_post_nms_top_n_test=1000,
+            rpn_nms_thresh=0.7,
+            rpn_fg_iou_thresh=0.7,
+            rpn_bg_iou_thresh=0.3,
+            rpn_batch_size_per_image=256,
+            rpn_positive_fraction=0.5,
+            rpn_score_thresh=0.0,
+            # Box parameters
+            box_roi_pool=None,
+            box_head=None,
+            box_predictor=None,
+            box_score_thresh=0.05,
+            box_nms_thresh=0.5,
+            box_detections_per_img=100,
+            box_fg_iou_thresh=0.5,
+            box_bg_iou_thresh=0.5,
+            box_batch_size_per_image=512,
+            box_positive_fraction=0.25,
+            bbox_reg_weights=None,
+            # keypoint parameters
+            line_roi_pool=None,
+            line_head=None,
+            line_predictor=None,
+            num_keypoints=None,
+            **kwargs,
+    ):
+
+        out_channels = backbone.out_channels
+
+        if rpn_anchor_generator is None:
+            rpn_anchor_generator = _default_anchorgen()
+        if rpn_head is None:
+            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+
+        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+        rpn = RegionProposalNetwork(
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_pre_nms_top_n,
+            rpn_post_nms_top_n,
+            rpn_nms_thresh,
+            score_thresh=rpn_score_thresh,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = ObjectionPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+        )
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        super().__init__(backbone, rpn, roi_heads, transform)
+
+
+
+        if not isinstance(line_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+            )
+        if min_size is None:
+            min_size = (640, 672, 704, 736, 768, 800)
+
+        if num_keypoints is not None:
+            if line_predictor is not None:
+                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        else:
+            num_keypoints = 2
+
+
+        if line_roi_pool is None:
+            line_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if line_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            line_head = LineHeads(out_channels, keypoint_layers)
+
+        if line_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            line_predictor = LinePredictor(keypoint_dim_reduced, num_keypoints)
+
+
+        self.roi_heads.keypoint_roi_pool = line_roi_pool
+        self.roi_heads.keypoint_head = line_head
+        self.roi_heads.keypoint_predictor = line_predictor
+
+
+class TwoMLPHead(nn.Module):
+    """
+    Standard heads for FPN-based models
+
+    Args:
+        in_channels (int): number of input channels
+        representation_size (int): size of the intermediate representation
+    """
+
+    def __init__(self, in_channels, representation_size):
+        super().__init__()
+
+        self.fc6 = nn.Linear(in_channels, representation_size)
+        self.fc7 = nn.Linear(representation_size, representation_size)
+
+    def forward(self, x):
+        x = x.flatten(start_dim=1)
+
+        x = F.relu(self.fc6(x))
+        x = F.relu(self.fc7(x))
+
+        return x
+
+
+class ObjectionConvFCHead(nn.Sequential):
+    def __init__(
+        self,
+        input_size: Tuple[int, int, int],
+        conv_layers: List[int],
+        fc_layers: List[int],
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        """
+        Args:
+            input_size (Tuple[int, int, int]): the input size in CHW format.
+            conv_layers (list): feature dimensions of each Convolution layer
+            fc_layers (list): feature dimensions of each FCN layer
+            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+        """
+        in_channels, in_height, in_width = input_size
+
+        blocks = []
+        previous_channels = in_channels
+        for current_channels in conv_layers:
+            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
+            previous_channels = current_channels
+        blocks.append(nn.Flatten())
+        previous_channels = previous_channels * in_height * in_width
+        for current_channels in fc_layers:
+            blocks.append(nn.Linear(previous_channels, current_channels))
+            blocks.append(nn.ReLU(inplace=True))
+            previous_channels = current_channels
+
+        super().__init__(*blocks)
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+                if layer.bias is not None:
+                    nn.init.zeros_(layer.bias)
+
+
+class ObjectionPredictor(nn.Module):
+    """
+    Standard classification + bounding box regression layers
+    for Fast R-CNN.
+
+    Args:
+        in_channels (int): number of input channels
+        num_classes (int): number of output classes (including background)
+    """
+
+    def __init__(self, in_channels, num_classes):
+        super().__init__()
+        self.cls_score = nn.Linear(in_channels, num_classes)
+        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+    def forward(self, x):
+        if x.dim() == 4:
+            torch._assert(
+                list(x.shape[2:]) == [1, 1],
+                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
+            )
+        x = x.flatten(start_dim=1)
+        scores = self.cls_score(x)
+        bbox_deltas = self.bbox_pred(x)
+
+        return scores, bbox_deltas
+
+class LineHeads(nn.Sequential):
+    def __init__(self, in_channels, layers):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        super().__init__(*d)
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+
+class LinePredictor(nn.Module):
+    def __init__(self, in_channels, num_keypoints):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = num_keypoints
+
+    def forward(self, x):
+        x = self.kps_score_lowres(x)
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}
+
+
+class LineDetect_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_LEGACY = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/issues/1606",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 50.6,
+                    "kp_map": 61.1,
+                }
+            },
+            "_ops": 133.924,
+            "_file_size": 226.054,
+            "_docs": """
+                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
+                from an early epoch.
+            """,
+        },
+    )
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 54.6,
+                    "kp_map": 65.0,
+                }
+            },
+            "_ops": 137.42,
+            "_file_size": 226.054,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=(
+            "pretrained",
+            lambda kwargs: LineDetect_ResNet50_FPN_Weights.COCO_LEGACY
+            if kwargs["pretrained"] == "legacy"
+            else LineDetect_ResNet50_FPN_Weights.COCO_V1,
+    ),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+
+
+
+def linedetect_resnet50_fpn(
+        *,
+        weights: Optional[LineDetect_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> LineDetect:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.linedetect_resnet50_fpn(weights=LineDetect_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = LineDetect_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = LineDetect(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == LineDetect_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model

+ 0 - 0
models/line_detect/line_head.py → models/line_net/line_head.py


+ 0 - 0
models/line_detect/line_net.py → models/line_net/line_net.py


+ 0 - 0
models/line_detect/line_net.yaml → models/line_net/line_net.yaml


+ 0 - 0
models/line_detect/line_predictor.py → models/line_net/line_predictor.py


+ 0 - 0
models/line_detect/postprocess.py → models/line_net/postprocess.py


+ 0 - 0
models/line_detect/predict.py → models/line_net/predict.py


+ 1 - 1
models/line_detect/predict2.py → models/line_net/predict2.py

@@ -10,7 +10,7 @@ from PIL import Image
 import matplotlib.pyplot as plt
 import matplotlib as mpl
 import numpy as np
-from models.line_detect.line_net import linenet_resnet50_fpn, get_line_net_efficientnetv2, get_line_net_convnext_fpn
+from models.line_net.line_net import linenet_resnet50_fpn, get_line_net_efficientnetv2, get_line_net_convnext_fpn
 from torchvision import transforms
 
 # from models.wirenet.postprocess import postprocess

+ 1 - 1
models/line_detect/predict_demo.py → models/line_net/predict_demo.py

@@ -1,4 +1,4 @@
-from models.line_detect.line_net import linenet_resnet18_fpn, linenet_resnet50_fpn
+from models.line_net.line_net import linenet_resnet18_fpn, linenet_resnet50_fpn
 
 if __name__ == '__main__':
     # model=linenet_resnet18_fpn()

+ 1225 - 0
models/line_net/roi_heads.py

@@ -0,0 +1,1225 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from  libs.vision_libs.ops import boxes as box_ops, roi_align
+
+import libs.vision_libs.models.detection._utils as det_utils
+
+from collections import OrderedDict
+
+
+def l2loss(input, target):
+    return ((target - input) ** 2).mean(2).mean(1)
+
+
+def cross_entropy_loss(logits, positive):
+    nlogp = -F.log_softmax(logits, dim=0)
+    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
+
+
+def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
+    logp = torch.sigmoid(logits) + offset
+    loss = torch.abs(logp - target)
+    if mask is not None:
+        w = mask.mean(2, True).mean(1, True)
+        w[w == 0] = 1
+        loss = loss * (mask / w)
+
+    return loss.mean(2).mean(1)
+class DiceLoss(nn.Module):
+    def __init__(self, smooth=1.):
+        super(DiceLoss, self).__init__()
+        self.smooth = smooth
+
+    def forward(self, logits, targets):
+        probs = torch.sigmoid(logits)
+        probs = probs.view(-1)
+        targets = targets.view(-1).float()
+
+        intersection = (probs * targets).sum()
+        dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
+        return 1. - dice
+
+
+
+bce_loss = nn.BCEWithLogitsLoss()
+dice_loss = DiceLoss()
+
+
+def combined_loss(preds, targets, alpha=0.5):
+    bce = bce_loss(preds, targets)
+    d = dice_loss(preds, targets)
+    return alpha * bce + (1 - alpha) * d
+
+###计算多头损失
+def line_head_loss(input_dict, outputs, feature, loss_weight, mode_train):
+    # image = input_dict["image"]
+    # target_b = input_dict["target_b"]
+    # outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
+
+    result = {"feature": feature}
+    batch, channel, row, col = outputs[0].shape
+
+    T = input_dict["target"].copy()
+    n_jtyp = T["junc_map"].shape[1]
+
+    # switch to CNHW
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    for stack, output in enumerate(outputs):
+        output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+        jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+        lmap = output[offset[0]: offset[1]].squeeze(0)
+        # print(f"lmap:{lmap.shape}")
+        joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+        if stack == 0:
+            result["preds"] = {
+                "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                "lmap": lmap.sigmoid(),
+                "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+            }
+            if mode_train == False:
+                return result
+
+        L = OrderedDict()
+        L["jmap"] = sum(
+            cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+        )
+        L["lmap"] = (
+            F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+                .mean(2)
+                .mean(1)
+        )
+        L["joff"] = sum(
+            sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+            for i in range(n_jtyp)
+            for j in range(2)
+        )
+        for loss_name in L:
+            L[loss_name].mul_(loss_weight[loss_name])
+        losses.append(L)
+    result["losses"] = losses
+    # result["aaa"] = aaa
+    return result
+
+
+#  计算线性损失
+def line_vectorizer_loss(result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc, loss_weight, mode_train):
+    if mode_train == False:
+        p = torch.cat(ps)
+        s = torch.sigmoid(x)
+        b = s > 0.5
+        lines = []
+        score = []
+        for i in range(n_batch):
+            p0 = p[idx[i]: idx[i + 1]]
+            s0 = s[idx[i]: idx[i + 1]]
+            mask = b[idx[i]: idx[i + 1]]
+            p0 = p0[mask]
+            s0 = s0[mask]
+            if len(p0) == 0:
+                lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+                score.append(torch.zeros([1, n_out_line], device=p.device))
+            else:
+                arg = torch.argsort(s0, descending=True)
+                p0, s0 = p0[arg], s0[arg]
+                lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+                score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+            for j in range(len(jcs[i])):
+                if len(jcs[i][j]) == 0:
+                    jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+                jcs[i][j] = jcs[i][j][
+                    None, torch.arange(n_out_junc) % len(jcs[i][j])
+                ]
+        result["preds"]["lines"] = torch.cat(lines)
+        result["preds"]["score"] = torch.cat(score)
+        result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+        if len(jcs[i]) > 1:
+            result["preds"]["junts"] = torch.cat(
+                [jcs[i][1] for i in range(n_batch)]
+            )
+
+    # if input_dict["mode"] != "testing":
+    y = torch.cat(ys)
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
+    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+
+    if mode_train == True:
+        del result["preds"]
+
+    return result
+
+
+def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
+    # output, feature: head返回结果
+    # x, y, idx : line中间生成结果
+    result = {}
+    batch, channel, row, col = output.shape
+
+    wires_targets = [t["wires"] for t in targets]
+    wires_targets = wires_targets.copy()
+    # print(f'wires_target:{wires_targets}')
+    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+    junc_maps = [d["junc_map"] for d in wires_targets]
+    junc_offsets = [d["junc_offset"] for d in wires_targets]
+    line_maps = [d["line_map"] for d in wires_targets]
+
+    junc_map_tensor = torch.stack(junc_maps, dim=0)
+    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+    line_map_tensor = torch.stack(line_maps, dim=0)
+    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
+
+    n_jtyp = T["junc_map"].shape[1]
+
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+    lmap = output[offset[0]: offset[1]].squeeze(0)
+    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+    L = OrderedDict()
+    # L["junc_map"] = sum(
+    #     cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+    # ).mean()
+    # L["line_map"] = (
+    #     F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+    #         .mean(2)
+    #         .mean(1)
+    # ).mean()
+    L["junc_map"] = combined_loss(jmap[:, 1, :, :, :], T["junc_map"])
+
+    L["line_map"] = combined_loss(lmap, T["line_map"])
+    L["junc_offset"] = sum(
+        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+        for i in range(n_jtyp)
+        for j in range(2)
+    ).mean()
+    for loss_name in L:
+        L[loss_name].mul_(loss_weight[loss_name])
+    losses.append(L)
+    result["losses"] = losses
+
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = (lpos * loss_weight["lpos"]).mean()
+    result["losses"][0]["lneg"] = (lneg * loss_weight["lneg"]).mean()
+
+    return result
+
+
+def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
+    result = {}
+    result["wires"] = {}
+    p = torch.cat(ps)
+    s = torch.sigmoid(input)
+    b = s > 0.5
+    lines = []
+    score = []
+    # print(f"n_batch:{n_batch}")
+    for i in range(n_batch):
+        # print(f"idx:{idx}")
+        p0 = p[idx[i]: idx[i + 1]]
+        s0 = s[idx[i]: idx[i + 1]]
+        mask = b[idx[i]: idx[i + 1]]
+        p0 = p0[mask]
+        s0 = s0[mask]
+        if len(p0) == 0:
+            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+            score.append(torch.zeros([1, n_out_line], device=p.device))
+        else:
+            arg = torch.argsort(s0, descending=True)
+            p0, s0 = p0[arg], s0[arg]
+            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+        for j in range(len(jcs[i])):
+            if len(jcs[i][j]) == 0:
+                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+            jcs[i][j] = jcs[i][j][
+                None, torch.arange(n_out_junc) % len(jcs[i][j])
+            ]
+    result["wires"]["lines"] = torch.cat(lines)
+    result["wires"]["score"] = torch.cat(score)
+    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+
+    if len(jcs[i]) > 1:
+        result["preds"]["junts"] = torch.cat(
+            [jcs[i][1] for i in range(n_batch)]
+        )
+    # print(f'predic result:{result}')
+    return result
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+            .index_select(2, x_int.to(dtype=torch.int64))
+            .view(-1)
+            .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+            self,
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            line_head,
+            line_predictor,
+            # Faster R-CNN training
+            fg_iou_thresh,
+            bg_iou_thresh,
+            batch_size_per_image,
+            positive_fraction,
+            bbox_reg_weights,
+            # Faster R-CNN inference
+            score_thresh,
+            nms_thresh,
+            detections_per_img,
+            # Mask
+            mask_roi_pool=None,
+            mask_head=None,
+            mask_predictor=None,
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.line_head = line_head
+        self.line_predictor = line_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_line(self):
+        # if self.mask_roi_pool is None:
+        #     return False
+        if self.line_head is None:
+            return False
+        if self.line_predictor is None:
+            return False
+        return True
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+            self,
+            proposals,  # type: List[Tensor]
+            targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+            self,
+            class_logits,  # type: Tensor
+            box_regression,  # type: Tensor
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+            self,
+            features,  # type: Dict[str, Tensor]
+            proposals,  # type: List[Tensor]
+            lines,
+            image_shapes,  # type: List[Tuple[int, int]]
+            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        # if targets is not None:
+        #     self.training = True
+        #     # print(f'targets is not None')
+        #
+        # else:
+        #     self.training = False
+        #     # print(f'targets is None')
+        print(f'roihead forward!!!')
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            if targets is not None:
+                proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+            else:
+                labels = None
+                regression_targets = None
+                matched_idxs = None
+
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            if targets is not None:
+                loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+                losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+            else:
+                boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+                num_images = len(boxes)
+                for i in range(num_images):
+                    result.append(
+                        {
+                            "boxes": boxes[i],
+                            "labels": labels[i],
+                            "scores": scores[i],
+                            "lines":lines[i],
+                        }
+                    )
+
+        line_features = features['0']
+        if self.has_line():
+            # print('has line_head')
+            # outputs = self.line_head(features_lcnn)
+            # outputs = line_features[:, 0:5, :, :]
+
+
+            loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
+                inputs=line_features, features=line_features, targets=targets)
+
+            # # line_loss(multitasklearner)
+            # if self.training:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=True)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=True)
+            # else:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=False)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=False)
+
+            if self.training:
+                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
+                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+                # print(f'loss_wirepoint:{loss_wirepoint}')
+
+            else:
+                # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
+                if targets is not None:
+                    rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
+                    loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+                else:
+                    print(f'model inference!!!')
+                    pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                    result.append(line_features)
+                    result.append(pred)
+                    loss_wirepoint = {}
+
+            losses.update(loss_wirepoint)
+        else:
+            pass
+            # print('has not line_head')
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+                self.keypoint_roi_pool is not None
+                and self.keypoint_head is not None
+                and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        # print(f'roi losses:{losses}')
+        return result, losses

+ 0 - 0
models/line_detect/test_tiff.py → models/line_net/test_tiff.py


+ 38 - 0
models/line_net/train.yaml

@@ -0,0 +1,38 @@
+io:
+  logdir: train_results
+  datadir: \\192.168.50.222/share/lm/Dataset_all
+#  datadir: D:\python\PycharmProjects\data_20250223\0423_
+#  datadir: I:\datasets\wirenet_1000
+
+  tensorboard_port: 6000
+  validation_interval: 300
+
+train_params:
+  resume_from:
+  num_workers: 8
+  batch_size: 2
+  max_epoch: 80000
+  optim:
+    name: Adam
+    lr: 4.0e-4
+    amsgrad: True
+    weight_decay: 1.0e-4
+    lr_decay_epoch: 10
+
+#  冻结参数
+  freeze_params:
+    backbone: False,
+    rpn: False,
+    roi_heads:
+      box_head: False,
+      box_predictor: False,
+      line_head: False,
+      line_predictor:
+        fc1: False,
+        fc2:
+          0: False,
+          2: False,
+          4: False
+
+
+

+ 19 - 0
models/line_net/train_demo.py

@@ -0,0 +1,19 @@
+import torch
+
+from models.line_net.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
+    get_line_net_convnext_fpn, linenet_newresnet18fpn
+from models.line_net.trainer import Trainer
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if __name__ == '__main__':
+
+    # model = LineNet('line_net.yaml')
+    # model=linenet_resnet50_fpn()
+    # model = linenet_resnet18_fpn()
+    # model=get_line_net_convnext_fpn(num_classes=2).to(device)
+    # model=linenet_newresnet50fpn()
+    model = linenet_newresnet18fpn()
+    # model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
+    # trainer = Trainer()
+    # trainer.train_cfg(model,cfg='./train.yaml')
+    model.start_train(cfg='train.yaml')

+ 362 - 0
models/line_net/trainer.py

@@ -0,0 +1,362 @@
+import os
+import time
+from datetime import datetime
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.tensorboard import SummaryWriter
+
+from libs.vision_libs.utils import draw_bounding_boxes
+from models.base.base_model import BaseModel
+from models.base.base_trainer import BaseTrainer
+from models.config.config_tool import read_yaml
+from models.line_net.dataset_LD import WirePointDataset
+from models.wirenet.postprocess import postprocess
+from tools import utils
+from torchvision import transforms
+import matplotlib as mpl
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+    return total_loss
+def c(x):
+    return sm.to_rgba(x)
+
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+class Trainer(BaseTrainer):
+    def __init__(self, model=None, **kwargs):
+        super().__init__(model, device, **kwargs)
+        self.model = model
+        # print(f'kwargs:{kwargs}')
+        self.init_params(**kwargs)
+
+    def init_params(self, **kwargs):
+        if kwargs != {}:
+            print(f'train_params:{kwargs["train_params"]}')
+            self.freeze_config = kwargs['train_params']['freeze_params']
+            print(f'freeze_config:{self.freeze_config}')
+            self.dataset_path = kwargs['io']['datadir']
+            self.batch_size = kwargs['train_params']['batch_size']
+            self.num_workers = kwargs['train_params']['num_workers']
+            self.logdir = kwargs['io']['logdir']
+            self.resume_from = kwargs['train_params']['resume_from']
+            self.optim = ''
+            self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
+            self.wts_path = os.path.join(self.train_result_ptath, 'weights')
+            self.tb_path = os.path.join(self.train_result_ptath, 'logs')
+            self.writer = SummaryWriter(self.tb_path)
+            self.last_model_path = os.path.join(self.wts_path, 'last.pth')
+            self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
+            self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
+            self.max_epoch = kwargs['train_params']['max_epoch']
+
+    def move_to_device(self, data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(self.move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: self.move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+    def freeze_params(self, model):
+        """根据配置冻结模型参数"""
+        default_config = {
+            'backbone': True,  # 冻结 backbone
+            'rpn': False,  # 不冻结 rpn
+            'roi_heads': {
+                'box_head': False,
+                'box_predictor': False,
+                'line_head': False,
+                'line_predictor': {
+                    'fc1': False,
+                    'fc2': {
+                        '0': False,
+                        '2': False,
+                        '4': False
+                    }
+                }
+            }
+        }
+
+        # 更新默认配置
+        default_config.update(self.freeze_config)
+        config = default_config
+
+        print("\n===== Parameter Freezing Configuration =====")
+        for name, module in model.named_children():
+            if name in config:
+                if isinstance(config[name], bool):
+                    for param in module.parameters():
+                        param.requires_grad = not config[name]
+                    print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
+
+                elif isinstance(config[name], dict):
+                    for subname, submodule in module.named_children():
+                        if subname in config[name]:
+                            if isinstance(config[name][subname], bool):
+                                for param in submodule.parameters():
+                                    param.requires_grad = not config[name][subname]
+                                print(
+                                    f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
+
+                            elif isinstance(config[name][subname], dict):
+                                for subsubname, subsubmodule in submodule.named_children():
+                                    if subsubname in config[name][subname]:
+                                        for param in subsubmodule.parameters():
+                                            param.requires_grad = not config[name][subname][subsubname]
+                                        print(
+                                            f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
+
+        # 打印参数统计
+        total_params = sum(p.numel() for p in model.parameters())
+        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print(f"\nTotal Parameters: {total_params:,}")
+        print(f"Trainable Parameters: {trainable_params:,}")
+        print(f"Frozen Parameters: {total_params - trainable_params:,}")
+
+    def load_best_model(self, model, optimizer, save_path, device):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            if optimizer is not None:
+                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            epoch = checkpoint['epoch']
+            loss = checkpoint['loss']
+            print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model, optimizer
+
+    def writer_predict_result(self, img, result, epoch):
+        img = img.cpu().detach()
+        im = img.permute(1, 2, 0)
+        self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
+
+        boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
+                                          colors="yellow", width=1)
+        self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+        PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+        # print(f'pred[1]:{pred[1]}')
+        heatmaps = result[-2][0]
+        print(f'heatmaps:{heatmaps.shape}')
+        jmap = heatmaps[1: 2].cpu().detach()
+        lmap = heatmaps[2: 3].cpu().detach()
+        self.writer.add_image("z-jmap", jmap, epoch)
+        self.writer.add_image("z-lmap", lmap, epoch)
+        # plt.imshow(lmap)
+        # plt.show()
+        H = result[-1]['wires']
+        # lines = H["lines"][0].cpu().numpy()
+        lines=result[0]["lines"]
+        scores =100
+        for i in range(1, len(lines)):
+            if (lines[i] == lines[0]).all():
+                lines = lines[:i]
+                scores = scores[:i]
+                break
+
+        # postprocess lines to remove overlapped lines
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+        for i, t in enumerate([0]):
+            plt.gca().set_axis_off()
+            plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+            plt.margins(0, 0)
+            for (a, b), s in zip(nlines, nscores):
+                if s < t:
+                    continue
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+                plt.scatter(a[1], a[0], **PLTOPTS)
+                plt.scatter(b[1], b[0], **PLTOPTS)
+            plt.gca().xaxis.set_major_locator(plt.NullLocator())
+            plt.gca().yaxis.set_major_locator(plt.NullLocator())
+            plt.imshow(im)
+            plt.tight_layout()
+            fig = plt.gcf()
+            fig.canvas.draw()
+
+            width, height = fig.get_size_inches() * fig.get_dpi()  # 获取图像尺寸
+            tmp_img = fig.canvas.tostring_argb()
+            tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
+            tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
+
+            img_rgb = tmp_img_np[:, :, 1:]  # 提取RGB部分,忽略Alpha通道
+
+            # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
+            #     fig.canvas.get_width_height()[::-1] + (3,))
+            plt.close()
+
+            img2 = transforms.ToTensor()(img_rgb)
+
+            self.writer.add_image("z-output", img2, epoch)
+
+    def writer_loss(self, losses, epoch, phase='train'):
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            self.writer.add_scalar(f'{phase}/loss/{subkey}',
+                                                   subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                                   epoch)
+                elif isinstance(value, torch.Tensor):
+                    self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+    def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None):  # 新增:支持传入冻结配置
+        cfg = read_yaml(cfg)
+        # print(f'cfg:{cfg}')
+        # self.freeze_config = freeze_config or {}  # 更新冻结配置
+
+        self.train(model, **cfg)
+
+    def train(self, model, **kwargs):
+
+        self.init_params(**kwargs)
+
+        dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
+        dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
+
+        train_sampler = torch.utils.data.RandomSampler(dataset_train)
+        val_sampler = torch.utils.data.RandomSampler(dataset_val)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
+        train_collate_fn = utils.collate_fn
+        val_collate_fn = utils.collate_fn
+
+        data_loader_train = torch.utils.data.DataLoader(
+            dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
+        )
+        data_loader_val = torch.utils.data.DataLoader(
+            dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
+        )
+
+        model.to(device)
+
+        optimizer = torch.optim.Adam(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=kwargs['train_params']['optim']['lr']
+        )
+
+        for epoch in range(self.max_epoch):
+            print(f"train epoch:{epoch}")
+
+            model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
+
+            # ========== Validation ==========
+            with torch.no_grad():
+                model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
+
+            if epoch==0:
+                best_train_loss = epoch_train_loss
+                best_val_loss = epoch_val_loss
+
+            self.save_last_model(model,self.last_model_path, epoch, optimizer)
+            best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
+                                                   best_train_loss,
+                                                   optimizer)
+            best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
+                                                 optimizer)
+
+    def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
+        if phase == 'train':
+            model.train()
+        if phase == 'val':
+            model.eval()
+
+        total_loss = 0
+        epoch_step = 0
+        global_step = epoch * len(data_loader)
+        for imgs, targets in data_loader:
+            imgs = self.move_to_device(imgs, device)
+            targets = self.move_to_device(targets, device)
+            if phase== 'val':
+
+                result,losses = model(imgs, targets)
+            else:
+                losses = model(imgs, targets)
+
+            loss = _loss(losses)
+            total_loss += loss.item()
+            if phase == 'train':
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+            self.writer_loss(losses, global_step, phase=phase)
+            global_step += 1
+
+            if epoch_step == 0 and phase == 'val':
+                t_start = time.time()
+                print(f'start to predict:{t_start}')
+                result = model(self.move_to_device(imgs, self.device))
+                t_end = time.time()
+                print(f'predict used:{t_end - t_start}')
+                self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
+                epoch_step+=1
+
+        avg_loss = total_loss / len(data_loader)
+        print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
+        self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
+        return model, avg_loss
+
+    def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+        if current_loss <= best_loss:
+            checkpoint = {
+                'epoch': epoch,
+                'model_state_dict': model.state_dict(),
+                'loss': current_loss
+            }
+            if optimizer is not None:
+                checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+            torch.save(checkpoint, save_path)
+            print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
+
+            return current_loss
+
+        return best_loss
+
+    def save_last_model(self, model, save_path, epoch, optimizer=None):
+
+        if os.path.exists(f'{self.wts_path}/last.pt'):
+            os.remove(f'{self.wts_path}/last.pt')
+
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+        checkpoint = {
+            'epoch': epoch,
+            'model_state_dict': model.state_dict(),
+        }
+
+        if optimizer is not None:
+            checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+        torch.save(checkpoint, save_path)
+
+
+if __name__ == '__main__':
+    print('')

+ 0 - 0
models/line_net/untitled.py


+ 3 - 3
train——line_rcnn.py

@@ -5,20 +5,20 @@ import numpy as np
 import torch
 
 from models.config.config_tool import read_yaml
-from models.line_detect.dataset_LD import WirePointDataset
+from models.line_net.dataset_LD import WirePointDataset
 from tools import utils
 
 from torch.utils.tensorboard import SummaryWriter
 import matplotlib as mpl
 
-from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_net.line_net import linenet_resnet50_fpn
 from torchvision.utils import draw_bounding_boxes
 from models.wirenet.postprocess import postprocess
 from torchvision import transforms
 
 from PIL import Image
 
-from models.line_detect.postprocess import box_line_, show_
+from models.line_net.postprocess import box_line_, show_
 import matplotlib.pyplot as plt
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')