فهرست منبع

合并重构部分模块代码

RenLiqiang 3 ماه پیش
والد
کامیت
0c48d83264
47فایلهای تغییر یافته به همراه8979 افزوده شده و 0 حذف شده
  1. 0 0
      models/__init__.py
  2. 0 0
      models/base/__init__.py
  3. 48 0
      models/base/base_dataset.py
  4. 0 0
      models/config/__init__.py
  5. 22 0
      models/config/config_tool.py
  6. 42 0
      models/config/test_config.py
  7. 34 0
      models/config/train.yaml
  8. 318 0
      models/dataset_tool.py
  9. 0 0
      models/ins/__init__.py
  10. 143 0
      models/ins/maskrcnn.py
  11. 93 0
      models/ins/maskrcnn_dataset.py
  12. 31 0
      models/ins/train.yaml
  13. 220 0
      models/ins/trainer.py
  14. 0 0
      models/keypoint/__init__.py
  15. 290 0
      models/keypoint/kepointrcnn.py
  16. 205 0
      models/keypoint/keypoint_dataset.py
  17. 65 0
      models/keypoint/test.py
  18. 77 0
      models/keypoint/test_predict.py
  19. 32 0
      models/keypoint/train.yaml
  20. 380 0
      models/keypoint/trainer.py
  21. 0 0
      models/line_net/__init__.py
  22. 120 0
      models/line_net/fasterrcnn_resnet50.py
  23. 0 0
      models/obj/__init__.py
  24. 44 0
      models/utils.py
  25. 46 0
      models/wirenet/TestPointMap.py
  26. 0 0
      models/wirenet/__init__.py
  27. 548 0
      models/wirenet/_utils.py
  28. 1290 0
      models/wirenet/head.py
  29. 126 0
      models/wirenet/postprocess.py
  30. 896 0
      models/wirenet/roi_head.py
  31. 23 0
      models/wirenet/test.py
  32. 14 0
      models/wirenet/test_mask.py
  33. 20 0
      models/wirenet/train.py
  34. 69 0
      models/wirenet/wirenet.yaml
  35. 178 0
      models/wirenet/wirepoint_dataset.py
  36. 847 0
      models/wirenet/wirepoint_rcnn.py
  37. 70 0
      models/wirenet2/WirePredictor.py
  38. 0 0
      models/wirenet2/__init__.py
  39. 548 0
      models/wirenet2/_utils.py
  40. 82 0
      models/wirenet2/kepointrcnn.py
  41. 203 0
      models/wirenet2/keypoint_dataset.py
  42. 879 0
      models/wirenet2/roi_heads.py
  43. 65 0
      models/wirenet2/test.py
  44. 50 0
      models/wirenet2/test_linemap.py
  45. 32 0
      models/wirenet2/train.yaml
  46. 212 0
      models/wirenet2/trainer.py
  47. 617 0
      models/wirenet2/wirenet_rcnn.py

+ 0 - 0
models/__init__.py


+ 0 - 0
models/base/__init__.py


+ 48 - 0
models/base/base_dataset.py

@@ -0,0 +1,48 @@
+from abc import ABC, abstractmethod
+
+import torch
+from torch import nn, Tensor
+from torch.utils.data import Dataset
+from torch.utils.data.dataset import T_co
+
+from torchvision.transforms import  functional as F
+
+class BaseDataset(Dataset, ABC):
+    def __init__(self,dataset_path):
+        self.default_transform=DefaultTransform()
+        pass
+
+    def __getitem__(self, index) -> T_co:
+        pass
+
+    @abstractmethod
+    def read_target(self,item,lbl_path,extra=None):
+        pass
+
+    """显示数据集指定图片"""
+    @abstractmethod
+    def show(self,idx):
+        pass
+
+    """
+    显示数据集指定名字的图片
+    """
+
+    @abstractmethod
+    def show_img(self,img_path):
+        pass
+
+class DefaultTransform(nn.Module):
+    def forward(self, img: Tensor) -> Tensor:
+        if not isinstance(img, Tensor):
+            img = F.pil_to_tensor(img)
+        return F.convert_image_dtype(img, torch.float)
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__ + "()"
+
+    def describe(self) -> str:
+        return (
+            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
+            "The images are rescaled to ``[0.0, 1.0]``."
+        )

+ 0 - 0
models/config/__init__.py


+ 22 - 0
models/config/config_tool.py

@@ -0,0 +1,22 @@
+import yaml
+
+
+def read_yaml(path='application.yaml'):
+    try:
+        with open(path, 'r') as file:
+            data = file.read()
+            # result = yaml.load(data)
+            result = yaml.load(data, Loader=yaml.FullLoader)
+
+            return result
+    except Exception as e:
+        print(e)
+        return None
+
+
+def write_yaml(path='application.yaml', data=None):
+    try:
+        with open(path, 'w', encoding='utf-8') as f:
+            yaml.dump(data=data, stream=f, allow_unicode=True)
+    except Exception as e:
+        print(e)

+ 42 - 0
models/config/test_config.py

@@ -0,0 +1,42 @@
+import yaml
+
+test_data = {
+    'cameras': [{
+        'id': 1,
+        'ip': "192.168.1.2"
+    }, {
+        'id': 2,
+        'ip': "192.168.1.3"
+    }]
+}
+
+
+def read_yaml(path):
+    try:
+        with open(path, 'r') as file:
+            data = file.read()
+            # result = yaml.load(data)
+            result = yaml.load(data, Loader=yaml.FullLoader)
+
+            return result
+    except Exception as e:
+        print(e)
+        return None
+
+
+def write_yaml(path):
+    try:
+        with open('path', 'w', encoding='utf-8') as f:
+            yaml.dump(data=test_data, stream=f, allow_unicode=True)
+    except Exception as e:
+        print(e)
+
+
+if __name__ == '__main__':
+    p = 'train.yaml'
+    result = read_yaml(p)
+    # j=json.load(result)
+    print('result', result)
+    # print('cameras', result['cameras'])
+    # print('json',j)
+

+ 34 - 0
models/config/train.yaml

@@ -0,0 +1,34 @@
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
+#train: images/train  # train images (relative to 'path') 128 images
+#val: images/train  # val images (relative to 'path') 128 images
+#test: images/test  # test images (optional)
+
+#train parameters
+num_classes: 5
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: polygon
+enable_logs: True
+augmentation: True
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 318 - 0
models/dataset_tool.py

@@ -0,0 +1,318 @@
+import cv2
+import numpy as np
+import torch
+import torchvision
+from matplotlib import pyplot as plt
+import tools.transforms as reference_transforms
+from collections import defaultdict
+
+from tools import presets
+
+import json
+
+
+def get_modules(use_v2):
+    # We need a protected import to avoid the V2 warning in case just V1 is used
+    if use_v2:
+        import torchvision.transforms.v2
+        import torchvision.tv_tensors
+
+        return torchvision.transforms.v2, torchvision.tv_tensors
+    else:
+        return reference_transforms, None
+
+
+class Augmentation:
+    # Note: this transform assumes that the input to forward() are always PIL
+    # images, regardless of the backend parameter.
+    def __init__(
+            self,
+            *,
+            data_augmentation,
+            hflip_prob=0.5,
+            mean=(123.0, 117.0, 104.0),
+            backend="pil",
+            use_v2=False,
+    ):
+
+        T, tv_tensors = get_modules(use_v2)
+
+        transforms = []
+        backend = backend.lower()
+        if backend == "tv_tensor":
+            transforms.append(T.ToImage())
+        elif backend == "tensor":
+            transforms.append(T.PILToTensor())
+        elif backend != "pil":
+            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
+
+        if data_augmentation == "hflip":
+            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
+        elif data_augmentation == "lsj":
+            transforms += [
+                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
+                # TODO: FixedSizeCrop below doesn't work on tensors!
+                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "multiscale":
+            transforms += [
+                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssd":
+            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
+            transforms += [
+                T.RandomPhotometricDistort(),
+                T.RandomZoomOut(fill=fill),
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssdlite":
+            transforms += [
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        else:
+            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
+
+        if backend == "pil":
+            # Note: we could just convert to pure tensors even in v2.
+            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
+
+        transforms += [T.ToDtype(torch.float, scale=True)]
+
+        if use_v2:
+            transforms += [
+                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
+                T.SanitizeBoundingBoxes(),
+                T.ToPureTensor(),
+            ]
+
+        self.transforms = T.Compose(transforms)
+
+    def __call__(self, img, target):
+        return self.transforms(img, target)
+
+
+def read_polygon_points(lbl_path, shape):
+    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
+    polygon_points = []
+    w, h = shape[:2]
+    with open(lbl_path, 'r') as f:
+        lines = f.readlines()
+
+    for line in lines:
+        parts = line.strip().split()
+        class_id = int(parts[0])
+        points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2)  # 读取点坐标
+        points[:, 0] *= h
+        points[:, 1] *= w
+
+        polygon_points.append((class_id, points))
+
+    return polygon_points
+
+
+def read_masks_from_pixels(lbl_path, shape):
+    """读取纯像素点格式的文件,不是轮廓像素点"""
+    h, w = shape
+    masks = []
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = reader.readlines()
+        mask_points = []
+        for line in lines:
+            mask = torch.zeros((h, w), dtype=torch.uint8)
+            parts = line.strip().split()
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
+            labels.append(cls)
+            x_array = parts[1::2]
+            y_array = parts[2::2]
+
+            for x, y in zip(x_array, y_array):
+                x = float(x)
+                y = float(y)
+                mask_points.append((int(y * h), int(x * w)))
+
+            for p in mask_points:
+                mask[p] = 1
+            masks.append(mask)
+    reader.close()
+    return labels, masks
+
+
+def create_masks_from_polygons(polygons, image_shape):
+    """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+    colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+    masks = []
+
+    for polygon_data, col in zip(polygons, colors):
+        mask = np.zeros(image_shape[:2], dtype=np.uint8)
+        # 将多边形顶点转换为 NumPy 数组
+        _, polygon = polygon_data
+        pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+        # 使用 OpenCV 的 fillPoly 函数填充多边形
+        # print(f'color:{col[:3]}')
+        cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+        mask = torch.from_numpy(mask)
+        mask[mask != 0] = 1
+        masks.append(mask)
+
+    return masks
+
+
+def read_masks_from_txt(label_path, shape):
+    polygon_points = read_polygon_points(label_path, shape)
+    masks = create_masks_from_polygons(polygon_points, shape)
+    labels = [torch.tensor(item[0]) for item in polygon_points]
+
+    return labels, masks
+
+
+def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
+    """
+    Compute the bounding boxes around the provided masks.
+
+    Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
+    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+    Args:
+        masks (Tensor[N, H, W]): masks to transform where N is the number of masks
+            and (H, W) are the spatial dimensions.
+
+    Returns:
+        Tensor[N, 4]: bounding boxes
+    """
+    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+    #     _log_api_usage_once(masks_to_boxes)
+    if masks.numel() == 0:
+        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
+
+    n = masks.shape[0]
+
+    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
+
+    for index, mask in enumerate(masks):
+        y, x = torch.where(mask != 0)
+        bounding_boxes[index, 0] = torch.min(x)
+        bounding_boxes[index, 1] = torch.min(y)
+        bounding_boxes[index, 2] = torch.max(x)
+        bounding_boxes[index, 3] = torch.max(y)
+        # debug to pixel datasets
+
+        if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
+            bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
+            bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
+
+        if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
+            bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
+            bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
+
+    return bounding_boxes
+
+
+def line_boxes(target):
+    boxs = []
+    lpre = target['wires']["lpre"].cpu().numpy() * 4
+    vecl_target = target['wires']["lpre_label"].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    lines = lpre
+    sline = np.ones(lpre.shape[0])
+
+    keypoints = []
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+            keypoints.append([a[0], b[0]])
+            keypoints.append([a[1], b[1]])
+
+            if a[1] > b[1]:
+                ymax = a[1] + 10
+                ymin = b[1] - 10
+            else:
+                ymin = a[1] - 10
+                ymax = b[1] + 10
+            if a[0] > b[0]:
+                xmax = a[0] + 10
+                xmin = b[0] - 10
+            else:
+                xmin = a[0] - 10
+                xmax = b[0] + 10
+            boxs.append([ymin, xmin, ymax, xmax])
+
+    return torch.tensor(boxs), torch.tensor(keypoints)
+
+
+def read_polygon_points_wire(lbl_path, shape):
+    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
+    polygon_points = []
+    w, h = shape[:2]
+    with open(lbl_path, 'r') as f:
+        lines = json.load(f)
+
+    for line in lines["segmentations"]:
+        parts = line["data"]
+        class_id = int(line["cls_id"])
+        points = np.array(parts, dtype=np.float32).reshape(-1, 2)  # 读取点坐标
+        points[:, 0] *= h
+        points[:, 1] *= w
+
+        polygon_points.append((class_id, points))
+
+    return polygon_points
+
+
+def read_masks_from_txt_wire(label_path, shape):
+    polygon_points = read_polygon_points_wire(label_path, shape)
+    masks = create_masks_from_polygons(polygon_points, shape)
+    labels = [torch.tensor(item[0]) for item in polygon_points]
+
+    return labels, masks
+
+
+def read_masks_from_pixels_wire(lbl_path, shape):
+    """读取纯像素点格式的文件,不是轮廓像素点"""
+    h, w = shape
+    masks = []
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = json.load(reader)
+        mask_points = []
+        for line in lines["segmentations"]:
+            # mask = torch.zeros((h, w), dtype=torch.uint8)
+            # parts = line["data"]
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
+            labels.append(cls)
+            # x_array = parts[0::2]
+            # y_array = parts[1::2]
+            # 
+            # for x, y in zip(x_array, y_array):
+            #     x = float(x)
+            #     y = float(y)
+            #     mask_points.append((int(y * h), int(x * w)))
+
+            # for p in mask_points:
+            #     mask[p] = 1
+            # masks.append(mask)
+    reader.close()
+    return labels
+
+
+def adjacency_matrix(n, link):  # 邻接矩阵
+    mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+    link = torch.tensor(link)
+    if len(link) > 0:
+        mat[link[:, 0], link[:, 1]] = 1
+        mat[link[:, 1], link[:, 0]] = 1
+    return mat

+ 0 - 0
models/ins/__init__.py


+ 143 - 0
models/ins/maskrcnn.py

@@ -0,0 +1,143 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.ins.trainer import train_cfg
+from tools import utils
+
+
+class MaskRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=0, transforms=None):
+        super(MaskRCNNModel, self).__init__()
+        self.__model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
+            weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
+        if transforms is None:
+            self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        if num_classes != 0:
+            self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes=parameters['num_classes']
+        # print(f'num_classes:{num_classes}')
+        self.set_num_classes(num_classes)
+        train_cfg(self.__model, cfg)
+
+    def set_num_classes(self, num_classes):
+        in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+        self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+        in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+        hidden_layer = 256
+        self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
+                                                                  num_classes=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+    def predict(self, src, show_box=True, show_mask=True):
+        self.__model.eval()
+
+        img = read_image(src)
+        img = self.transforms(img)
+        img = img.to(self.device)
+        result = self.__model([img])
+        print(f'result:{result}')
+        masks = result[0]['masks']
+        boxes = result[0]['boxes']
+        # cv2.imshow('mask',masks[0].cpu().detach().numpy())
+        boxes = boxes.cpu().detach()
+        drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
+        print(f'drawn_boxes:{drawn_boxes.shape}')
+        boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
+        # boxed_img=cv2.resize(boxed_img,(800,800))
+        # cv2.imshow('boxes',boxed_img)
+
+        mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
+
+        mask = cv2.resize(mask, (800, 800))
+        # cv2.imshow('mask',mask)
+        img = img.cpu().detach().permute(1, 2, 0).numpy()
+
+        masked_img = self.overlay_masks_on_image(boxed_img, masks)
+        masked_img = cv2.resize(masked_img, (800, 800))
+        cv2.imshow('img_masks', masked_img)
+        # show_img_boxes_masks(img, boxes, masks)
+        cv2.waitKey(0)
+
+    def generate_colors(self, n):
+        """
+        生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
+
+        :param n: 需要的颜色数量
+        :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
+        """
+        hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
+        bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
+        return bgr_colors
+
+    def overlay_masks_on_image(self, image, masks, alpha=0.6):
+        """
+        在原图上叠加多个掩码,每个掩码使用不同的颜色。
+
+        :param image: 原图 (NumPy 数组)
+        :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
+        :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
+        :param alpha: 掩码的透明度 (0.0 到 1.0)
+        :return: 叠加了多个掩码的图像
+        """
+        colors = self.generate_colors(len(masks))
+        if len(masks) != len(colors):
+            raise ValueError("The number of masks and colors must be the same.")
+
+        # 复制原图,避免修改原始图像
+        overlay = image.copy()
+
+        for mask, color in zip(masks, colors):
+            # 确保掩码是二值图像
+            mask = mask.cpu().detach().permute(1, 2, 0).numpy()
+            binary_mask = (mask > 0).astype(np.uint8) * 255  # 你可以根据实际情况调整阈值
+
+            # 创建彩色掩码
+            colored_mask = np.zeros_like(image)
+
+            colored_mask[:] = color
+            colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
+
+            # 将彩色掩码与当前的叠加图像混合
+            overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
+
+        return overlay
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    ins_model = MaskRCNNModel()
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    ins_model.train(cfg='train.yaml')

+ 93 - 0
models/ins/maskrcnn_dataset.py

@@ -0,0 +1,93 @@
+import os
+
+import PIL
+import cv2
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.data import Dataset
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
+
+
+class MaskRCNNDataset(Dataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
+        self.data_path = dataset_path
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        # print('maskrcnn inited!')
+
+    def __getitem__(self, item):
+        # print('__getitem__')
+        img_path = os.path.join(self.img_path, self.imgs[item])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
+        img = PIL.Image.open(img_path).convert('RGB')
+        # h, w = np.array(img).shape[:2]
+        w, h = img.size
+        # print(f'h,w:{h, w}')
+        target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img,target)
+        else:
+            img=self.deafult_transform(img)
+        # print(f'img:{img.shape},target:{target}')
+        return img, target
+
+    def create_masks_from_polygons(self, polygons, image_shape):
+        """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+        masks = []
+
+        for polygon_data, col in zip(polygons, colors):
+            mask = np.zeros(image_shape[:2], dtype=np.uint8)
+            # 将多边形顶点转换为 NumPy 数组
+            _, polygon = polygon_data
+            pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+            # 使用 OpenCV 的 fillPoly 函数填充多边形
+            # print(f'color:{col[:3]}')
+            cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+            mask = torch.from_numpy(mask)
+            mask[mask != 0] = 1
+            masks.append(mask)
+
+        return masks
+
+    def read_target(self, item, lbl_path, shape):
+        # print(f'lbl_path:{lbl_path}')
+        h, w = shape
+        labels = []
+        masks = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels, masks = read_masks_from_pixels(lbl_path, shape)
+
+        target = {}
+        target["boxes"] = masks_to_boxes(torch.stack(masks))
+        target["labels"] = torch.stack(labels)
+        target["masks"] = torch.stack(masks)
+        target["image_id"] = torch.tensor(item)
+        target["area"] = torch.zeros(len(masks))
+        target["iscrowd"] = torch.zeros(len(masks))
+        return target
+
+    def heatmap_enhance(self, img):
+        # 直方图均衡化
+        img_eq = cv2.equalizeHist(img)
+
+        # 自适应直方图均衡化
+        # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
+        # img_clahe = clahe.apply(img)
+
+        # 将灰度图转换为热力图
+        heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
+
+    def __len__(self):
+        return len(self.imgs)

+ 31 - 0
models/ins/train.yaml

@@ -0,0 +1,31 @@
+
+
+dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
+
+#train parameters
+num_classes: 5
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.0005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: polygon
+enable_logs: True
+augmentation: True
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 220 - 0
models/ins/trainer.py

@@ -0,0 +1,220 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from tools import utils, presets
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+
+def load_train_parameter(cfg):
+    parameters = read_yaml(cfg)
+    return parameters
+
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = MaskRCNNDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 0 - 0
models/keypoint/__init__.py


+ 290 - 0
models/keypoint/kepointrcnn.py

@@ -0,0 +1,290 @@
+import math
+import os
+import sys
+from collections import OrderedDict
+from datetime import datetime
+from typing import Mapping
+import cv2
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from matplotlib import pyplot as plt
+from torch import nn
+from torch.nn.modules.module import T
+from torchvision.io import read_image
+from torchvision.models import resnet50, ResNet50_Weights, resnet18, ResNet18_Weights
+from torchvision.models._utils import _ovewrite_value_param
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.anchor_utils import AnchorGenerator
+from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor, KeypointRCNN, \
+    KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
+from torchvision.utils import draw_bounding_boxes
+from torchvision.ops import misc as misc_nn_ops, FeaturePyramidNetwork
+from typing import Optional, Any
+from models.config.config_tool import read_yaml
+from models.keypoint.trainer import train_cfg
+from models.wirenet._utils import overwrite_eps
+# from timm import create_model
+from  torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from tools import utils
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+
+class KeypointRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
+        super(KeypointRCNNModel, self).__init__()
+
+        ####mobile net
+       # backbone = torchvision.models.mobilenet_v2(weights=None).features
+       # backbone.out_channels = 1280
+       # anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),aspect_ratios = ((0.5, 1.0, 2.0),))
+       # roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 7,sampling_ratio = 2)
+       #keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 14,sampling_ratio = 2)
+       # self.__model= KeypointRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,keypoint_roi_pool=keypoint_roi_pooler)
+        ####
+
+        # 加载 EfficientNet 模型并移除分类头
+        # backbone = create_model('tf_efficientnet_b0', pretrained=True, features_only=True)
+        # backbone_out_channels =backbone.feature_info.channels()  # 获取所有阶段的通道数
+        #
+        #
+        # # 构建 FPN
+        # fpn = FeaturePyramidNetwork(
+        #     in_channels_list=backbone_out_channels,
+        #     out_channels=256,
+        #     extra_blocks=LastLevelMaxPool()
+        # )
+        #
+        # # 将 EfficientNet 和 FPN 组合成一个新的 backbone
+        # self.body = nn.Sequential(
+        #     backbone,
+        #     fpn
+        # )
+        default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+
+        self.__model = keypointrcnn_resnet18_fpn(weights=None,num_classes=num_classes,
+                                                                              num_keypoints=num_keypoints,
+                                                                              progress=False)
+        # self.__model.backbone.body = nn.Sequential(OrderedDict([
+        #     ('body', self.body),
+        #     ('fpn', fpn)
+        # ]))
+
+        if transforms is None:
+            self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
+        # if num_classes != 0:
+        #     self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes = parameters['num_classes']
+        num_keypoints = parameters['num_keypoints']
+        # print(f'num_classes:{num_classes}')
+        # self.set_num_classes(num_classes)
+        self.num_keypoints = num_keypoints
+        train_cfg(self.__model, cfg)
+
+    # def set_num_classes(self, num_classes):
+    #     in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+    #     self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+    #
+    #     # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+    #     in_channels = self.__model.roi_heads.keypoint_predictor.
+    #     hidden_layer = 256
+    #     self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
+    #                                                               num_classes=num_classes)
+    #     self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+    def eval(self: T) -> T:
+        self.__model.eval()
+        # return super().eval()
+    def predict(self, img, show=True, save=False, save_path=None):
+        """
+         对输入图像进行关键点检测预测。
+
+         参数:
+             img (str or PIL.Image): 输入图像的路径或 PIL.Image 对象。
+             show (bool): 是否显示预测结果,默认为 True。
+             save (bool): 是否保存预测结果,默认为 False。
+
+         返回:
+             dict: 包含预测结果的字典。
+         """
+        if isinstance(img, str):
+            img = Image.open(img).convert("RGB")
+
+        self.__model.eval()
+
+        # 预处理图像
+        img_tensor = self.transforms(img)
+        with torch.no_grad():
+            predictions = self.__model([img_tensor])
+
+        print(f'predictions:{predictions}')
+
+        # 后处理预测结果
+        boxes = predictions[0]['boxes'].cpu().numpy()
+        keypoints = predictions[0]['keypoints'].cpu().numpy()
+
+        # 可视化预测结果
+        if show or save:
+            fig, ax = plt.subplots(figsize=(10, 10))
+            ax.imshow(np.array(img))
+
+            for box in boxes:
+                x0, y0, x1, y1 = box
+                ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))
+
+            for (a, b) in keypoints:
+                ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
+                ax.scatter(a[0], a[1], c='red', s=2)
+                ax.scatter(b[0], b[1], c='red', s=2)
+
+            if show:
+                plt.show()
+
+            if save:
+                fig.savefig(save_path)
+                print(f"Prediction saved to {save_path}")
+            plt.close(fig)
+
+def keypointrcnn_resnet18_fpn(
+        *,
+        weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> KeypointRCNN:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    # if weights_backbone is None:
+
+    weights_backbone = ResNet18_Weights.IMAGENET1K_V1
+
+    if weights is not None:
+        # weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    keypoint_model = KeypointRCNNModel(num_keypoints=2)
+    wts_path='./train_results/20241227_231659/weights/best.pt'
+
+
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    keypoint_model.train(cfg='train.yaml')
+
+    # keypoint_model.load_weight(wts_path)
+    # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-10-02-15_SaveImage.png"
+    # keypoint_model.predict(img_path)

+ 205 - 0
models/keypoint/keypoint_dataset.py

@@ -0,0 +1,205 @@
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+def validate_keypoints(keypoints, image_width, image_height):
+    for kp in keypoints:
+        x, y, v = kp
+        if not (0 <= x < image_width and 0 <= y < image_height):
+            raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
+
+
+class KeypointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+
+        target["labels"] = torch.stack(labels)
+        # print(f'labels:{target["labels"]}')
+        # target["boxes"] = line_boxes(target)
+        target["boxes"], keypoints = line_boxes(target)
+        # keypoints=keypoints/512
+        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
+
+        # keypoints= wire_labels["junc_coords"]
+        a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
+        keypoints = torch.cat((keypoints, a), dim=1)
+
+        target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
+        # print(f'boxes:{target["boxes"].shape}')
+        # 在 __getitem__ 方法中调用此函数
+        validate_keypoints(keypoints, shape[0], shape[1])
+        # print(f'keypoints:{target["keypoints"].shape}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
+
+
+
+if __name__ == '__main__':
+    path=r"I:\datasets\wirenet_1000"
+    dataset= KeypointDataset(dataset_path=path, dataset_type='train')
+    dataset.show(7)

+ 65 - 0
models/keypoint/test.py

@@ -0,0 +1,65 @@
+import time
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.io import decode_image, read_image
+import torchvision.transforms.functional as F
+from torchvision.utils import draw_keypoints
+def show(imgs):
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
+    for i, img in enumerate(imgs):
+        img = img.detach()
+        img = F.to_pil_image(img)
+        axs[0, i].imshow(np.asarray(img))
+        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
+
+img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
+# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
+img_int = read_image(img_path)
+
+
+# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
+
+weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+transforms = weights.transforms()
+print(f'transforms:{transforms}')
+img = transforms(img_int)
+
+person_float = transforms(img)
+
+model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
+model = model.eval()
+t1=time.time()
+# img = torch.ones((3, 3, 512, 512))
+
+
+outputs = model([img])
+t2=time.time()
+print(f'time:{t2-t1}')
+# print(f'outputs:{outputs}')
+
+kpts = outputs[0]['keypoints']
+scores = outputs[0]['scores']
+
+print(f'kpts:{kpts}')
+print(f'scores:{scores}')
+
+detect_threshold = 0.75
+idx = torch.where(scores > detect_threshold)
+keypoints = kpts[idx]
+
+# print(f'keypoints:{keypoints}')
+
+
+
+res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
+show(res)
+plt.show()
+
+
+
+

+ 77 - 0
models/keypoint/test_predict.py

@@ -0,0 +1,77 @@
+import time
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.io import decode_image, read_image
+import torchvision.transforms.functional as F
+from torchvision.utils import draw_keypoints, draw_bounding_boxes
+
+from models.keypoint.kepointrcnn import KeypointRCNNModel
+
+
+def show(imgs):
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
+    for i, img in enumerate(imgs):
+        img = img.detach()
+        img = F.to_pil_image(img)
+        axs[0, i].imshow(np.asarray(img))
+        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
+
+
+# img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
+# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
+img_path = r"I:\datasets\wirenet_1000\images\train\00031644_0.png"
+img_int = read_image(img_path)
+
+# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
+
+device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+
+transforms = weights.transforms()
+
+print(f'transforms:{transforms}')
+img = transforms(img_int)
+
+person_float = transforms(img)
+
+model = KeypointRCNNModel(num_keypoints=2)
+
+print(f'start to load pretraine weight!')
+model.load_weight('./train_results/20241226_171710/weights/best.pt')
+print(f'loaded weight !!!')
+
+# model.to(device)
+model.eval()
+# model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
+# model = model.eval()
+t1 = time.time()
+# img = torch.ones((3, 3, 512, 512))
+
+print(f't1:{t1}')
+outputs = model([img])
+t2 = time.time()
+print(f'time:{t2 - t1}')
+# print(f'outputs:{outputs}')
+
+kpts = outputs[0]['keypoints']
+scores = outputs[0]['scores']
+boxes= outputs[0]['boxes']
+print(f'kpts:{kpts}')
+print(f'scores:{scores}')
+
+detect_threshold = 0.001
+idx = torch.where(scores > detect_threshold)
+keypoints = kpts[idx]
+
+# print(f'keypoints:{keypoints}')
+
+
+res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
+res_box=draw_bounding_boxes(img_int,boxes)
+show(res_box)
+plt.show()

+ 32 - 0
models/keypoint/train.yaml

@@ -0,0 +1,32 @@
+
+
+dataset_path: I:\datasets\wirenet_1000
+
+#train parameters
+num_classes: 2
+num_keypoints: 2
+opt: 'adamw'
+batch_size: 2
+epochs: 50000
+lr: 0.0002
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: pixel
+enable_logs: True
+augmentation: False
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 380 - 0
models/keypoint/trainer.py

@@ -0,0 +1,380 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.wirenet.postprocess import postprocess_keypoint
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+import matplotlib.pyplot as plt
+import numpy as np
+import matplotlib as mpl
+from tools.coco_utils import get_coco_api_from_dataset
+from tools.coco_eval import CocoEvaluator
+import time
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.keypoint.keypoint_dataset import KeypointDataset
+from tools import utils, presets
+
+
+def log_losses_to_tensorboard(writer, result, step):
+    writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
+    writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
+    writer.add_scalar('Loss/keypoint', result['loss_keypoint'].item(), step)
+    writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
+    writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+    total_train_loss=0
+    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+        global_step = epoch * len(data_loader) + batch_idx
+        # print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            # print(f'loss_dict:{loss_dict}')
+
+            losses = sum(loss for loss in loss_dict.values())
+
+            total_train_loss += losses.item()
+            log_losses_to_tensorboard(writer, loss_dict, global_step)
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger, total_train_loss
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)   # [512, 512, 3]
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
+                                      colors="yellow", width=1)
+
+    # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
+    # plt.show()
+
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    lines = pred["keypoints"].detach().cpu().numpy()
+    scores = pred["keypoints_scores"].detach().cpu().numpy()
+    # for i in range(1, len(lines)):
+    #     if (lines[i] == lines[0]).all():
+    #         lines = lines[:i]
+    #         scores = scores[:i]
+    #         break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
+    # print(f'nscores:{nscores}')
+
+    for i, t in enumerate([0.5]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            # plt.scatter(a[1], a[0], **PLTOPTS)
+            # plt.scatter(b[1], b[0], **PLTOPTS)
+            plt.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=2, zorder=s)
+            plt.scatter(a[0], a[1], **PLTOPTS)
+            plt.scatter(b[0], b[1], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im.cpu())
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+
+
+def _get_iou_types(model):
+    model_without_ddp = model
+    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+        model_without_ddp = model.module
+    iou_types = ["bbox"]
+    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
+        iou_types.append("segm")
+    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
+        iou_types.append("keypoints")
+    return iou_types
+
+
+def evaluate(model, data_loader, epoch, writer, device):
+    n_threads = torch.get_num_threads()
+    # FIXME remove this and make paste_masks_in_image run on the GPU
+    torch.set_num_threads(1)
+    cpu_device = torch.device("cpu")
+    model.eval()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    header = "Test:"
+
+    coco = get_coco_api_from_dataset(data_loader.dataset)
+    iou_types = _get_iou_types(model)
+    coco_evaluator = CocoEvaluator(coco, iou_types)
+
+    print(f'start to evaluate!!!')
+    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
+        images = list(img.to(device) for img in images)
+
+        model_time = time.time()
+        outputs = model(images)
+        # print(f'outputs:{outputs}')
+
+        if batch_idx == 0:
+            show_line(images[0], outputs[0], epoch, writer)
+
+        # print(f'outputs:{outputs}')
+        # print(f'outputs[0]:{outputs[0]}')
+
+
+    #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+    #     model_time = time.time() - model_time
+    #
+    #     res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+    #     evaluator_time = time.time()
+    #     coco_evaluator.update(res)
+    #     evaluator_time = time.time() - evaluator_time
+    #     metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+    #
+    # # gather the stats from all processes
+    # metric_logger.synchronize_between_processes()
+    # print("Averaged stats:", metric_logger)
+    # coco_evaluator.synchronize_between_processes()
+    #
+    # # accumulate predictions from all images
+    # coco_evaluator.accumulate()
+    # coco_evaluator.summarize()
+    # torch.set_num_threads(n_threads)
+    # return coco_evaluator
+
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints': 2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint': None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = KeypointDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.RandomSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    )
+
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        # write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+        evaluate(model, data_loader_test, epoch, writer, device=device)
+        avg_train_loss = total_train_loss / len(data_loader)
+
+        writer.add_scalar('Loss/train', avg_train_loss, epoch)
+
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    # writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar('Loss/box_reg', metric_logger.meters['loss_keypoint'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
+
+# def log_losses_to_tensorboard(writer, result, step):
+#     writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
+#     writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
+#     writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
+#     writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
+#     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)

+ 0 - 0
models/line_net/__init__.py


+ 120 - 0
models/line_net/fasterrcnn_resnet50.py

@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+import torchvision
+from typing import Dict, List, Optional, Tuple
+import torch.nn.functional as F
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
+from torchvision.models.detection.transform import GeneralizedRCNNTransform
+
+
+def get_model(num_classes):
+    # 加载预训练的ResNet-50 FPN backbone
+    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+
+    # 获取分类器的输入特征数
+    in_features = model.roi_heads.box_predictor.cls_score.in_features
+
+    # 替换分类器以适应新的类别数量
+    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
+
+    return model
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+class Fasterrcnn_resnet50(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1):
+        super(Fasterrcnn_resnet50, self).__init__()
+
+        self.model = get_model(num_classes=5)
+        self.backbone = self.model.backbone
+
+        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
+
+        out_channels = self.backbone.out_channels
+        resolution = self.box_roi_pool.output_size[0]
+        representation_size = 1024
+        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        # 多任务输出层
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
+
+        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
+                                             image_std=[0.229, 0.224, 0.225])
+        images, targets = transform(x, target1)
+        x_ = self.backbone(images.tensors)
+
+        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
+        # print(f'backbone:{self.backbone}')
+        # print(f'Fasterrcnn_resnet50 x_:{x_}')
+        feature_ = x_['0']  # 图片特征
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(feature_)
+            outputs.append(output)  # 多头
+
+        if train_or_val == "training":
+            loss_box = self.model(x, target1)
+            return outputs, feature_, loss_box
+        else:
+            box_all = self.model(x, target1)
+            return outputs, feature_, box_all
+
+
+def fasterrcnn_resnet50(**kwargs):
+    model = Fasterrcnn_resnet50(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1)
+    )
+    return model

+ 0 - 0
models/obj/__init__.py


+ 44 - 0
models/utils.py

@@ -0,0 +1,44 @@
+# import torch
+#
+#
+# def evaluate(model, data_loader, device):
+#     n_threads = torch.get_num_threads()
+#     # FIXME remove this and make paste_masks_in_image run on the GPU
+#     torch.set_num_threads(1)
+#     cpu_device = torch.device("cpu")
+#     model.eval()
+#     metric_logger = utils.MetricLogger(delimiter="  ")
+#     header = "Test:"
+#
+#     coco = get_coco_api_from_dataset(data_loader.dataset)
+#     iou_types = _get_iou_types(model)
+#     coco_evaluator = CocoEvaluator(coco, iou_types)
+#
+#     print(f'start to evaluate!!!')
+#     for images, targets in metric_logger.log_every(data_loader, 10, header):
+#         images = list(img.to(device) for img in images)
+#
+#         if torch.cuda.is_available():
+#             torch.cuda.synchronize()
+#         model_time = time.time()
+#         outputs = model(images)
+#
+#         outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+#         model_time = time.time() - model_time
+#
+#         res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+#         evaluator_time = time.time()
+#         coco_evaluator.update(res)
+#         evaluator_time = time.time() - evaluator_time
+#         metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+#
+#     # gather the stats from all processes
+#     metric_logger.synchronize_between_processes()
+#     print("Averaged stats:", metric_logger)
+#     coco_evaluator.synchronize_between_processes()
+#
+#     # accumulate predictions from all images
+#     coco_evaluator.accumulate()
+#     coco_evaluator.summarize()
+#     torch.set_num_threads(n_threads)
+#     return coco_evaluator

+ 46 - 0
models/wirenet/TestPointMap.py

@@ -0,0 +1,46 @@
+def map_heatmap_keypoints_to_original_image(heatmap, rois, downsample_ratio=4, joff=None):
+    """
+    将热力图中的关键点映射回原始图像的位置。
+
+    参数:
+    heatmap (torch.Tensor): 热力图,形状为 [H, W]
+    rois (list of tuples): 每个ROI的坐标列表 [(x_min, y_min, x_max, y_max), ...]
+    downsample_ratio (int): 下采样比例,默认为4
+    joff (torch.Tensor, optional): 偏移图,形状为 [2, H, W]
+
+    返回:
+    list of tuples: 每个ROI对应的关键点在原始图像中的坐标 [(x, y), ...]
+    """
+    keypoints_in_original_image = []
+
+    for i, (x_min, y_min, x_max, y_max) in enumerate(rois):
+        roi_width = x_max - x_min
+        roi_height = y_max - y_min
+
+        # 获取热力图中的关键点位置
+        heatmap_roi = heatmap[i] if len(heatmap.shape) == 4 else heatmap
+        y_prime, x_prime = torch.where(heatmap_roi == torch.max(heatmap_roi))
+
+        if len(y_prime) > 0 and len(x_prime) > 0:
+            y_prime, x_prime = y_prime.item(), x_prime.item()
+
+            # 如果有偏移图,则应用偏移修正
+            if joff is not None:
+                offset_x = joff[0, y_prime, x_prime].item()
+                offset_y = joff[1, y_prime, x_prime].item()
+                x_prime += offset_x
+                y_prime += offset_y
+
+            # 计算ROI内的相对坐标
+            relative_x = x_prime / 128 * roi_width
+            relative_y = y_prime / 128 * roi_height
+
+            # 映射回原始图像坐标
+            final_x = relative_x + x_min
+            final_y = relative_y + y_min
+
+            keypoints_in_original_image.append((final_x.item(), final_y.item()))
+        else:
+            keypoints_in_original_image.append(None)  # 如果没有找到关键点
+
+    return keypoints_in_original_image

+ 0 - 0
models/wirenet/__init__.py


+ 548 - 0
models/wirenet/_utils.py

@@ -0,0 +1,548 @@
+import math
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
+
+
+class BalancedPositiveNegativeSampler:
+    """
+    This class samples batches, ensuring that they contain a fixed proportion of positives
+    """
+
+    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
+        """
+        Args:
+            batch_size_per_image (int): number of elements to be selected per image
+            positive_fraction (float): percentage of positive elements per batch
+        """
+        self.batch_size_per_image = batch_size_per_image
+        self.positive_fraction = positive_fraction
+
+    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+        """
+        Args:
+            matched_idxs: list of tensors containing -1, 0 or positive values.
+                Each tensor corresponds to a specific image.
+                -1 values are ignored, 0 are considered as negatives and > 0 as
+                positives.
+
+        Returns:
+            pos_idx (list[tensor])
+            neg_idx (list[tensor])
+
+        Returns two lists of binary masks for each image.
+        The first list contains the positive elements that were selected,
+        and the second list the negative example.
+        """
+        pos_idx = []
+        neg_idx = []
+        for matched_idxs_per_image in matched_idxs:
+            positive = torch.where(matched_idxs_per_image >= 1)[0]
+            negative = torch.where(matched_idxs_per_image == 0)[0]
+
+            num_pos = int(self.batch_size_per_image * self.positive_fraction)
+            # protect against not enough positive examples
+            num_pos = min(positive.numel(), num_pos)
+            num_neg = self.batch_size_per_image - num_pos
+            # protect against not enough negative examples
+            num_neg = min(negative.numel(), num_neg)
+
+            # randomly select positive and negative examples
+            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+            pos_idx_per_image = positive[perm1]
+            neg_idx_per_image = negative[perm2]
+
+            # create binary mask from indices
+            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+
+            pos_idx_per_image_mask[pos_idx_per_image] = 1
+            neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+            pos_idx.append(pos_idx_per_image_mask)
+            neg_idx.append(neg_idx_per_image_mask)
+
+        return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
+    """
+    Encode a set of proposals with respect to some
+    reference boxes
+
+    Args:
+        reference_boxes (Tensor): reference boxes
+        proposals (Tensor): boxes to be encoded
+        weights (Tensor[4]): the weights for ``(x, y, w, h)``
+    """
+
+    # perform some unpacking to make it JIT-fusion friendly
+    wx = weights[0]
+    wy = weights[1]
+    ww = weights[2]
+    wh = weights[3]
+
+    proposals_x1 = proposals[:, 0].unsqueeze(1)
+    proposals_y1 = proposals[:, 1].unsqueeze(1)
+    proposals_x2 = proposals[:, 2].unsqueeze(1)
+    proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+    # implementation starts here
+    ex_widths = proposals_x2 - proposals_x1
+    ex_heights = proposals_y2 - proposals_y1
+    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+    ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+    gt_widths = reference_boxes_x2 - reference_boxes_x1
+    gt_heights = reference_boxes_y2 - reference_boxes_y1
+    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+    targets_dw = ww * torch.log(gt_widths / ex_widths)
+    targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+    return targets
+
+
+class BoxCoder:
+    """
+    This class encodes and decodes a set of bounding boxes into
+    the representation used for training the regressors.
+    """
+
+    def __init__(
+        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
+    ) -> None:
+        """
+        Args:
+            weights (4-element tuple)
+            bbox_xform_clip (float)
+        """
+        self.weights = weights
+        self.bbox_xform_clip = bbox_xform_clip
+
+    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
+        boxes_per_image = [len(b) for b in reference_boxes]
+        reference_boxes = torch.cat(reference_boxes, dim=0)
+        proposals = torch.cat(proposals, dim=0)
+        targets = self.encode_single(reference_boxes, proposals)
+        return targets.split(boxes_per_image, 0)
+
+    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some
+        reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+        """
+        dtype = reference_boxes.dtype
+        device = reference_boxes.device
+        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+        targets = encode_boxes(reference_boxes, proposals, weights)
+
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
+        torch._assert(
+            isinstance(boxes, (list, tuple)),
+            "This function expects boxes of type list or tuple.",
+        )
+        torch._assert(
+            isinstance(rel_codes, torch.Tensor),
+            "This function expects rel_codes of type torch.Tensor.",
+        )
+        boxes_per_image = [b.size(0) for b in boxes]
+        concat_boxes = torch.cat(boxes, dim=0)
+        box_sum = 0
+        for val in boxes_per_image:
+            box_sum += val
+        if box_sum > 0:
+            rel_codes = rel_codes.reshape(box_sum, -1)
+        pred_boxes = self.decode_single(rel_codes, concat_boxes)
+        if box_sum > 0:
+            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+        return pred_boxes
+
+    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+        """
+
+        boxes = boxes.to(rel_codes.dtype)
+
+        widths = boxes[:, 2] - boxes[:, 0]
+        heights = boxes[:, 3] - boxes[:, 1]
+        ctr_x = boxes[:, 0] + 0.5 * widths
+        ctr_y = boxes[:, 1] + 0.5 * heights
+
+        wx, wy, ww, wh = self.weights
+        dx = rel_codes[:, 0::4] / wx
+        dy = rel_codes[:, 1::4] / wy
+        dw = rel_codes[:, 2::4] / ww
+        dh = rel_codes[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=self.bbox_xform_clip)
+        dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        # Distance from center to box's corner.
+        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+
+        pred_boxes1 = pred_ctr_x - c_to_c_w
+        pred_boxes2 = pred_ctr_y - c_to_c_h
+        pred_boxes3 = pred_ctr_x + c_to_c_w
+        pred_boxes4 = pred_ctr_y + c_to_c_h
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+        return pred_boxes
+
+
+class BoxLinearCoder:
+    """
+    The linear box-to-box transform defined in FCOS. The transformation is parameterized
+    by the distance from the center of (square) src box to 4 edges of the target box.
+    """
+
+    def __init__(self, normalize_by_size: bool = True) -> None:
+        """
+        Args:
+            normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
+        """
+        self.normalize_by_size = normalize_by_size
+
+    def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+
+        Returns:
+            Tensor: the encoded relative box offsets that can be used to
+            decode the boxes.
+
+        """
+
+        # get the center of reference_boxes
+        reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
+        reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
+
+        # get box regression transformation deltas
+        target_l = reference_boxes_ctr_x - proposals[..., 0]
+        target_t = reference_boxes_ctr_y - proposals[..., 1]
+        target_r = proposals[..., 2] - reference_boxes_ctr_x
+        target_b = proposals[..., 3] - reference_boxes_ctr_y
+
+        targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
+
+        if self.normalize_by_size:
+            reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
+            reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
+            reference_boxes_size = torch.stack(
+                (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
+            )
+            targets = targets / reference_boxes_size
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+
+        Returns:
+            Tensor: the predicted boxes with the encoded relative box offsets.
+
+        .. note::
+            This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
+
+        """
+
+        boxes = boxes.to(dtype=rel_codes.dtype)
+
+        ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
+        ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
+
+        if self.normalize_by_size:
+            boxes_w = boxes[..., 2] - boxes[..., 0]
+            boxes_h = boxes[..., 3] - boxes[..., 1]
+
+            list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
+            rel_codes = rel_codes * list_box_size
+
+        pred_boxes1 = ctr_x - rel_codes[..., 0]
+        pred_boxes2 = ctr_y - rel_codes[..., 1]
+        pred_boxes3 = ctr_x + rel_codes[..., 2]
+        pred_boxes4 = ctr_y + rel_codes[..., 3]
+
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
+        return pred_boxes
+
+
+class Matcher:
+    """
+    This class assigns to each predicted "element" (e.g., a box) a ground-truth
+    element. Each predicted element will have exactly zero or one matches; each
+    ground-truth element may be assigned to zero or more predicted elements.
+
+    Matching is based on the MxN match_quality_matrix, that characterizes how well
+    each (ground-truth, predicted)-pair match. For example, if the elements are
+    boxes, the matrix may contain box IoU overlap values.
+
+    The matcher returns a tensor of size N containing the index of the ground-truth
+    element m that matches to prediction n. If there is no match, a negative value
+    is returned.
+    """
+
+    BELOW_LOW_THRESHOLD = -1
+    BETWEEN_THRESHOLDS = -2
+
+    __annotations__ = {
+        "BELOW_LOW_THRESHOLD": int,
+        "BETWEEN_THRESHOLDS": int,
+    }
+
+    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
+        """
+        Args:
+            high_threshold (float): quality values greater than or equal to
+                this value are candidate matches.
+            low_threshold (float): a lower quality threshold used to stratify
+                matches into three levels:
+                1) matches >= high_threshold
+                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+            allow_low_quality_matches (bool): if True, produce additional matches
+                for predictions that have only low-quality match candidates. See
+                set_low_quality_matches_ for more details.
+        """
+        self.BELOW_LOW_THRESHOLD = -1
+        self.BETWEEN_THRESHOLDS = -2
+        torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
+        self.high_threshold = high_threshold
+        self.low_threshold = low_threshold
+        self.allow_low_quality_matches = allow_low_quality_matches
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        """
+        Args:
+            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+            pairwise quality between M ground-truth elements and N predicted elements.
+
+        Returns:
+            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+            [0, M - 1] or a negative value indicating that prediction i could not
+            be matched.
+        """
+        if match_quality_matrix.numel() == 0:
+            # empty targets or proposals not supported during training
+            if match_quality_matrix.shape[0] == 0:
+                raise ValueError("No ground-truth boxes available for one of the images during training")
+            else:
+                raise ValueError("No proposal boxes available for one of the images during training")
+
+        # match_quality_matrix is M (gt) x N (predicted)
+        # Max over gt elements (dim 0) to find best gt candidate for each prediction
+        matched_vals, matches = match_quality_matrix.max(dim=0)
+        if self.allow_low_quality_matches:
+            all_matches = matches.clone()
+        else:
+            all_matches = None  # type: ignore[assignment]
+
+        # Assign candidate matches with low quality to negative (unassigned) values
+        below_low_threshold = matched_vals < self.low_threshold
+        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
+        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
+        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
+
+        if self.allow_low_quality_matches:
+            if all_matches is None:
+                torch._assert(False, "all_matches should not be None")
+            else:
+                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+        return matches
+
+    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
+        """
+        Produce additional matches for predictions that have only low-quality matches.
+        Specifically, for each ground-truth find the set of predictions that have
+        maximum overlap with it (including ties); for each prediction in that set, if
+        it is unmatched, then match it to the ground-truth with which it has the highest
+        quality value.
+        """
+        # For each gt, find the prediction with which it has the highest quality
+        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+        # Find the highest quality match available, even if it is low, including ties
+        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
+        # Example gt_pred_pairs_of_highest_quality:
+        #   tensor([[    0, 39796],
+        #           [    1, 32055],
+        #           [    1, 32070],
+        #           [    2, 39190],
+        #           [    2, 40255],
+        #           [    3, 40390],
+        #           [    3, 41455],
+        #           [    4, 45470],
+        #           [    5, 45325],
+        #           [    5, 46390]])
+        # Each row is a (gt index, prediction index)
+        # Note how gt items 1, 2, 3, and 5 each have two ties
+
+        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
+
+
+class SSDMatcher(Matcher):
+    def __init__(self, threshold: float) -> None:
+        super().__init__(threshold, threshold, allow_low_quality_matches=False)
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        matches = super().__call__(match_quality_matrix)
+
+        # For each gt, find the prediction with which it has the highest quality
+        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
+        matches[highest_quality_pred_foreach_gt] = torch.arange(
+            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
+        )
+
+        return matches
+
+
+def overwrite_eps(model: nn.Module, eps: float) -> None:
+    """
+    This method overwrites the default eps values of all the
+    FrozenBatchNorm2d layers of the model with the provided value.
+    This is necessary to address the BC-breaking change introduced
+    by the bug-fix at pytorch/vision#2933. The overwrite is applied
+    only when the pretrained weights are loaded to maintain compatibility
+    with previous versions.
+
+    Args:
+        model (nn.Module): The model on which we perform the overwrite.
+        eps (float): The new value of eps.
+    """
+    for module in model.modules():
+        if isinstance(module, FrozenBatchNorm2d):
+            module.eps = eps
+
+
+def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
+    """
+    This method retrieves the number of output channels of a specific model.
+
+    Args:
+        model (nn.Module): The model for which we estimate the out_channels.
+            It should return a single Tensor or an OrderedDict[Tensor].
+        size (Tuple[int, int]): The size (wxh) of the input.
+
+    Returns:
+        out_channels (List[int]): A list of the output channels of the model.
+    """
+    in_training = model.training
+    model.eval()
+
+    with torch.no_grad():
+        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
+        device = next(model.parameters()).device
+        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
+        features = model(tmp_img)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        out_channels = [x.size(1) for x in features.values()]
+
+    if in_training:
+        model.train()
+
+    return out_channels
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> int:
+    return v  # type: ignore[return-value]
+
+
+def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
+    """
+    ONNX spec requires the k-value to be less than or equal to the number of inputs along
+    provided dim. Certain models use the number of elements along a particular axis instead of K
+    if K exceeds the number of elements along that axis. Previously, python's min() function was
+    used to determine whether to use the provided k-value or the specified dim axis value.
+
+    However, in cases where the model is being exported in tracing mode, python min() is
+    static causing the model to be traced incorrectly and eventually fail at the topk node.
+    In order to avoid this situation, in tracing mode, torch.min() is used instead.
+
+    Args:
+        input (Tensor): The original input tensor.
+        orig_kval (int): The provided k-value.
+        axis(int): Axis along which we retrieve the input size.
+
+    Returns:
+        min_kval (int): Appropriately selected k-value.
+    """
+    if not torch.jit.is_tracing():
+        return min(orig_kval, input.size(axis))
+    axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
+    min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
+    return _fake_cast_onnx(min_kval)
+
+
+def _box_loss(
+    type: str,
+    box_coder: BoxCoder,
+    anchors_per_image: Tensor,
+    matched_gt_boxes_per_image: Tensor,
+    bbox_regression_per_image: Tensor,
+    cnf: Optional[Dict[str, float]] = None,
+) -> Tensor:
+    torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
+
+    if type == "l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+    elif type == "smooth_l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
+        return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
+    else:
+        bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
+        eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
+        if type == "ciou":
+            return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        if type == "diou":
+            return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        # otherwise giou
+        return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

+ 1290 - 0
models/wirenet/head.py

@@ -0,0 +1,1290 @@
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+from torch.utils.data.dataloader import default_collate
+
+
+def l2loss(input, target):
+    return ((target - input) ** 2).mean(2).mean(1)
+
+
+def cross_entropy_loss(logits, positive):
+    nlogp = -F.log_softmax(logits, dim=0)
+    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
+
+
+def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
+    logp = torch.sigmoid(logits) + offset
+    loss = torch.abs(logp - target)
+    if mask is not None:
+        w = mask.mean(2, True).mean(1, True)
+        w[w == 0] = 1
+        loss = loss * (mask / w)
+
+    return loss.mean(2).mean(1)
+
+
+# def wirepoint_loss(target, outputs, feature, loss_weight,mode):
+#     wires = target['wires']
+#     result = {"feature": feature}
+#     batch, channel, row, col = outputs[0].shape
+#     print(f"Initial Output[0] shape: {outputs[0].shape}")  # 打印初始输出形状
+#     print(f"Total Stacks: {len(outputs)}")  # 打印堆栈数
+#
+#     T = wires.copy()
+#     n_jtyp = T["junc_map"].shape[1]
+#     for task in ["junc_map"]:
+#         T[task] = T[task].permute(1, 0, 2, 3)
+#     for task in ["junc_offset"]:
+#         T[task] = T[task].permute(1, 2, 0, 3, 4)
+#
+#     offset = self.head_off
+#     loss_weight = loss_weight
+#     losses = []
+#
+#     for stack, output in enumerate(outputs):
+#         output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+#         print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+#         jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+#         lmap = output[offset[0]: offset[1]].squeeze(0)
+#         joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+#
+#         if stack == 0:
+#             result["preds"] = {
+#                 "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+#                 "lmap": lmap.sigmoid(),
+#                 "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+#             }
+#             # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+#             # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+#             # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+#
+#             if mode == "testing":
+#                 return result
+#
+#         L = OrderedDict()
+#         L["junc_map"] = sum(
+#             cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+#         )
+#         L["line_map"] = (
+#             F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+#             .mean(2)
+#             .mean(1)
+#         )
+#         L["junc_offset"] = sum(
+#             sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+#             for i in range(n_jtyp)
+#             for j in range(2)
+#         )
+#         for loss_name in L:
+#             L[loss_name].mul_(loss_weight[loss_name])
+#         losses.append(L)
+#
+#     result["losses"] = losses
+#     return result
+
+def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
+    # output, feature: head返回结果
+    # x, y, idx : line中间生成结果
+    result = {}
+    batch, channel, row, col = output.shape
+
+    wires_targets = [t["wires"] for t in targets]
+    wires_targets = wires_targets.copy()
+    # print(f'wires_target:{wires_targets}')
+    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+    junc_maps = [d["junc_map"] for d in wires_targets]
+    junc_offsets = [d["junc_offset"] for d in wires_targets]
+    line_maps = [d["line_map"] for d in wires_targets]
+
+    junc_map_tensor = torch.stack(junc_maps, dim=0)
+    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+    line_map_tensor = torch.stack(line_maps, dim=0)
+    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
+
+    n_jtyp = T["junc_map"].shape[1]
+
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+    lmap = output[offset[0]: offset[1]].squeeze(0)
+    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+    L = OrderedDict()
+    L["junc_map"] = sum(
+        cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+    )
+    L["line_map"] = (
+        F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+            .mean(2)
+            .mean(1)
+    )
+    L["junc_offset"] = sum(
+        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+        for i in range(n_jtyp)
+        for j in range(2)
+    )
+    for loss_name in L:
+        L[loss_name].mul_(loss_weight[loss_name])
+    losses.append(L)
+    result["losses"] = losses
+
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
+    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+
+    return result
+
+
+def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
+    result = {}
+    result["wires"] = {}
+    p = torch.cat(ps)
+    s = torch.sigmoid(input)
+    b = s > 0.5
+    lines = []
+    score = []
+    # print(f"n_batch:{n_batch}")
+    for i in range(n_batch):
+        # print(f"idx:{idx}")
+        p0 = p[idx[i]: idx[i + 1]]
+        s0 = s[idx[i]: idx[i + 1]]
+        mask = b[idx[i]: idx[i + 1]]
+        p0 = p0[mask]
+        s0 = s0[mask]
+        if len(p0) == 0:
+            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+            score.append(torch.zeros([1, n_out_line], device=p.device))
+        else:
+            arg = torch.argsort(s0, descending=True)
+            p0, s0 = p0[arg], s0[arg]
+            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+        for j in range(len(jcs[i])):
+            if len(jcs[i][j]) == 0:
+                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+            jcs[i][j] = jcs[i][j][
+                None, torch.arange(n_out_junc) % len(jcs[i][j])
+            ]
+    result["wires"]["lines"] = torch.cat(lines)
+    result["wires"]["score"] = torch.cat(score)
+    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+
+    if len(jcs[i]) > 1:
+        result["preds"]["junts"] = torch.cat(
+            [jcs[i][1] for i in range(n_batch)]
+        )
+
+    return result
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
+    # print(f'mask discretization_size:{discretization_size}')
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    # print(f'mask labels:{labels}')
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    # print(f'mask labels1:{labels}')
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+    # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
+    # print(f'mask_targets:{mask_targets}')
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    # print(f'mask_loss:{mask_loss}')
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+            .index_select(2, x_int.to(dtype=torch.int64))
+            .view(-1)
+            .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # print(f'x:{x.shape}')
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+    # print(f'x2:{x2}')
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+            self,
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            # Faster R-CNN training
+            fg_iou_thresh,
+            bg_iou_thresh,
+            batch_size_per_image,
+            positive_fraction,
+            bbox_reg_weights,
+            # Faster R-CNN inference
+            score_thresh,
+            nms_thresh,
+            detections_per_img,
+            # Mask
+            mask_roi_pool=None,
+            mask_head=None,
+            mask_predictor=None,
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
+            wirepoint_roi_pool=None,
+            wirepoint_head=None,
+            wirepoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+        self.wirepoint_roi_pool = wirepoint_roi_pool
+        self.wirepoint_head = wirepoint_head
+        self.wirepoint_predictor = wirepoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def has_wirepoint(self):
+        if self.wirepoint_roi_pool is None:
+            print(f'wirepoint_roi_pool is None')
+            return False
+        if self.wirepoint_head is None:
+            print(f'wirepoint_head is None')
+            return False
+        if self.wirepoint_predictor is None:
+            print(f'wirepoint_roi_predictor is None')
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+            self,
+            proposals,  # type: List[Tensor]
+            targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+            self,
+            class_logits,  # type: Tensor
+            box_regression,  # type: Tensor
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+            self,
+            features,  # type: Dict[str, Tensor]
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            print('result append boxes!!!')
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if self.has_keypoint():
+
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            # tmp = keypoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
+            keypoint_features = self.keypoint_head(keypoint_features)
+
+            # print(f'keypoint_features:{keypoint_features.shape}')
+            tmp = keypoint_features[0][0]
+            plt.imshow(tmp.detach().numpy())
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+            # print(f'keypoint_logits:{keypoint_logits.shape}')
+            """
+            接wirenet
+            """
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        if self.has_wirepoint():
+            print(f'wirepoint result:{result}')
+            wirepoint_proposals = [p["boxes"] for p in result]
+
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                wirepoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    wirepoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            # print(f'proposals:{len(proposals)}')
+            print(f'wirepoint_proposals:{wirepoint_proposals}')
+            wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
+
+            # tmp = keypoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            print(f'wirepoint_features from roi_pool:{wirepoint_features.shape}')
+            outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
+            print(f'outputs1 from head:{outputs.shape}')
+
+            outputs = merge_features(outputs, wirepoint_proposals)
+
+
+
+            # wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
+
+            print(f'outpust:{outputs.shape}')
+
+            wirepoint_logits = self.wirepoint_predictor(inputs=outputs, features=wirepoint_features, targets=targets)
+            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
+
+            # print(f'keypoint_features:{wirepoint_features.shape}')
+            if self.training:
+
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+
+                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+
+            else:
+                pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                result.append(pred)
+
+                loss_wirepoint = {}
+
+                # loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+                # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                # loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+
+            # tmp = wirepoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            # wirepoint_logits = self.wirepoint_predictor((outputs,wirepoint_features))
+            # print(f'keypoint_logits:{wirepoint_logits.shape}')
+
+            # loss_wirepoint = {}    lm
+            # result=wirepoint_logits
+
+            # result.append(pred)    lm
+            losses.update(loss_wirepoint)
+        # print(f"result{result}")
+        # print(f"losses{losses}")
+
+        return result, losses
+
+
+# def merge_features(features, proposals):
+#     # 假设 roi_pool_features 是你的输入张量,形状为 [600, 256, 128, 128]
+#
+#     # 使用 torch.split 按照每个图像的提议数量分割 features
+#     proposals_count = sum([p.size(0) for p in proposals])
+#     features_size = features.size(0)
+#     # (f'proposals sum:{proposals_count},features batch:{features.size(0)}')
+#     if proposals_count != features_size:
+#         raise ValueError("The length of proposals must match the batch size of features.")
+#
+#     split_features = []
+#     start_idx = 0
+#     print(f"proposals:{proposals}")
+#     for proposal in proposals:
+#         # 提取当前图像的特征
+#         current_features = features[start_idx:start_idx + proposal.size(0)]
+#         # print(f'current_features:{current_features.shape}')
+#         split_features.append(current_features)
+#         start_idx += 1
+#
+#     features_imgs = []
+#     for features_per_img in split_features:
+#         features_per_img, _ = torch.max(features_per_img, dim=0, keepdim=True)
+#         features_imgs.append(features_per_img)
+#
+#     merged_features = torch.cat(features_imgs, dim=0)
+#     # print(f' merged_features:{merged_features.shape}')
+#     return merged_features
+
+def merge_features(features, proposals):
+    print(f'features in merge_features:{features.shape}')
+    print(f'proposals:{len(proposals)}')
+    def diagnose_input(features, proposals):
+        """诊断输入数据"""
+        print("Input Diagnostics:")
+        print(f"Features type: {type(features)}, shape: {features.shape}")
+        print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
+        for i, p in enumerate(proposals):
+            print(f"Proposal {i} shape: {p.shape}")
+
+    def validate_inputs(features, proposals):
+        """验证输入的有效性"""
+        if features is None or proposals is None:
+            raise ValueError("Features or proposals cannot be None")
+
+        proposals_count = sum([p.size(0) for p in proposals])
+        features_size = features.size(0)
+
+        if proposals_count != features_size:
+            raise ValueError(
+                f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
+            )
+
+    def safe_max_reduction(features_per_img,proposals):
+
+        print(f'proposal:{proposals.shape},features_per_img:{features_per_img.shape}')
+        """安全的最大值压缩"""
+        if features_per_img.numel() == 0:
+            return torch.zeros_like(features_per_img).unsqueeze(0)
+
+        for feature_map,roi in zip(features_per_img,proposals):
+            print(f'feature_map:{feature_map.shape},roi:{roi}')
+            roi_off_x=roi[0]
+            roi_off_y=roi[1]
+
+
+        try:
+            # 沿着第0维求最大值,保持维度
+            max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
+            return max_features
+        except Exception as e:
+            print(f"Max reduction error: {e}")
+            return features_per_img.unsqueeze(0)
+
+    try:
+        # 诊断输入(可选)
+        # diagnose_input(features, proposals)
+
+        # 验证输入
+        validate_inputs(features, proposals)
+
+        # 分割特征
+        split_features = []
+        start_idx = 0
+
+        for proposal in proposals:
+            # 提取当前图像的特征
+            current_features = features[start_idx:start_idx + proposal.size(0)]
+            split_features.append(current_features)
+            start_idx += proposal.size(0)
+
+        # 每张图像特征压缩
+        features_imgs = []
+
+        print(f'split_features:{len(split_features)}')
+        for features_per_img,proposal in zip(split_features,proposals):
+            compressed_features = safe_max_reduction(features_per_img,proposal)
+            features_imgs.append(compressed_features)
+
+        # 合并特征
+        merged_features = torch.cat(features_imgs, dim=0)
+
+        return merged_features
+
+    except Exception as e:
+        print(f"Error in merge_features: {e}")
+        # 返回原始特征或None
+        return features
+

+ 126 - 0
models/wirenet/postprocess.py

@@ -0,0 +1,126 @@
+import numpy as np
+
+
+def pline(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def psegment(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def plambda(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+
+
+def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                    min(
+                        max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                        max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                    )
+                    > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(score)
+    return np.array(nlines), np.array(nscores)
+
+
+def postprocess_keypoint(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                    min(
+                        max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                        max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                    )
+                    > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(min(score[0],score[1]))
+    return np.array(nlines), np.array(nscores)

+ 896 - 0
models/wirenet/roi_head.py

@@ -0,0 +1,896 @@
+from typing import Dict, List, Optional, Tuple
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
+    # print(f'mask discretization_size:{discretization_size}')
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    # print(f'mask labels:{labels}')
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    # print(f'mask labels1:{labels}')
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+    # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
+    # print(f'mask_targets:{mask_targets}')
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    # print(f'mask_loss:{mask_loss}')
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    print(f'x:{x.shape}')
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+    print(f'x2:{x2}')
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            # tmp = keypoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
+            keypoint_features = self.keypoint_head(keypoint_features)
+
+            print(f'keypoint_features:{keypoint_features.shape}')
+            tmp=keypoint_features[0][0]
+            plt.imshow(tmp.detach().numpy())
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+            print(f'keypoint_logits:{keypoint_logits.shape}')
+            """
+            接wirenet
+            """
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 23 - 0
models/wirenet/test.py

@@ -0,0 +1,23 @@
+from models.wirenet.wirepoint_dataset import WirePointDataset
+from models.config.config_tool import read_yaml
+
+# image_file = "D:/python/PycharmProjects/data"
+#
+# label_file = "D:/python/PycharmProjects/data/labels/train"
+# dataset_test = WireDataset(image_file)
+# dataset_test.show(0)
+# for i in dataset_test:
+#     print(i)
+cfg = 'wirenet.yaml'
+cfg = read_yaml(cfg)
+print(f'cfg:{cfg}')
+print(cfg['model']['n_dyn_negl'])
+# net = WirepointPredictor()
+
+dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+# dataset.show(0)
+
+for i in range(len(dataset)):
+    dataset.show(i)
+
+

+ 14 - 0
models/wirenet/test_mask.py

@@ -0,0 +1,14 @@
+import torch
+from matplotlib import pyplot as plt
+
+img=torch.ones((128,128,3))
+mask=torch.zeros((128,128,3))
+
+mask[0:30,:,:]=1
+
+
+img[mask==1]=0
+
+
+plt.imshow(img)
+plt.show()

+ 20 - 0
models/wirenet/train.py

@@ -0,0 +1,20 @@
+def train_epoch(model):
+    pass
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        print(f"{name}:{loss}")
+        total_loss += loss
+
+    return total_loss
+

+ 69 - 0
models/wirenet/wirenet.yaml

@@ -0,0 +1,69 @@
+io:
+  logdir: logs/
+  datadir: I:/wirenet_dateset
+  resume_from:
+  num_workers: 4
+  tensorboard_port: 0
+  validation_interval: 24000
+
+model:
+  image:
+      mean: [109.730, 103.832, 98.681]
+      stddev: [22.275, 22.124, 23.229]
+
+  batch_size: 2
+  batch_size_eval: 2
+
+  # backbone multi-task parameters
+  head_size: [[2], [1], [2]]
+  loss_weight:
+    jmap: 8.0
+    lmap: 0.5
+    joff: 0.25
+    lpos: 1
+    lneg: 1
+
+  # backbone parameters
+  backbone: stacked_hourglass
+  depth: 4
+  num_stacks: 2
+  num_blocks: 1
+
+  # sampler parameters
+  ## static sampler
+  n_stc_posl: 300
+  n_stc_negl: 40
+
+  ## dynamic sampler
+  n_dyn_junc: 300
+  n_dyn_posl: 300
+  n_dyn_negl: 80
+  n_dyn_othr: 600
+
+  # LOIPool layer parameters
+  n_pts0: 32
+  n_pts1: 8
+
+  # line verification network parameters
+  dim_loi: 128
+  dim_fc: 1024
+
+  # maximum junction and line outputs
+  n_out_junc: 250
+  n_out_line: 2500
+
+  # additional ablation study parameters
+  use_cood: 0
+  use_slop: 0
+  use_conv: 0
+
+  # junction threashold for evaluation (See #5)
+  eval_junc_thres: 0.008
+
+optim:
+  name: Adam
+  lr: 4.0e-4
+  amsgrad: True
+  weight_decay: 1.0e-4
+  max_epoch: 24
+  lr_decay_epoch: 10

+ 178 - 0
models/wirenet/wirepoint_dataset.py

@@ -0,0 +1,178 @@
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+
+class WirePointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+        target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
+
+
+

+ 847 - 0
models/wirenet/wirepoint_rcnn.py

@@ -0,0 +1,847 @@
+import os
+from typing import Optional, Any
+
+import cv2
+import numpy as np
+import torch
+from tensorboardX import SummaryWriter
+from torch import nn
+import torch.nn.functional as F
+# from torchinfo import summary
+from torchvision.io import read_image
+from torchvision.models import resnet50, ResNet50_Weights
+from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection._utils import overwrite_eps
+from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
+from torchvision.models.detection.keypoint_rcnn import KeypointRCNNHeads, KeypointRCNNPredictor, \
+    KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.ops import misc as misc_nn_ops
+# from visdom import Visdom
+
+from models.config import config_tool
+from models.config.config_tool import read_yaml
+from models.ins.trainer import get_transform
+from models.wirenet.head import RoIHeads
+from models.wirenet.wirepoint_dataset import WirePointDataset
+from tools import utils
+
+from torch.utils.tensorboard import SummaryWriter
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from skimage import io
+import os.path as osp
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+from models.wirenet.postprocess import postprocess
+
+FEATURE_DIM = 8
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+print(f"Using device: {device}")
+
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+class Bottleneck1D(nn.Module):
+    def __init__(self, inplanes, outplanes):
+        super(Bottleneck1D, self).__init__()
+
+        planes = outplanes // 2
+        self.op = nn.Sequential(
+            nn.BatchNorm1d(inplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(inplanes, planes, kernel_size=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, outplanes, kernel_size=1),
+        )
+
+    def forward(self, x):
+        return x + self.op(x)
+
+
+class WirepointRCNN(FasterRCNN):
+    def __init__(
+            self,
+            backbone,
+            num_classes=None,
+            # transform parameters
+            min_size=None,
+            max_size=1333,
+            image_mean=None,
+            image_std=None,
+            # RPN parameters
+            rpn_anchor_generator=None,
+            rpn_head=None,
+            rpn_pre_nms_top_n_train=2000,
+            rpn_pre_nms_top_n_test=1000,
+            rpn_post_nms_top_n_train=2000,
+            rpn_post_nms_top_n_test=1000,
+            rpn_nms_thresh=0.7,
+            rpn_fg_iou_thresh=0.7,
+            rpn_bg_iou_thresh=0.3,
+            rpn_batch_size_per_image=256,
+            rpn_positive_fraction=0.5,
+            rpn_score_thresh=0.0,
+            # Box parameters
+            box_roi_pool=None,
+            box_head=None,
+            box_predictor=None,
+            box_score_thresh=0.05,
+            box_nms_thresh=0.5,
+            box_detections_per_img=100,
+            box_fg_iou_thresh=0.5,
+            box_bg_iou_thresh=0.5,
+            box_batch_size_per_image=512,
+            box_positive_fraction=0.25,
+            bbox_reg_weights=None,
+            # keypoint parameters
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
+            num_keypoints=None,
+            wirepoint_roi_pool=None,
+            wirepoint_head=None,
+            wirepoint_predictor=None,
+            **kwargs,
+    ):
+        if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+            )
+        if min_size is None:
+            min_size = (640, 672, 704, 736, 768, 800)
+
+        if num_keypoints is not None:
+            if keypoint_predictor is not None:
+                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        else:
+            num_keypoints = 17
+
+        out_channels = backbone.out_channels
+
+        if wirepoint_roi_pool is None:
+            wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
+                                                    sampling_ratio=2, )
+
+        if wirepoint_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            # print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
+            wirepoint_head = WirepointHead(out_channels, keypoint_layers)
+
+        if wirepoint_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            wirepoint_predictor = WirepointPredictor()
+
+        super().__init__(
+            backbone,
+            num_classes,
+            # transform parameters
+            min_size,
+            max_size,
+            image_mean,
+            image_std,
+            # RPN-specific parameters
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_pre_nms_top_n_train,
+            rpn_pre_nms_top_n_test,
+            rpn_post_nms_top_n_train,
+            rpn_post_nms_top_n_test,
+            rpn_nms_thresh,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_score_thresh,
+            # Box parameters
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            **kwargs,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            # wirepoint_roi_pool=wirepoint_roi_pool,
+            # wirepoint_head=wirepoint_head,
+            # wirepoint_predictor=wirepoint_predictor,
+        )
+        self.roi_heads = roi_heads
+
+        self.roi_heads.wirepoint_roi_pool = wirepoint_roi_pool
+        self.roi_heads.wirepoint_head = wirepoint_head
+        self.roi_heads.wirepoint_predictor = wirepoint_predictor
+
+
+class WirepointHead(nn.Module):
+    def __init__(self, input_channels, num_class):
+        super(WirepointHead, self).__init__()
+        self.head_size = [[2], [1], [2]]
+        m = int(input_channels / 4)
+        heads = []
+        # print(f'M.head_size:{M.head_size}')
+        # for output_channels in sum(M.head_size, []):
+        for output_channels in sum(self.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+
+    def forward(self, x):
+        # for idx, head in enumerate(self.heads):
+        #     print(f'{idx},multitask head:{head(x).shape},input x:{x.shape}')
+
+        outputs = torch.cat([head(x) for head in self.heads], dim=1)
+
+        features = x
+        return outputs, features
+
+
+class WirepointPredictor(nn.Module):
+
+    def __init__(self):
+        super().__init__()
+        # self.backbone = backbone
+        # self.cfg = read_yaml(cfg)
+        self.cfg = read_yaml('wirenet.yaml')
+        self.n_pts0 = self.cfg['model']['n_pts0']
+        self.n_pts1 = self.cfg['model']['n_pts1']
+        self.n_stc_posl = self.cfg['model']['n_stc_posl']
+        self.dim_loi = self.cfg['model']['dim_loi']
+        self.use_conv = self.cfg['model']['use_conv']
+        self.dim_fc = self.cfg['model']['dim_fc']
+        self.n_out_line = self.cfg['model']['n_out_line']
+        self.n_out_junc = self.cfg['model']['n_out_junc']
+        self.loss_weight = self.cfg['model']['loss_weight']
+        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
+        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
+        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
+        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
+        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
+        self.use_cood = self.cfg['model']['use_cood']
+        self.use_slop = self.cfg['model']['use_slop']
+        self.n_stc_negl = self.cfg['model']['n_stc_negl']
+        self.head_size = self.cfg['model']['head_size']
+        self.num_class = sum(sum(self.head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in self.head_size])
+
+        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
+        scale_factor = self.n_pts0 // self.n_pts1
+        if self.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(self.dim_loi, self.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, inputs, features, targets=None):
+
+        # outputs, features = input
+        # for out in outputs:
+        #     print(f'out:{out.shape}')
+        # outputs=merge_features(outputs,100)
+        batch, channel, row, col = inputs.shape
+        # print(f'outputs:{inputs.shape}')
+        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
+
+        if targets is not None:
+            self.training = True
+            # print(f'target:{targets}')
+            wires_targets = [t["wires"] for t in targets]
+            # print(f'wires_target:{wires_targets}')
+            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+            junc_maps = [d["junc_map"] for d in wires_targets]
+            junc_offsets = [d["junc_offset"] for d in wires_targets]
+            line_maps = [d["line_map"] for d in wires_targets]
+
+            junc_map_tensor = torch.stack(junc_maps, dim=0)
+            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+            line_map_tensor = torch.stack(line_maps, dim=0)
+
+            wires_meta = {
+                "junc_map": junc_map_tensor,
+                "junc_offset": junc_offset_tensor,
+                # "line_map": line_map_tensor,
+            }
+        else:
+            self.training = False
+            t = {
+                "junc_coords": torch.zeros(1, 2).to(device),
+                "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
+            }
+            wires_targets = [t for b in range(inputs.size(0))]
+
+            wires_meta = {
+                "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
+            }
+
+        T = wires_meta.copy()
+        n_jtyp = T["junc_map"].shape[1]
+        offset = self.head_off
+        result = {}
+        for stack, output in enumerate([inputs]):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+
+        h = result["preds"]
+        # print(f'features shape:{features.shape}')
+        x = self.fc1(features)
+        n_batch, n_channel, row, col = x.shape
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+
+        for i, meta in enumerate(wires_targets):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i],
+            )
+            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            ys.append(label)
+            if self.training and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                    .reshape(n_channel, -1, self.n_pts0)
+                    .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            # print(f'xp.shape:{xp.shape}')
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+            # print(f'idx__:{idx}')
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+        x = torch.cat([x, f], 1)
+        x = x.to(dtype=torch.float32)
+        x = self.fc2(x).flatten()
+
+        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
+        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+
+        # if mode != "training":
+        # self.inference(x, idx, jcs, n_batch, ps)
+
+        # return result
+
+    def sample_lines(self, meta, jmap, joff):
+        with torch.no_grad():
+            junc = meta["junc_coords"]  # [N, 2]
+            jtyp = meta["jtyp"]  # [N]
+            Lpos = meta["line_pos_idx"]
+            Lneg = meta["line_neg_idx"]
+
+            n_type = jmap.shape[0]
+            print(f'jmap:{jmap.shape}')
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = self.n_dyn_junc // n_type
+            N = len(junc)
+            # if mode != "training":
+            if not self.training:
+                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # print(f"xy_.is_cuda: {xy_.is_cuda}")
+            # print(f"junc.is_cuda: {junc.is_cuda}")
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            # if mode == "training":
+            if self.training:
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > self.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > self.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * self.use_cood,
+                    xyv / 128 * self.use_cood,
+                    u2v * self.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
+
+def wirepointrcnn_resnet50_fpn(
+        *,
+        weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> WirepointRCNN:
+    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = WirepointRCNN(backbone, num_classes=5, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress))
+        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+
+    return total_loss
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+    # plt.show()
+
+
+# def _plot_samples(img, i, result, prefix, epoch):
+#     print(f"prefix:{prefix}")
+#     def draw_vecl(lines, sline, juncs, junts, fn):
+#         directory = os.path.dirname(fn)
+#         if not os.path.exists(directory):
+#             os.makedirs(directory)
+#         imshow(img.permute(1, 2, 0))
+#         if len(lines) > 0 and not (lines[0] == 0).all():
+#             for i, ((a, b), s) in enumerate(zip(lines, sline)):
+#                 if i > 0 and (lines[i] == lines[0]).all():
+#                     break
+#                 plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+#         if not (juncs[0] == 0).all():
+#             for i, j in enumerate(juncs):
+#                 if i > 0 and (i == juncs[0]).all():
+#                     break
+#                 plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+#         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+#             for i, j in enumerate(junts):
+#                 if i > 0 and (i == junts[0]).all():
+#                     break
+#                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+#         plt.savefig(fn), plt.close()
+#
+#     rjuncs = result["juncs"][i].cpu().numpy() * 4
+#     rjunts = None
+#     if "junts" in result:
+#         rjunts = result["junts"][i].cpu().numpy() * 4
+#
+#     vecl_result = result["lines"][i].cpu().numpy() * 4
+#     score = result["score"][i].cpu().numpy()
+#
+#     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+#
+#     img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
+#     writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
+
+def _plot_samples(img, i, result, prefix, epoch, writer):
+    # print(f"prefix:{prefix}")
+
+    def draw_vecl(lines, sline, juncs, junts, fn):
+        # 确保目录存在
+        directory = os.path.dirname(fn)
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+
+        # 绘制图像
+        plt.figure()
+        plt.imshow(img.permute(1, 2, 0).cpu().numpy())
+        plt.axis('off')  # 可选:关闭坐标轴
+
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for idx, ((a, b), s) in enumerate(zip(lines, sline)):
+                if idx > 0 and (lines[idx] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1)
+
+        if not (juncs[0] == 0).all():
+            for idx, j in enumerate(juncs):
+                if idx > 0 and (j == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=20, zorder=100)
+
+        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+            for idx, j in enumerate(junts):
+                if idx > 0 and (j == junts[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="blue", s=20, zorder=100)
+
+        # plt.show()
+
+        # 将matplotlib图像转换为numpy数组
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+
+        return image_from_plot
+
+    # 获取结果数据并转换为numpy数组
+    rjuncs = result["juncs"][i].cpu().numpy() * 4
+    rjunts = None
+    if "junts" in result:
+        rjunts = result["junts"][i].cpu().numpy() * 4
+
+    vecl_result = result["lines"][i].cpu().numpy() * 4
+    score = result["score"][i].cpu().numpy()
+
+    # 调用绘图函数并获取图像
+    image_path = f"{prefix}_vecl_b.jpg"
+    image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path)
+
+    # 将numpy数组转换为torch tensor,并写入TensorBoard
+    image_tensor = transforms.ToTensor()(image_array)
+    writer.add_image(f'output_epoch', image_tensor, global_step=epoch)
+    writer.add_image(f'ori_epoch', img, global_step=epoch)
+
+
+def show_line(img, pred, prefix, epoch, write):
+    fn = f"{prefix}_line.jpg"
+    directory = os.path.dirname(fn)
+    if not os.path.exists(directory):
+        os.makedirs(directory)
+    print(fn)
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred
+
+    im = img.permute(1, 2, 0)
+
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.5]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.savefig(fn, bbox_inches="tight")
+        plt.show()
+        plt.close()
+
+
+        img2 = cv2.imread(fn)  # 预测图
+        # img1 = im.resize(img2.shape)  # 原图
+
+        # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC')
+        writer.add_image("output", img2, epoch)
+
+
+if __name__ == '__main__':
+    cfg = 'wirenet.yaml'
+    cfg = read_yaml(cfg)
+    print(f'cfg:{cfg}')
+    print(cfg['model']['n_dyn_negl'])
+    # net = WirepointPredictor()
+
+    # if torch.cuda.is_available():
+    #     device_name = "cuda"
+    #     torch.backends.cudnn.deterministic = True
+    #     torch.cuda.manual_seed(0)
+    #     print("Let's use", torch.cuda.device_count(), "GPU(s)!")
+    # else:
+    #     print("CUDA is not available")
+    #
+    # device = torch.device(device_name)
+
+    dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
+    train_sampler = torch.utils.data.RandomSampler(dataset_train)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+    train_collate_fn = utils.collate_fn_wirepoint
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
+    )
+
+    dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+    val_sampler = torch.utils.data.RandomSampler(dataset_val)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
+    val_collate_fn = utils.collate_fn_wirepoint
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
+    )
+
+    model = wirepointrcnn_resnet50_fpn().to(device)
+
+    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
+    writer = SummaryWriter(cfg['io']['logdir'])
+
+
+    def move_to_device(data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+
+    def writer_loss(writer, losses, epoch):
+        # ??????
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    # ?? wirepoint ??????
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            # ?? .item() ?????
+                            writer.add_scalar(f'loss_wirepoint/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    # ????????
+                    writer.add_scalar(key, value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+
+    for epoch in range(cfg['optim']['max_epoch']):
+        print(f"epoch:{epoch}")
+        model.train()
+
+        for imgs, targets in data_loader_train:
+            print(f'targets:{targets[0]["wires"]["line_map"].shape}')
+            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+            loss = _loss(losses)
+            print(loss)
+        # optimizer.zero_grad()
+        # loss.backward()
+        # optimizer.step()
+        # writer_loss(writer, losses, epoch)
+
+        # model.eval()
+        # with torch.no_grad():
+        #     for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+        #         pred = model(move_to_device(imgs, device))
+        #         # print(f"pred:{pred}")
+        #
+        #         if batch_idx == 0:
+        #             result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
+        #             print(imgs[0].shape)  # [3,512,512]
+        #             # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
+        #             _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
+                    # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
+
+# imgs, targets = next(iter(data_loader))
+#
+# model.train()
+# pred = model(imgs, targets)
+# print(f'pred:{pred}')
+
+# result, losses = model(imgs, targets)
+# print(f'result:{result}')
+# print(f'pred:{losses}')

+ 70 - 0
models/wirenet2/WirePredictor.py

@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class WirePredictor(nn.Module):
+    def __init__(self, in_channels=4, out_channels=1, init_features=32):
+        super(WirePredictor, self).__init__()
+
+        features = init_features
+        self.encoder1 = self._block(in_channels, features, name="enc1")
+        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+        self.encoder2 = self._block(features, features * 2, name="enc2")
+        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+        self.bottleneck = self._block(features * 2, features * 4, name="bottleneck")
+
+        self.upconv2 = nn.ConvTranspose2d(
+            features * 4, features * 2, kernel_size=2, stride=2
+        )
+        self.decoder2 = self._block((features * 2) * 2, features * 2, name="dec2")
+        self.upconv1 = nn.ConvTranspose2d(
+            features * 2, features, kernel_size=2, stride=2
+        )
+        self.decoder1 = self._block(features * 2, features, name="dec1")
+
+        # Output for line segment mask
+        self.conv_mask = nn.Conv2d(
+            in_channels=features, out_channels=out_channels, kernel_size=1
+        )
+
+        # Output for normal vectors (2 channels for x and y components)
+        self.conv_normals = nn.Conv2d(
+            in_channels=features, out_channels=2, kernel_size=1
+        )
+
+    def forward(self, x):
+        enc1 = self.encoder1(x)
+        enc2 = self.encoder2(self.pool1(enc1))
+
+        bottleneck = self.bottleneck(self.pool2(enc2))
+
+        dec2 = self.upconv2(bottleneck)
+        dec2 = torch.cat((dec2, enc2), dim=1)
+        dec2 = self.decoder2(dec2)
+        dec1 = self.upconv1(dec2)
+        dec1 = torch.cat((dec1, enc1), dim=1)
+        dec1 = self.decoder1(dec1)
+
+        mask = torch.sigmoid(self.conv_mask(dec1))
+        normals = torch.tanh(self.conv_normals(dec1))  # Normalize to [-1, 1]
+
+        return mask, normals
+
+    def _block(self, in_channels, features, name):
+        return nn.Sequential(
+            nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(num_features=features),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False),
+            nn.BatchNorm2d(num_features=features),
+            nn.ReLU(inplace=True),
+        )
+
+# 测试模型
+if __name__ == "__main__":
+    model = WirePredictor()
+    x = torch.randn((1, 4, 128, 128))  # 包含法向量信息的输入大小为 128x128
+    with torch.no_grad():
+        output_mask, output_normals = model(x)
+        print(output_mask.shape, output_normals.shape)  # 应输出 (1, 1, 128, 128) 和 (1, 2, 128, 128)

+ 0 - 0
models/wirenet2/__init__.py


+ 548 - 0
models/wirenet2/_utils.py

@@ -0,0 +1,548 @@
+import math
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
+
+
+class BalancedPositiveNegativeSampler:
+    """
+    This class samples batches, ensuring that they contain a fixed proportion of positives
+    """
+
+    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
+        """
+        Args:
+            batch_size_per_image (int): number of elements to be selected per image
+            positive_fraction (float): percentage of positive elements per batch
+        """
+        self.batch_size_per_image = batch_size_per_image
+        self.positive_fraction = positive_fraction
+
+    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+        """
+        Args:
+            matched_idxs: list of tensors containing -1, 0 or positive values.
+                Each tensor corresponds to a specific image.
+                -1 values are ignored, 0 are considered as negatives and > 0 as
+                positives.
+
+        Returns:
+            pos_idx (list[tensor])
+            neg_idx (list[tensor])
+
+        Returns two lists of binary masks for each image.
+        The first list contains the positive elements that were selected,
+        and the second list the negative example.
+        """
+        pos_idx = []
+        neg_idx = []
+        for matched_idxs_per_image in matched_idxs:
+            positive = torch.where(matched_idxs_per_image >= 1)[0]
+            negative = torch.where(matched_idxs_per_image == 0)[0]
+
+            num_pos = int(self.batch_size_per_image * self.positive_fraction)
+            # protect against not enough positive examples
+            num_pos = min(positive.numel(), num_pos)
+            num_neg = self.batch_size_per_image - num_pos
+            # protect against not enough negative examples
+            num_neg = min(negative.numel(), num_neg)
+
+            # randomly select positive and negative examples
+            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+            pos_idx_per_image = positive[perm1]
+            neg_idx_per_image = negative[perm2]
+
+            # create binary mask from indices
+            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+
+            pos_idx_per_image_mask[pos_idx_per_image] = 1
+            neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+            pos_idx.append(pos_idx_per_image_mask)
+            neg_idx.append(neg_idx_per_image_mask)
+
+        return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
+    """
+    Encode a set of proposals with respect to some
+    reference boxes
+
+    Args:
+        reference_boxes (Tensor): reference boxes
+        proposals (Tensor): boxes to be encoded
+        weights (Tensor[4]): the weights for ``(x, y, w, h)``
+    """
+
+    # perform some unpacking to make it JIT-fusion friendly
+    wx = weights[0]
+    wy = weights[1]
+    ww = weights[2]
+    wh = weights[3]
+
+    proposals_x1 = proposals[:, 0].unsqueeze(1)
+    proposals_y1 = proposals[:, 1].unsqueeze(1)
+    proposals_x2 = proposals[:, 2].unsqueeze(1)
+    proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+    # implementation starts here
+    ex_widths = proposals_x2 - proposals_x1
+    ex_heights = proposals_y2 - proposals_y1
+    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+    ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+    gt_widths = reference_boxes_x2 - reference_boxes_x1
+    gt_heights = reference_boxes_y2 - reference_boxes_y1
+    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+    targets_dw = ww * torch.log(gt_widths / ex_widths)
+    targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+    return targets
+
+
+class BoxCoder:
+    """
+    This class encodes and decodes a set of bounding boxes into
+    the representation used for training the regressors.
+    """
+
+    def __init__(
+        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
+    ) -> None:
+        """
+        Args:
+            weights (4-element tuple)
+            bbox_xform_clip (float)
+        """
+        self.weights = weights
+        self.bbox_xform_clip = bbox_xform_clip
+
+    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
+        boxes_per_image = [len(b) for b in reference_boxes]
+        reference_boxes = torch.cat(reference_boxes, dim=0)
+        proposals = torch.cat(proposals, dim=0)
+        targets = self.encode_single(reference_boxes, proposals)
+        return targets.split(boxes_per_image, 0)
+
+    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some
+        reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+        """
+        dtype = reference_boxes.dtype
+        device = reference_boxes.device
+        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+        targets = encode_boxes(reference_boxes, proposals, weights)
+
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
+        torch._assert(
+            isinstance(boxes, (list, tuple)),
+            "This function expects boxes of type list or tuple.",
+        )
+        torch._assert(
+            isinstance(rel_codes, torch.Tensor),
+            "This function expects rel_codes of type torch.Tensor.",
+        )
+        boxes_per_image = [b.size(0) for b in boxes]
+        concat_boxes = torch.cat(boxes, dim=0)
+        box_sum = 0
+        for val in boxes_per_image:
+            box_sum += val
+        if box_sum > 0:
+            rel_codes = rel_codes.reshape(box_sum, -1)
+        pred_boxes = self.decode_single(rel_codes, concat_boxes)
+        if box_sum > 0:
+            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+        return pred_boxes
+
+    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+        """
+
+        boxes = boxes.to(rel_codes.dtype)
+
+        widths = boxes[:, 2] - boxes[:, 0]
+        heights = boxes[:, 3] - boxes[:, 1]
+        ctr_x = boxes[:, 0] + 0.5 * widths
+        ctr_y = boxes[:, 1] + 0.5 * heights
+
+        wx, wy, ww, wh = self.weights
+        dx = rel_codes[:, 0::4] / wx
+        dy = rel_codes[:, 1::4] / wy
+        dw = rel_codes[:, 2::4] / ww
+        dh = rel_codes[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=self.bbox_xform_clip)
+        dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        # Distance from center to box's corner.
+        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+
+        pred_boxes1 = pred_ctr_x - c_to_c_w
+        pred_boxes2 = pred_ctr_y - c_to_c_h
+        pred_boxes3 = pred_ctr_x + c_to_c_w
+        pred_boxes4 = pred_ctr_y + c_to_c_h
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+        return pred_boxes
+
+
+class BoxLinearCoder:
+    """
+    The linear box-to-box transform defined in FCOS. The transformation is parameterized
+    by the distance from the center of (square) src box to 4 edges of the target box.
+    """
+
+    def __init__(self, normalize_by_size: bool = True) -> None:
+        """
+        Args:
+            normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
+        """
+        self.normalize_by_size = normalize_by_size
+
+    def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+
+        Returns:
+            Tensor: the encoded relative box offsets that can be used to
+            decode the boxes.
+
+        """
+
+        # get the center of reference_boxes
+        reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
+        reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
+
+        # get box regression transformation deltas
+        target_l = reference_boxes_ctr_x - proposals[..., 0]
+        target_t = reference_boxes_ctr_y - proposals[..., 1]
+        target_r = proposals[..., 2] - reference_boxes_ctr_x
+        target_b = proposals[..., 3] - reference_boxes_ctr_y
+
+        targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
+
+        if self.normalize_by_size:
+            reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
+            reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
+            reference_boxes_size = torch.stack(
+                (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
+            )
+            targets = targets / reference_boxes_size
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+
+        Returns:
+            Tensor: the predicted boxes with the encoded relative box offsets.
+
+        .. note::
+            This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
+
+        """
+
+        boxes = boxes.to(dtype=rel_codes.dtype)
+
+        ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
+        ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
+
+        if self.normalize_by_size:
+            boxes_w = boxes[..., 2] - boxes[..., 0]
+            boxes_h = boxes[..., 3] - boxes[..., 1]
+
+            list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
+            rel_codes = rel_codes * list_box_size
+
+        pred_boxes1 = ctr_x - rel_codes[..., 0]
+        pred_boxes2 = ctr_y - rel_codes[..., 1]
+        pred_boxes3 = ctr_x + rel_codes[..., 2]
+        pred_boxes4 = ctr_y + rel_codes[..., 3]
+
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
+        return pred_boxes
+
+
+class Matcher:
+    """
+    This class assigns to each predicted "element" (e.g., a box) a ground-truth
+    element. Each predicted element will have exactly zero or one matches; each
+    ground-truth element may be assigned to zero or more predicted elements.
+
+    Matching is based on the MxN match_quality_matrix, that characterizes how well
+    each (ground-truth, predicted)-pair match. For example, if the elements are
+    boxes, the matrix may contain box IoU overlap values.
+
+    The matcher returns a tensor of size N containing the index of the ground-truth
+    element m that matches to prediction n. If there is no match, a negative value
+    is returned.
+    """
+
+    BELOW_LOW_THRESHOLD = -1
+    BETWEEN_THRESHOLDS = -2
+
+    __annotations__ = {
+        "BELOW_LOW_THRESHOLD": int,
+        "BETWEEN_THRESHOLDS": int,
+    }
+
+    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
+        """
+        Args:
+            high_threshold (float): quality values greater than or equal to
+                this value are candidate matches.
+            low_threshold (float): a lower quality threshold used to stratify
+                matches into three levels:
+                1) matches >= high_threshold
+                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+            allow_low_quality_matches (bool): if True, produce additional matches
+                for predictions that have only low-quality match candidates. See
+                set_low_quality_matches_ for more details.
+        """
+        self.BELOW_LOW_THRESHOLD = -1
+        self.BETWEEN_THRESHOLDS = -2
+        torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
+        self.high_threshold = high_threshold
+        self.low_threshold = low_threshold
+        self.allow_low_quality_matches = allow_low_quality_matches
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        """
+        Args:
+            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+            pairwise quality between M ground-truth elements and N predicted elements.
+
+        Returns:
+            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+            [0, M - 1] or a negative value indicating that prediction i could not
+            be matched.
+        """
+        if match_quality_matrix.numel() == 0:
+            # empty targets or proposals not supported during training
+            if match_quality_matrix.shape[0] == 0:
+                raise ValueError("No ground-truth boxes available for one of the images during training")
+            else:
+                raise ValueError("No proposal boxes available for one of the images during training")
+
+        # match_quality_matrix is M (gt) x N (predicted)
+        # Max over gt elements (dim 0) to find best gt candidate for each prediction
+        matched_vals, matches = match_quality_matrix.max(dim=0)
+        if self.allow_low_quality_matches:
+            all_matches = matches.clone()
+        else:
+            all_matches = None  # type: ignore[assignment]
+
+        # Assign candidate matches with low quality to negative (unassigned) values
+        below_low_threshold = matched_vals < self.low_threshold
+        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
+        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
+        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
+
+        if self.allow_low_quality_matches:
+            if all_matches is None:
+                torch._assert(False, "all_matches should not be None")
+            else:
+                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+        return matches
+
+    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
+        """
+        Produce additional matches for predictions that have only low-quality matches.
+        Specifically, for each ground-truth find the set of predictions that have
+        maximum overlap with it (including ties); for each prediction in that set, if
+        it is unmatched, then match it to the ground-truth with which it has the highest
+        quality value.
+        """
+        # For each gt, find the prediction with which it has the highest quality
+        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+        # Find the highest quality match available, even if it is low, including ties
+        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
+        # Example gt_pred_pairs_of_highest_quality:
+        #   tensor([[    0, 39796],
+        #           [    1, 32055],
+        #           [    1, 32070],
+        #           [    2, 39190],
+        #           [    2, 40255],
+        #           [    3, 40390],
+        #           [    3, 41455],
+        #           [    4, 45470],
+        #           [    5, 45325],
+        #           [    5, 46390]])
+        # Each row is a (gt index, prediction index)
+        # Note how gt items 1, 2, 3, and 5 each have two ties
+
+        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
+
+
+class SSDMatcher(Matcher):
+    def __init__(self, threshold: float) -> None:
+        super().__init__(threshold, threshold, allow_low_quality_matches=False)
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        matches = super().__call__(match_quality_matrix)
+
+        # For each gt, find the prediction with which it has the highest quality
+        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
+        matches[highest_quality_pred_foreach_gt] = torch.arange(
+            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
+        )
+
+        return matches
+
+
+def overwrite_eps(model: nn.Module, eps: float) -> None:
+    """
+    This method overwrites the default eps values of all the
+    FrozenBatchNorm2d layers of the model with the provided value.
+    This is necessary to address the BC-breaking change introduced
+    by the bug-fix at pytorch/vision#2933. The overwrite is applied
+    only when the pretrained weights are loaded to maintain compatibility
+    with previous versions.
+
+    Args:
+        model (nn.Module): The model on which we perform the overwrite.
+        eps (float): The new value of eps.
+    """
+    for module in model.modules():
+        if isinstance(module, FrozenBatchNorm2d):
+            module.eps = eps
+
+
+def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
+    """
+    This method retrieves the number of output channels of a specific model.
+
+    Args:
+        model (nn.Module): The model for which we estimate the out_channels.
+            It should return a single Tensor or an OrderedDict[Tensor].
+        size (Tuple[int, int]): The size (wxh) of the input.
+
+    Returns:
+        out_channels (List[int]): A list of the output channels of the model.
+    """
+    in_training = model.training
+    model.eval()
+
+    with torch.no_grad():
+        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
+        device = next(model.parameters()).device
+        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
+        features = model(tmp_img)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        out_channels = [x.size(1) for x in features.values()]
+
+    if in_training:
+        model.train()
+
+    return out_channels
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> int:
+    return v  # type: ignore[return-value]
+
+
+def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
+    """
+    ONNX spec requires the k-value to be less than or equal to the number of inputs along
+    provided dim. Certain models use the number of elements along a particular axis instead of K
+    if K exceeds the number of elements along that axis. Previously, python's min() function was
+    used to determine whether to use the provided k-value or the specified dim axis value.
+
+    However, in cases where the model is being exported in tracing mode, python min() is
+    static causing the model to be traced incorrectly and eventually fail at the topk node.
+    In order to avoid this situation, in tracing mode, torch.min() is used instead.
+
+    Args:
+        input (Tensor): The original input tensor.
+        orig_kval (int): The provided k-value.
+        axis(int): Axis along which we retrieve the input size.
+
+    Returns:
+        min_kval (int): Appropriately selected k-value.
+    """
+    if not torch.jit.is_tracing():
+        return min(orig_kval, input.size(axis))
+    axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
+    min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
+    return _fake_cast_onnx(min_kval)
+
+
+def _box_loss(
+    type: str,
+    box_coder: BoxCoder,
+    anchors_per_image: Tensor,
+    matched_gt_boxes_per_image: Tensor,
+    bbox_regression_per_image: Tensor,
+    cnf: Optional[Dict[str, float]] = None,
+) -> Tensor:
+    torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
+
+    if type == "l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+    elif type == "smooth_l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
+        return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
+    else:
+        bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
+        eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
+        if type == "ciou":
+            return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        if type == "diou":
+            return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        # otherwise giou
+        return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

+ 82 - 0
models/wirenet2/kepointrcnn.py

@@ -0,0 +1,82 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+from  torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from models.config.config_tool import read_yaml
+from models.keypoint.trainer import train_cfg
+
+from tools import utils
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+
+
+
+
+
+class KeypointRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
+        super(KeypointRCNNModel, self).__init__()
+        default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+        self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
+                                                                              num_keypoints=num_keypoints,
+                                                                              progress=False)
+        if transforms is None:
+            self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
+        # if num_classes != 0:
+        #     self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes = parameters['num_classes']
+        num_keypoints = parameters['num_keypoints']
+        # print(f'num_classes:{num_classes}')
+        # self.set_num_classes(num_classes)
+        self.num_keypoints = num_keypoints
+        train_cfg(self.__model, cfg)
+
+    # def set_num_classes(self, num_classes):
+    #     in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+    #     self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+    #
+    #     # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+    #     in_channels = self.__model.roi_heads.keypoint_predictor.
+    #     hidden_layer = 256
+    #     self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
+    #                                                               num_classes=num_classes)
+    #     self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    keypoint_model = KeypointRCNNModel(num_keypoints=2)
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    keypoint_model.train(cfg='train.yaml')

+ 203 - 0
models/wirenet2/keypoint_dataset.py

@@ -0,0 +1,203 @@
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+def validate_keypoints(keypoints, image_width, image_height):
+    for kp in keypoints:
+        x, y, v = kp
+        if not (0 <= x < image_width and 0 <= y < image_height):
+            raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
+
+
+class KeypointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+
+        target["labels"] = torch.stack(labels)
+        # print(f'labels:{target["labels"]}')
+        # target["boxes"] = line_boxes(target)
+        target["boxes"], keypoints = line_boxes(target)
+        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
+
+        # keypoints= wire_labels["junc_coords"]
+        a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
+        keypoints = torch.cat((keypoints, a), dim=1)
+        target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
+        # print(f'boxes:{target["boxes"].shape}')
+        # 在 __getitem__ 方法中调用此函数
+        validate_keypoints(keypoints, shape[0], shape[1])
+        # print(f'keypoints:{target["keypoints"].shape}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
+
+
+
+if __name__ == '__main__':
+    path=r"I:\wirenet_dateset"
+    dataset= KeypointDataset(dataset_path=path, dataset_type='train')
+    dataset.show(0)

+ 879 - 0
models/wirenet2/roi_heads.py

@@ -0,0 +1,879 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+def roi_line_loss(keypoints, rois, heatmap_size):
+
+    pass
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 65 - 0
models/wirenet2/test.py

@@ -0,0 +1,65 @@
+import time
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.io import decode_image, read_image
+import torchvision.transforms.functional as F
+from torchvision.utils import draw_keypoints
+def show(imgs):
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
+    for i, img in enumerate(imgs):
+        img = img.detach()
+        img = F.to_pil_image(img)
+        axs[0, i].imshow(np.asarray(img))
+        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
+
+img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
+# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
+img_int = read_image(img_path)
+
+
+# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
+
+weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+transforms = weights.transforms()
+print(f'transforms:{transforms}')
+img = transforms(img_int)
+
+person_float = transforms(img)
+
+model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
+model = model.eval()
+t1=time.time()
+# img = torch.ones((3, 3, 512, 512))
+
+
+outputs = model([img])
+t2=time.time()
+print(f'time:{t2-t1}')
+# print(f'outputs:{outputs}')
+
+kpts = outputs[0]['keypoints']
+scores = outputs[0]['scores']
+
+print(f'kpts:{kpts}')
+print(f'scores:{scores}')
+
+detect_threshold = 0.75
+idx = torch.where(scores > detect_threshold)
+keypoints = kpts[idx]
+
+# print(f'keypoints:{keypoints}')
+
+
+
+res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
+show(res)
+plt.show()
+
+
+
+

+ 50 - 0
models/wirenet2/test_linemap.py

@@ -0,0 +1,50 @@
+import numpy as np
+import cv2
+
+
+def draw_line_heatmap(image_shape, pt1, pt2, sigma=1):
+    """
+    根据给定的两个端点生成线段的热度图。
+
+    参数:
+    - image_shape: (height, width) 输出热度图的形状
+    - pt1: (x1, y1) 线段的第一个端点
+    - pt2: (x2, y2) 线段的第二个端点
+    - sigma: 高斯核的标准差,用于控制热度扩散的程度
+
+    返回:
+    - heatmap: 生成的热度图
+    """
+    # 创建空白热度图
+    heatmap = np.zeros(image_shape, dtype=np.float32)
+
+    # 绘制线段
+    cv2.line(heatmap, pt1, pt2, color=1, thickness=1)
+
+    # 应用高斯模糊以生成热度效果
+    if sigma > 0:
+        heatmap = cv2.GaussianBlur(heatmap, (0, 0), sigmaX=sigma, sigmaY=sigma)
+
+    # 归一化热度图
+    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)
+
+    return heatmap
+
+
+# 测试函数
+if __name__ == "__main__":
+    # 定义图像尺寸和线段端点
+    image_shape = (256, 256)  # 图像的高度和宽度
+    pt1 = (50, 50)  # 第一个端点
+    pt2 = (200, 200)  # 第二个端点
+    sigma = 2  # 控制热度扩散程度
+
+    # 生成热度图
+    heatmap = draw_line_heatmap(image_shape, pt1, pt2, sigma)
+
+    # 显示结果
+    import matplotlib.pyplot as plt
+
+    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
+    plt.colorbar()
+    plt.show()

+ 32 - 0
models/wirenet2/train.yaml

@@ -0,0 +1,32 @@
+
+
+dataset_path: I:/wirenet_dateset
+
+#train parameters
+num_classes: 2
+num_keypoints: 2
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: pixel
+enable_logs: True
+augmentation: False
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 212 - 0
models/wirenet2/trainer.py

@@ -0,0 +1,212 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.keypoint.keypoint_dataset import KeypointDataset
+from tools import utils, presets
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        # print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            # print(f'loss_dict:{loss_dict}')
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = KeypointDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 617 - 0
models/wirenet2/wirenet_rcnn.py

@@ -0,0 +1,617 @@
+import os
+from typing import Optional, Any
+
+import cv2
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+# from torchinfo import summary
+from torchvision.io import read_image
+from torchvision.models import resnet50, ResNet50_Weights, WeightsEnum, Weights, resnet18, ResNet18_Weights
+from torchvision.models._api import register_model
+from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param
+from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection._utils import overwrite_eps
+from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from  torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.ops import misc as misc_nn_ops
+
+__all__ = [
+    "WirenetRCNN",
+    "WirenetRCNN_ResNet50_FPN_Weights",
+    "wirenetrcnn_resnet50_fpn",
+]
+
+from torchvision.transforms._presets import ObjectDetection
+
+
+class WirenetRCNN(FasterRCNN):
+    """
+    Implements Keypoint R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
+          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): during inference, only return proposals with a classification score
+            greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+        wirenet_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+             the locations indicated by the bounding boxes, which will be used for the keypoint head.
+        wirenet_head (nn.Module): module that takes the cropped feature maps as input
+        wirenet_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
+            heatmap logits
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import KeypointRCNN
+        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+        >>>
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # KeypointRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                          output_size=14,
+        >>>                                                          sampling_ratio=2)
+        >>> # put the pieces together inside a KeypointRCNN model
+        >>> model = KeypointRCNN(backbone,
+        >>>                      num_classes=2,
+        >>>                      rpn_anchor_generator=anchor_generator,
+        >>>                      box_roi_pool=roi_pooler,
+        >>>                      keypoint_roi_pool=keypoint_roi_pooler)
+        >>> model.eval()
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+            self,
+            backbone,
+            num_classes=None,
+            # transform parameters
+            min_size=512,
+            max_size=1333,
+            image_mean=None,
+            image_std=None,
+            # RPN parameters
+            rpn_anchor_generator=None,
+            rpn_head=None,
+            rpn_pre_nms_top_n_train=2000,
+            rpn_pre_nms_top_n_test=1000,
+            rpn_post_nms_top_n_train=2000,
+            rpn_post_nms_top_n_test=1000,
+            rpn_nms_thresh=0.7,
+            rpn_fg_iou_thresh=0.7,
+            rpn_bg_iou_thresh=0.3,
+            rpn_batch_size_per_image=256,
+            rpn_positive_fraction=0.5,
+            rpn_score_thresh=0.0,
+            # Box parameters
+            box_roi_pool=None,
+            box_head=None,
+            box_predictor=None,
+            box_score_thresh=0.05,
+            box_nms_thresh=0.5,
+            box_detections_per_img=100,
+            box_fg_iou_thresh=0.5,
+            box_bg_iou_thresh=0.5,
+            box_batch_size_per_image=512,
+            box_positive_fraction=0.25,
+            bbox_reg_weights=None,
+            # keypoint parameters
+            wirenet_roi_pool=None,
+            wirenet_head=None,
+            wirenet_predictor=None,
+            num_keypoints=None,
+            **kwargs,
+    ):
+
+        if not isinstance(wirenet_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+            )
+        if min_size is None:
+            min_size = (640, 672, 704, 736, 768, 800)
+
+        if num_keypoints is not None:
+            if wirenet_predictor is not None:
+                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        else:
+            num_keypoints = 2
+
+        out_channels = backbone.out_channels
+
+        if wirenet_roi_pool is None:
+            wirenet_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if wirenet_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            wirenet_head = WirenetRCNNHeads(out_channels, keypoint_layers)
+
+        if wirenet_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            wirenet_predictor = WirenetRCNNPredictor(keypoint_dim_reduced, num_keypoints)
+
+        super().__init__(
+            backbone,
+            num_classes,
+            # transform parameters
+            min_size,
+            max_size,
+            image_mean,
+            image_std,
+            # RPN-specific parameters
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_pre_nms_top_n_train,
+            rpn_pre_nms_top_n_test,
+            rpn_post_nms_top_n_train,
+            rpn_post_nms_top_n_test,
+            rpn_nms_thresh,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_score_thresh,
+            # Box parameters
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            **kwargs,
+        )
+
+        self.roi_heads.keypoint_roi_pool = wirenet_roi_pool
+        self.roi_heads.keypoint_head = wirenet_head
+        self.roi_heads.keypoint_predictor = wirenet_predictor
+
+
+class WirenetRCNNHeads(nn.Module):
+    def __init__(self, in_channels, layers, num_keypoints=3):
+        super().__init__()
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        # super().__init__(*d)
+        self.feature_layers = nn.Sequential(*d)
+        for m in self.feature_layers.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+        input_features = next_feature
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = num_keypoints
+
+    def forward(self, x):
+        x = self.feature_layers(x)
+        x = self.kps_score_lowres(x)
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )
+
+
+class WirenetRCNNPredictor(nn.Module):
+    def __init__(self, in_channels, num_keypoints):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = num_keypoints
+
+    def forward(self, x):
+        x = self.kps_score_lowres(x)
+        x=torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )
+        print(f'x.shape:{x.shape}')
+        return x
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}
+
+
+class WirenetRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_LEGACY = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/issues/1606",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 50.6,
+                    "kp_map": 61.1,
+                }
+            },
+            "_ops": 133.924,
+            "_file_size": 226.054,
+            "_docs": """
+                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
+                from an early epoch.
+            """,
+        },
+    )
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 54.6,
+                    "kp_map": 65.0,
+                }
+            },
+            "_ops": 137.42,
+            "_file_size": 226.054,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=(
+            "pretrained",
+            lambda kwargs: WirenetRCNN_ResNet50_FPN_Weights.COCO_LEGACY
+            if kwargs["pretrained"] == "legacy"
+            else WirenetRCNN_ResNet50_FPN_Weights.COCO_V1,
+    ),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def wirenetrcnn_resnet50_fpn(
+        *,
+        weights: Optional[WirenetRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> WirenetRCNN:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=WirenetRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = WirenetRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = WirenetRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == WirenetRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+def wirenetrcnn_resnet18_fpn(
+        *,
+        weights: Optional[WirenetRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> WirenetRCNN:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=WirenetRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = WirenetRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    # if weights_backbone is None:
+
+    weights_backbone = ResNet18_Weights.IMAGENET1K_V1
+
+    if weights is not None:
+        # weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 3
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = WirenetRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == WirenetRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+if __name__ == '__main__':
+    model=wirenetrcnn_resnet18_fpn(num_keypoints=3)
+    img = torch.ones((3, 3, 512, 512))
+    model.eval()
+    model(img)
+    # model.train(cfg='train.yaml')