Ver Fonte

WireDataset

xue50 há 5 meses atrás
commit
79c7342e0f

+ 5 - 0
.gitignore

@@ -0,0 +1,5 @@
+.idea
+*.pt
+runs
+__pycache__
+train_results

BIN
assets/dog1.jpg


BIN
assets/dog1.png


BIN
assets/dog2.jpg


BIN
assets/dog2.png


+ 494 - 0
main.py

@@ -0,0 +1,494 @@
+import math
+import os.path
+import re
+import sys
+
+import PIL.Image
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import torchvision.transforms
+import torchvision.transforms.functional as F
+from torch.utils.data import DataLoader
+from torchvision.transforms import v2
+
+from torchvision.utils import make_grid, draw_bounding_boxes
+from torchvision.io import read_image
+from pathlib import Path
+from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
+# PyTorch TensorBoard support
+from torch.utils.tensorboard import SummaryWriter
+import cv2
+from sklearn.cluster import DBSCAN
+from test.MaskRCNN import MaskRCNNDataset
+from tools import utils
+import pandas as pd
+
+plt.rcParams["savefig.bbox"] = 'tight'
+orig_path = r'F:\Downloads\severstal-steel-defect-detection'
+dst_path = r'F:\Downloads\severstal-steel-defect-detection'
+
+
+def show(imgs):
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
+    for i, img in enumerate(imgs):
+        img = img.detach()
+        img = F.to_pil_image(img)
+        axs[0, i].imshow(np.asarray(img))
+        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
+    plt.show()
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+
+def train():
+    pass
+
+
+def trans_datasets_format():
+    # 使用pandas的read_csv函数读取文件
+    df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
+
+    # 显示数据的前几行
+    print(df.head())
+    for row in df.itertuples():
+        # print(f"Row index: {row.Index}")
+        # print(getattr(row, 'ImageId'))  # 输出特定列的值
+        img_name = getattr(row, 'ImageId')
+        img_path = os.path.join(orig_path + '/train_images', img_name)
+        dst_img_path = os.path.join(dst_path + '/images/train', img_name)
+        dst_label_path = os.path.join(dst_path + '/labels/train', img_name[:-3] + 'txt')
+        print(f'dst label:{dst_label_path}')
+        im = cv2.imread(img_path)
+        # cv2.imshow('test',im)
+        cv2.imwrite(dst_img_path, im)
+        img = PIL.Image.open(img_path)
+        height, width = im.shape[:2]
+        print(f'cv2 size:{im.shape}')
+        label, mask = compute_mask(row, img.size)
+        lbls, ins_masks=cluster_dbscan(mask,img)
+
+
+
+        with open(dst_label_path, 'a+') as writer:
+            # writer.write(label)
+            for ins_mask in ins_masks:
+                lbl_data = str(label) + ' '
+                for mp in ins_mask:
+                    h,w=mp
+                    lbl_data += str(w / width) + ' ' + str(h / height) + ' '
+
+                # non_zero_coords = np.nonzero(inm.reshape(width,height).T)
+                # coords_list = list(zip(non_zero_coords[0], non_zero_coords[1]))
+                # # print(f'mask:{mask[0,333]}')
+                # print(f'mask pixels:{coords_list}')
+                #
+                #
+                # for coord in coords_list:
+                #     h, w = coord
+                #     lbl_data += str(w / width) + ' ' + str(h / height) + ' '
+
+                writer.write(lbl_data + '\n')
+                print(f'lbl_data:{lbl_data}')
+        writer.close()
+        print(f'label:{label}')
+        # plt.imshow(img)
+        # plt.imshow(mask, cmap='Reds', alpha=0.3)
+        # plt.show()
+
+
+def compute_mask(row, shape):
+    width, height = shape
+    print(f'shape:{shape}')
+    mask = np.zeros(width * height, dtype=np.uint8)
+    pixels = np.array(list(map(int, row.EncodedPixels.split())))
+    label = row.ClassId
+    # print(f'pixels:{pixels}')
+    mask_start = pixels[0::2]
+    mask_length = pixels[1::2]
+
+    for s, l in zip(mask_start, mask_length):
+        mask[s:s + l] = 255
+    mask = mask.reshape((width, height)).T
+
+    # mask = np.flipud(np.rot90(mask.reshape((height, width))))
+    return label, mask
+
+def cluster_dbscan(mask,image):
+    # 将 mask 转换为二值图像
+    _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
+
+    # 将 mask 一维化
+    mask_flattened = mask_binary.flatten()
+
+    # 获取 mask 中的前景像素坐标
+    foreground_pixels = np.argwhere(mask_flattened == 255)
+
+    # 将像素坐标转换为二维坐标
+    foreground_pixels_2d = np.column_stack(
+        (foreground_pixels // mask_binary.shape[1], foreground_pixels % mask_binary.shape[1]))
+
+    # 定义 DBSCAN 参数
+    eps = 3  # 邻域半径
+    min_samples = 10  # 最少样本数量
+
+    # 应用 DBSCAN
+    dbscan = DBSCAN(eps=eps, min_samples=min_samples).fit(foreground_pixels_2d)
+
+    # 获取聚类标签
+    labels = dbscan.labels_
+    print(f'labels:{labels}')
+    # 获取唯一的标签
+    unique_labels = set(labels)
+
+    print(f'unique_labels:{unique_labels}')
+    # 创建一个空的图像来保存聚类结果
+    clustered_image = np.zeros_like(image)
+    # print(f'clustered_image shape:{clustered_image.shape}')
+
+
+    # 将每个像素分配给相应的簇
+    clustered_points=[]
+    for k in unique_labels:
+
+
+        class_member_mask = (labels == k)
+        # print(f'class_member_mask:{class_member_mask}')
+        # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
+
+        pixel_indices = foreground_pixels_2d[class_member_mask]
+        clustered_points.append(pixel_indices)
+
+    return unique_labels,clustered_points
+
+def show_cluster_dbscan(mask,image,unique_labels,clustered_points,):
+    print(f'mask shape:{mask.shape}')
+    # 将 mask 转换为二值图像
+    _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
+
+    # 将 mask 一维化
+    mask_flattened = mask_binary.flatten()
+
+    # 获取 mask 中的前景像素坐标
+    foreground_pixels = np.argwhere(mask_flattened == 255)
+    # print(f'unique_labels:{unique_labels}')
+    # 创建一个空的图像来保存聚类结果
+    print(f'image shape:{image.shape}')
+    clustered_image = np.zeros_like(image)
+    print(f'clustered_image shape:{clustered_image.shape}')
+
+    # 为每个簇分配颜色
+    colors =np.array( [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))])
+    # print(f'colors:{colors}')
+    plt.figure(figsize=(12, 6))
+    for points_coord,col in  zip(clustered_points,colors):
+        for coord in points_coord:
+
+            clustered_image[coord[0], coord[1]] = (np.array(col[:3]) * 255)
+
+    # # 将每个像素分配给相应的簇
+    # for k, col in zip(unique_labels, colors):
+    #     print(f'col:{col*255}')
+    #     if k == -1:
+    #         # 黑色用于噪声点
+    #         col = [0, 0, 0, 1]
+    #
+    #     class_member_mask = (labels == k)
+    #     # print(f'class_member_mask:{class_member_mask}')
+    #     # plt.subplot(132), plt.imshow(class_member_mask), plt.title(str(labels))
+    #
+    #     pixel_indices = foreground_pixels_2d[class_member_mask]
+    #     clustered_points.append(pixel_indices)
+    #     # print(f'pixel_indices:{pixel_indices}')
+    #     for pixel_index in pixel_indices:
+    #         clustered_image[pixel_index[0], pixel_index[1]] = (np.array(col[:3]) * 255)
+
+    print(f'clustered_points:{len(clustered_points)}')
+    # print(f'clustered_image:{clustered_image}')
+    # 显示原图和聚类结果
+    # plt.figure(figsize=(12, 6))
+    plt.subplot(131), plt.imshow(image), plt.title('Original Image')
+    # print(f'image:{image}')
+    plt.subplot(132), plt.imshow(mask_binary, cmap='gray'), plt.title('Mask')
+    plt.subplot(133), plt.imshow(clustered_image.astype(np.uint8)), plt.title('Clustered Image')
+    plt.show()
+def test():
+    dog1_int = read_image(str(Path('./assets') / 'dog1.jpg'))
+    dog2_int = read_image(str(Path('./assets') / 'dog2.jpg'))
+    dog_list = [dog1_int, dog2_int]
+    grid = make_grid(dog_list)
+
+    weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
+    transforms = weights.transforms()
+
+    images = [transforms(d) for d in dog_list]
+    # 假设输入图像的尺寸为 (3, 800, 800)
+    dummy_input = torch.randn(1, 3, 800, 800)
+    model = maskrcnn_resnet50_fpn_v2(weights=weights, progress=False)
+    model = model.eval()
+
+    # 使用 torch.jit.script
+    scripted_model = torch.jit.script(model)
+
+    output = model(dummy_input)
+    print(f'output:{output}')
+
+    writer = SummaryWriter('runs/')
+    writer.add_graph(scripted_model, input_to_model=dummy_input)
+    writer.flush()
+
+    # torch.onnx.export(models,images, f='maskrcnn.onnx')  # 导出 .onnx 文
+    # netron.start('AlexNet.onnx')  # 展示结构图
+
+    show(grid)
+
+
+def test_mask():
+    name = 'fdb7c0397'
+    label_path = os.path.join(dst_path + '/labels/train', name + '.txt')
+    img_path = os.path.join(orig_path + '/train_images', name + '.jpg')
+    mask = np.zeros((256, 1600), dtype=np.uint8)
+    df = pd.read_csv(os.path.join(orig_path, 'train.csv'))
+    # 显示数据的前几行
+    print(df.head())
+    points = []
+    with open(label_path, 'r') as reader:
+        lines = reader.readlines()
+        for line in lines:
+            parts = line.strip().split()
+            # print(f'parts:{parts}')
+            class_id = int(parts[0])
+            x_array = parts[1::2]
+            y_array = parts[2::2]
+
+            for x, y in zip(x_array, y_array):
+                x = float(x)
+                y = float(y)
+                points.append((int(y * 255), int(x * 1600)))
+            # points = np.array([[float(parts[i]), float(parts[i + 1])] for i in range(1, len(parts), 2)])
+            # mask_resized = cv2.resize(points, (1600, 256), interpolation=cv2.INTER_NEAREST)
+            print(f'points:{points}')
+            # mask[points[:,0],points[:,1]]=255
+            for p in points:
+                mask[p] = 255
+            # cv2.fillPoly(mask, points, color=(255,))
+    cv2.imshow('mask', mask)
+    for row in df.itertuples():
+        img_name = name + '.jpg'
+        if img_name == getattr(row, 'ImageId'):
+            img = PIL.Image.open(img_path)
+            height, width = img.size
+            print(f'img size:{img.size}')
+            label, mask = compute_mask(row, img.size)
+            plt.imshow(img)
+            plt.imshow(mask, cmap='Reds', alpha=0.3)
+            plt.show()
+    cv2.waitKey(0)
+
+def show_img_mask(img_path):
+    test_img = PIL.Image.open(img_path)
+
+    w,h=test_img.size
+    test_img=torchvision.transforms.ToTensor()(test_img)
+    test_img=test_img.permute(1, 2, 0)
+    print(f'test_img shape:{test_img.shape}')
+    lbl_path=re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
+    # print(f'lbl_path:{lbl_path}')
+    masks = []
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = reader.readlines()
+        # 为每个簇分配颜色
+        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(lines))])
+        print(f'colors:{colors*255}')
+        mask_points = []
+        for line ,col in zip(lines,colors):
+            print(f'col:{np.array(col[:3]) * 255}')
+            mask = torch.zeros(test_img.shape, dtype=torch.uint8)
+            # print(f'mask shape:{mask.shape}')
+            parts = line.strip().split()
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
+            labels.append(cls)
+            x_array = parts[1::2]
+            y_array = parts[2::2]
+
+            for x, y in zip(x_array, y_array):
+                x = float(x)
+                y = float(y)
+                mask_points.append((int(y * h), int(x * w)))
+            for p in mask_points:
+                # print(f'p:{p}')
+                mask[p] = torch.tensor(np.array(col[:3])*255)
+            masks.append(mask)
+    reader.close()
+    target = {}
+
+    # target["boxes"] = masks_to_boxes(torch.stack(masks))
+
+    # target["labels"] = torch.stack(labels)
+
+    target["masks"] = torch.stack(masks)
+    print(f'target:{target}')
+
+    # plt.imshow(test_img.permute(1, 2, 0))
+    fig, axs = plt.subplots(2, 1)
+    print(f'test_img:{test_img*255}')
+    axs[0].imshow(test_img)
+    axs[0].axis('off')
+    axs[1].axis('off')
+    axs[1].imshow(test_img*255)
+    for img_mask in target['masks']:
+        # img_mask=img_mask.unsqueeze(0)
+        # img_mask = img_mask.expand_as(test_img)
+        # print(f'img_mask:{img_mask.shape}')
+        axs[1].imshow(img_mask,alpha=0.3)
+
+        # img_mask=np.array(img_mask)
+        # print(f'img_mask:{img_mask.shape}')
+        # plt.imshow(img_mask,alpha=0.5)
+        # mask_3channel = cv2.merge([np.zeros_like(img_mask), np.zeros_like(img_mask), img_mask])
+        # masked_image = cv2.addWeighted(test_img, 1, mask_3channel, 0.6, 0)
+
+    # cv2.imshow('cv2 mask img', masked_image)
+    # cv2.waitKey(0)
+    plt.show()
+def show_dataset():
+    global transforms, dataset, imgs
+    transforms = v2.Compose([
+        # v2.RandomResizedCrop(size=(224, 224), antialias=True),
+        # v2.RandomPhotometricDistort(p=1),
+        # v2.RandomHorizontalFlip(p=1),
+        v2.ToTensor()
+    ])
+    dataset = MaskRCNNDataset(dataset_path=r'F:\Downloads\severstal-steel-defect-detection', transforms=transforms,
+                              dataset_type='train')
+    dataloader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=utils.collate_fn)
+    imgs, targets = next(iter(dataloader))
+
+    mask = np.array(targets[2]['masks'][0])
+    boxes = targets[2]['boxes']
+    print(f'boxes:{boxes}')
+    # mask[mask == 255] = 1
+    img = np.array(imgs[2].permute(1, 2, 0)) * 255
+    img = img.astype(np.uint8)
+    print(f'img shape:{img.shape}')
+    print(f'mask:{mask.shape}')
+    # print(f'target:{targets}')
+    # print(f'imgs:{imgs[0]}')
+    # print(f'cv2 img shape:{np.array(imgs[0]).shape}')
+    # cv2.imshow('cv2 img',img)
+    # cv2.imshow('cv2 mask', mask)
+    # plt.imshow('mask',mask)
+    mask_3channel = cv2.merge([np.zeros_like(mask), np.zeros_like(mask), mask])
+    # cv2.imshow('mask_3channel',mask_3channel)
+    print(f'mask_3channel:{mask_3channel.shape}')
+    masked_image = cv2.addWeighted(img, 1, mask_3channel, 0.6, 0)
+    # cv2.imshow('cv2 mask img', masked_image)
+    plt.imshow(imgs[0].permute(1, 2, 0))
+    plt.imshow(mask, cmap='Reds', alpha=0.3)
+    drawn_boxes = draw_bounding_boxes((imgs[2] * 255).to(torch.uint8), boxes, colors="red", width=5)
+    plt.imshow(drawn_boxes.permute(1, 2, 0))
+    # show(drawn_boxes)
+    plt.show()
+    cv2.waitKey(0)
+
+def test_cluster(img_path):
+    test_img = PIL.Image.open(img_path)
+    w, h = test_img.size
+    test_img = torchvision.transforms.ToTensor()(test_img)
+    test_img=(test_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
+    # print(f'test_img:{test_img}')
+    lbl_path = re.sub(r'\\images\\', r'\\labels\\', img_path[:-3]) + 'txt'
+    # print(f'lbl_path:{lbl_path}')
+    masks = []
+    labels = []
+    with open(lbl_path, 'r') as reader:
+        lines = reader.readlines()
+        mask_points = []
+        for line in lines:
+            mask = torch.zeros((h, w), dtype=torch.uint8)
+            parts = line.strip().split()
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
+            labels.append(cls)
+            x_array = parts[1::2]
+            y_array = parts[2::2]
+
+            for x, y in zip(x_array, y_array):
+                x = float(x)
+                y = float(y)
+                mask_points.append((int(y * h), int(x * w)))
+            for p in mask_points:
+                mask[p] = 255
+            masks.append(mask)
+    # print(f'masks:{masks}')
+    labels,clustered_points=cluster_dbscan(masks[0].numpy(),test_img)
+    print(f'labels:{labels}')
+    print(f'clustered_points len:{len(clustered_points)}')
+    show_cluster_dbscan(masks[0].numpy(),test_img,labels,clustered_points)
+
+if __name__ == '__main__':
+    # trans_datasets_format()
+    # test_mask()
+    # 定义转换
+    # show_dataset()
+
+    # test_img_path= r"F:\Downloads\severstal-steel-defect-detection\images\train\0025bde0c.jpg"
+    test_img_path = r"F:\DevTools\datasets\renyaun\1012\spilt\images\train\2024-09-27-14-32-53_SaveImage.png"
+    # test_img1_path=r"F:\Downloads\severstal-steel-defect-detection\images\train\1d00226a0.jpg"
+    show_img_mask(test_img_path)
+    #
+    # test_cluster(test_img_path)

+ 0 - 0
models/__init__.py


+ 0 - 0
models/base/__init__.py


+ 48 - 0
models/base/base_dataset.py

@@ -0,0 +1,48 @@
+from abc import ABC, abstractmethod
+
+import torch
+from torch import nn, Tensor
+from torch.utils.data import Dataset
+from torch.utils.data.dataset import T_co
+
+from torchvision.transforms import  functional as F
+
+class BaseDataset(Dataset, ABC):
+    def __init__(self,dataset_path):
+        self.default_transform=DefaultTransform()
+        pass
+
+    def __getitem__(self, index) -> T_co:
+        pass
+
+    @abstractmethod
+    def read_target(self,item,lbl_path,extra=None):
+        pass
+
+    """显示数据集指定图片"""
+    @abstractmethod
+    def show(self,idx):
+        pass
+
+    """
+    显示数据集指定名字的图片
+    """
+
+    @abstractmethod
+    def show_img(self,img_path):
+        pass
+
+class DefaultTransform(nn.Module):
+    def forward(self, img: Tensor) -> Tensor:
+        if not isinstance(img, Tensor):
+            img = F.pil_to_tensor(img)
+        return F.convert_image_dtype(img, torch.float)
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__ + "()"
+
+    def describe(self) -> str:
+        return (
+            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
+            "The images are rescaled to ``[0.0, 1.0]``."
+        )

+ 0 - 0
models/config/__init__.py


+ 22 - 0
models/config/config_tool.py

@@ -0,0 +1,22 @@
+import yaml
+
+
+def read_yaml(path='application.yaml'):
+    try:
+        with open(path, 'r') as file:
+            data = file.read()
+            # result = yaml.load(data)
+            result = yaml.load(data, Loader=yaml.FullLoader)
+
+            return result
+    except Exception as e:
+        print(e)
+        return None
+
+
+def write_yaml(path='application.yaml', data=None):
+    try:
+        with open(path, 'w', encoding='utf-8') as f:
+            yaml.dump(data=data, stream=f, allow_unicode=True)
+    except Exception as e:
+        print(e)

+ 42 - 0
models/config/test_config.py

@@ -0,0 +1,42 @@
+import yaml
+
+test_data = {
+    'cameras': [{
+        'id': 1,
+        'ip': "192.168.1.2"
+    }, {
+        'id': 2,
+        'ip': "192.168.1.3"
+    }]
+}
+
+
+def read_yaml(path):
+    try:
+        with open(path, 'r') as file:
+            data = file.read()
+            # result = yaml.load(data)
+            result = yaml.load(data, Loader=yaml.FullLoader)
+
+            return result
+    except Exception as e:
+        print(e)
+        return None
+
+
+def write_yaml(path):
+    try:
+        with open('path', 'w', encoding='utf-8') as f:
+            yaml.dump(data=test_data, stream=f, allow_unicode=True)
+    except Exception as e:
+        print(e)
+
+
+if __name__ == '__main__':
+    p = 'train.yaml'
+    result = read_yaml(p)
+    # j=json.load(result)
+    print('result', result)
+    # print('cameras', result['cameras'])
+    # print('json',j)
+

+ 34 - 0
models/config/train.yaml

@@ -0,0 +1,34 @@
+
+
+# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
+#train: images/train  # train images (relative to 'path') 128 images
+#val: images/train  # val images (relative to 'path') 128 images
+#test: images/test  # test images (optional)
+
+#train parameters
+num_classes: 5
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: polygon
+enable_logs: True
+augmentation: True
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 281 - 0
models/dataset_tool.py

@@ -0,0 +1,281 @@
+import cv2
+import numpy as np
+import torch
+import torchvision
+from matplotlib import pyplot as plt
+import tools.transforms as reference_transforms
+from collections import defaultdict
+
+from tools import presets
+
+import json
+
+
+def get_modules(use_v2):
+    # We need a protected import to avoid the V2 warning in case just V1 is used
+    if use_v2:
+        import torchvision.transforms.v2
+        import torchvision.tv_tensors
+
+        return torchvision.transforms.v2, torchvision.tv_tensors
+    else:
+        return reference_transforms, None
+
+
+class Augmentation:
+    # Note: this transform assumes that the input to forward() are always PIL
+    # images, regardless of the backend parameter.
+    def __init__(
+            self,
+            *,
+            data_augmentation,
+            hflip_prob=0.5,
+            mean=(123.0, 117.0, 104.0),
+            backend="pil",
+            use_v2=False,
+    ):
+
+        T, tv_tensors = get_modules(use_v2)
+
+        transforms = []
+        backend = backend.lower()
+        if backend == "tv_tensor":
+            transforms.append(T.ToImage())
+        elif backend == "tensor":
+            transforms.append(T.PILToTensor())
+        elif backend != "pil":
+            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
+
+        if data_augmentation == "hflip":
+            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
+        elif data_augmentation == "lsj":
+            transforms += [
+                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
+                # TODO: FixedSizeCrop below doesn't work on tensors!
+                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "multiscale":
+            transforms += [
+                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssd":
+            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
+            transforms += [
+                T.RandomPhotometricDistort(),
+                T.RandomZoomOut(fill=fill),
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssdlite":
+            transforms += [
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        else:
+            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
+
+        if backend == "pil":
+            # Note: we could just convert to pure tensors even in v2.
+            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
+
+        transforms += [T.ToDtype(torch.float, scale=True)]
+
+        if use_v2:
+            transforms += [
+                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
+                T.SanitizeBoundingBoxes(),
+                T.ToPureTensor(),
+            ]
+
+        self.transforms = T.Compose(transforms)
+
+    def __call__(self, img, target):
+        return self.transforms(img, target)
+
+
+def read_polygon_points(lbl_path, shape):
+    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
+    polygon_points = []
+    w, h = shape[:2]
+    with open(lbl_path, 'r') as f:
+        lines = f.readlines()
+
+    for line in lines:
+        parts = line.strip().split()
+        class_id = int(parts[0])
+        points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2)  # 读取点坐标
+        points[:, 0] *= h
+        points[:, 1] *= w
+
+        polygon_points.append((class_id, points))
+
+    return polygon_points
+
+
+def read_masks_from_pixels(lbl_path, shape):
+    """读取纯像素点格式的文件,不是轮廓像素点"""
+    h, w = shape
+    masks = []
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = reader.readlines()
+        mask_points = []
+        for line in lines:
+            mask = torch.zeros((h, w), dtype=torch.uint8)
+            parts = line.strip().split()
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
+            labels.append(cls)
+            x_array = parts[1::2]
+            y_array = parts[2::2]
+
+            for x, y in zip(x_array, y_array):
+                x = float(x)
+                y = float(y)
+                mask_points.append((int(y * h), int(x * w)))
+
+            for p in mask_points:
+                mask[p] = 1
+            masks.append(mask)
+    reader.close()
+    return labels, masks
+
+
+def create_masks_from_polygons(polygons, image_shape):
+    """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+    colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+    masks = []
+
+    for polygon_data, col in zip(polygons, colors):
+        mask = np.zeros(image_shape[:2], dtype=np.uint8)
+        # 将多边形顶点转换为 NumPy 数组
+        _, polygon = polygon_data
+        pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+        # 使用 OpenCV 的 fillPoly 函数填充多边形
+        # print(f'color:{col[:3]}')
+        cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+        mask = torch.from_numpy(mask)
+        mask[mask != 0] = 1
+        masks.append(mask)
+
+    return masks
+
+
+def read_masks_from_txt(label_path, shape):
+    polygon_points = read_polygon_points(label_path, shape)
+    masks = create_masks_from_polygons(polygon_points, shape)
+    labels = [torch.tensor(item[0]) for item in polygon_points]
+
+    return labels, masks
+
+
+def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
+    """
+    Compute the bounding boxes around the provided masks.
+
+    Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
+    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+    Args:
+        masks (Tensor[N, H, W]): masks to transform where N is the number of masks
+            and (H, W) are the spatial dimensions.
+
+    Returns:
+        Tensor[N, 4]: bounding boxes
+    """
+    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+    #     _log_api_usage_once(masks_to_boxes)
+    if masks.numel() == 0:
+        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
+
+    n = masks.shape[0]
+
+    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
+
+    for index, mask in enumerate(masks):
+        y, x = torch.where(mask != 0)
+        bounding_boxes[index, 0] = torch.min(x)
+        bounding_boxes[index, 1] = torch.min(y)
+        bounding_boxes[index, 2] = torch.max(x)
+        bounding_boxes[index, 3] = torch.max(y)
+        # debug to pixel datasets
+
+        if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
+            bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
+            bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
+
+        if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
+            bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
+            bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
+
+    return bounding_boxes
+
+
+def read_polygon_points_wire(lbl_path, shape):
+    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
+    polygon_points = []
+    w, h = shape[:2]
+    with open(lbl_path, 'r') as f:
+        lines = json.load(f)
+
+    for line in lines["segmentations"]:
+        parts = line["data"]
+        class_id = int(line["cls_id"])
+        points = np.array(parts, dtype=np.float32).reshape(-1, 2)  # 读取点坐标
+        points[:, 0] *= h
+        points[:, 1] *= w
+
+        polygon_points.append((class_id, points))
+
+    return polygon_points
+
+
+def read_masks_from_txt_wire(label_path, shape):
+    polygon_points = read_polygon_points_wire(label_path, shape)
+    masks = create_masks_from_polygons(polygon_points, shape)
+    labels = [torch.tensor(item[0]) for item in polygon_points]
+
+    return labels, masks
+
+
+def read_masks_from_pixels_wire(lbl_path, shape):
+    """读取纯像素点格式的文件,不是轮廓像素点"""
+    h, w = shape
+    masks = []
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = json.load(reader)
+        mask_points = []
+        for line in lines["segmentations"]:
+            mask = torch.zeros((h, w), dtype=torch.uint8)
+            parts = line["data"]
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
+            labels.append(cls)
+            x_array = parts[0::2]
+            y_array = parts[1::2]
+
+            for x, y in zip(x_array, y_array):
+                x = float(x)
+                y = float(y)
+                mask_points.append((int(y * h), int(x * w)))
+
+            for p in mask_points:
+                mask[p] = 1
+            masks.append(mask)
+    reader.close()
+    return labels, masks
+
+
+def adjacency_matrix(n, link):  # 邻接矩阵
+    mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+    link = torch.tensor(link)
+    if len(link) > 0:
+        mat[link[:, 0], link[:, 1]] = 1
+        mat[link[:, 1], link[:, 0]] = 1
+    return mat

+ 0 - 0
models/ins/__init__.py


+ 142 - 0
models/ins/maskrcnn.py

@@ -0,0 +1,142 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.ins.trainer import train_cfg
+from tools import utils
+
+
+class MaskRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=0, transforms=None):
+        super(MaskRCNNModel, self).__init__()
+        self.__model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
+            weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
+        if transforms is None:
+            self.transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        if num_classes != 0:
+            self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes=parameters['num_classes']
+        # print(f'num_classes:{num_classes}')
+        self.set_num_classes(num_classes)
+        train_cfg(self.__model, cfg)
+
+    def set_num_classes(self, num_classes):
+        in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+        self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+        in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+        hidden_layer = 256
+        self.__model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer,
+                                                                  num_classes=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+    def predict(self, src, show_box=True, show_mask=True):
+        self.__model.eval()
+
+        img = read_image(src)
+        img = self.transforms(img)
+        img = img.to(self.device)
+        result = self.__model([img])
+        print(f'result:{result}')
+        masks = result[0]['masks']
+        boxes = result[0]['boxes']
+        # cv2.imshow('mask',masks[0].cpu().detach().numpy())
+        boxes = boxes.cpu().detach()
+        drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
+        print(f'drawn_boxes:{drawn_boxes.shape}')
+        boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
+        # boxed_img=cv2.resize(boxed_img,(800,800))
+        # cv2.imshow('boxes',boxed_img)
+
+        mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
+
+        mask = cv2.resize(mask, (800, 800))
+        # cv2.imshow('mask',mask)
+        img = img.cpu().detach().permute(1, 2, 0).numpy()
+
+        masked_img = self.overlay_masks_on_image(boxed_img, masks)
+        masked_img = cv2.resize(masked_img, (800, 800))
+        cv2.imshow('img_masks', masked_img)
+        # show_img_boxes_masks(img, boxes, masks)
+        cv2.waitKey(0)
+
+    def generate_colors(self, n):
+        """
+        生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
+
+        :param n: 需要的颜色数量
+        :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
+        """
+        hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
+        bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
+        return bgr_colors
+
+    def overlay_masks_on_image(self, image, masks, alpha=0.6):
+        """
+        在原图上叠加多个掩码,每个掩码使用不同的颜色。
+
+        :param image: 原图 (NumPy 数组)
+        :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
+        :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
+        :param alpha: 掩码的透明度 (0.0 到 1.0)
+        :return: 叠加了多个掩码的图像
+        """
+        colors = self.generate_colors(len(masks))
+        if len(masks) != len(colors):
+            raise ValueError("The number of masks and colors must be the same.")
+
+        # 复制原图,避免修改原始图像
+        overlay = image.copy()
+
+        for mask, color in zip(masks, colors):
+            # 确保掩码是二值图像
+            mask = mask.cpu().detach().permute(1, 2, 0).numpy()
+            binary_mask = (mask > 0).astype(np.uint8) * 255  # 你可以根据实际情况调整阈值
+
+            # 创建彩色掩码
+            colored_mask = np.zeros_like(image)
+            colored_mask[:] = color
+            colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
+
+            # 将彩色掩码与当前的叠加图像混合
+            overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
+
+        return overlay
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    ins_model = MaskRCNNModel()
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    ins_model.train(cfg='train.yaml')

+ 93 - 0
models/ins/maskrcnn_dataset.py

@@ -0,0 +1,93 @@
+import os
+
+import PIL
+import cv2
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.data import Dataset
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.dataset_tool import masks_to_boxes, read_masks_from_txt, read_masks_from_pixels
+
+
+class MaskRCNNDataset(Dataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='polygon'):
+        self.data_path = dataset_path
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        self.deafult_transform= MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+        # print('maskrcnn inited!')
+
+    def __getitem__(self, item):
+        # print('__getitem__')
+        img_path = os.path.join(self.img_path, self.imgs[item])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[item][:-3] + 'txt')
+        img = PIL.Image.open(img_path).convert('RGB')
+        # h, w = np.array(img).shape[:2]
+        w, h = img.size
+        # print(f'h,w:{h, w}')
+        target = self.read_target(item=item, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img,target)
+        else:
+            img=self.deafult_transform(img)
+        # print(f'img:{img.shape},target:{target}')
+        return img, target
+
+    def create_masks_from_polygons(self, polygons, image_shape):
+        """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+        colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+        masks = []
+
+        for polygon_data, col in zip(polygons, colors):
+            mask = np.zeros(image_shape[:2], dtype=np.uint8)
+            # 将多边形顶点转换为 NumPy 数组
+            _, polygon = polygon_data
+            pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+            # 使用 OpenCV 的 fillPoly 函数填充多边形
+            # print(f'color:{col[:3]}')
+            cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+            mask = torch.from_numpy(mask)
+            mask[mask != 0] = 1
+            masks.append(mask)
+
+        return masks
+
+    def read_target(self, item, lbl_path, shape):
+        # print(f'lbl_path:{lbl_path}')
+        h, w = shape
+        labels = []
+        masks = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels, masks = read_masks_from_pixels(lbl_path, shape)
+
+        target = {}
+        target["boxes"] = masks_to_boxes(torch.stack(masks))
+        target["labels"] = torch.stack(labels)
+        target["masks"] = torch.stack(masks)
+        target["image_id"] = torch.tensor(item)
+        target["area"] = torch.zeros(len(masks))
+        target["iscrowd"] = torch.zeros(len(masks))
+        return target
+
+    def heatmap_enhance(self, img):
+        # 直方图均衡化
+        img_eq = cv2.equalizeHist(img)
+
+        # 自适应直方图均衡化
+        # clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
+        # img_clahe = clahe.apply(img)
+
+        # 将灰度图转换为热力图
+        heatmap = cv2.applyColorMap(img_eq, cv2.COLORMAP_HOT)
+
+    def __len__(self):
+        return len(self.imgs)

+ 31 - 0
models/ins/train.yaml

@@ -0,0 +1,31 @@
+
+
+dataset_path: F:\DevTools\datasets\renyaun\1012\spilt
+
+#train parameters
+num_classes: 5
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: polygon
+enable_logs: True
+augmentation: True
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 219 - 0
models/ins/trainer.py

@@ -0,0 +1,219 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from tools import utils, presets
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+
+def load_train_parameter(cfg):
+    parameters = read_yaml(cfg)
+    return parameters
+
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 10,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = MaskRCNNDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 0 - 0
models/wirenet/__init__.py


+ 548 - 0
models/wirenet/_utils.py

@@ -0,0 +1,548 @@
+import math
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
+
+
+class BalancedPositiveNegativeSampler:
+    """
+    This class samples batches, ensuring that they contain a fixed proportion of positives
+    """
+
+    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
+        """
+        Args:
+            batch_size_per_image (int): number of elements to be selected per image
+            positive_fraction (float): percentage of positive elements per batch
+        """
+        self.batch_size_per_image = batch_size_per_image
+        self.positive_fraction = positive_fraction
+
+    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+        """
+        Args:
+            matched_idxs: list of tensors containing -1, 0 or positive values.
+                Each tensor corresponds to a specific image.
+                -1 values are ignored, 0 are considered as negatives and > 0 as
+                positives.
+
+        Returns:
+            pos_idx (list[tensor])
+            neg_idx (list[tensor])
+
+        Returns two lists of binary masks for each image.
+        The first list contains the positive elements that were selected,
+        and the second list the negative example.
+        """
+        pos_idx = []
+        neg_idx = []
+        for matched_idxs_per_image in matched_idxs:
+            positive = torch.where(matched_idxs_per_image >= 1)[0]
+            negative = torch.where(matched_idxs_per_image == 0)[0]
+
+            num_pos = int(self.batch_size_per_image * self.positive_fraction)
+            # protect against not enough positive examples
+            num_pos = min(positive.numel(), num_pos)
+            num_neg = self.batch_size_per_image - num_pos
+            # protect against not enough negative examples
+            num_neg = min(negative.numel(), num_neg)
+
+            # randomly select positive and negative examples
+            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+            pos_idx_per_image = positive[perm1]
+            neg_idx_per_image = negative[perm2]
+
+            # create binary mask from indices
+            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+
+            pos_idx_per_image_mask[pos_idx_per_image] = 1
+            neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+            pos_idx.append(pos_idx_per_image_mask)
+            neg_idx.append(neg_idx_per_image_mask)
+
+        return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
+    """
+    Encode a set of proposals with respect to some
+    reference boxes
+
+    Args:
+        reference_boxes (Tensor): reference boxes
+        proposals (Tensor): boxes to be encoded
+        weights (Tensor[4]): the weights for ``(x, y, w, h)``
+    """
+
+    # perform some unpacking to make it JIT-fusion friendly
+    wx = weights[0]
+    wy = weights[1]
+    ww = weights[2]
+    wh = weights[3]
+
+    proposals_x1 = proposals[:, 0].unsqueeze(1)
+    proposals_y1 = proposals[:, 1].unsqueeze(1)
+    proposals_x2 = proposals[:, 2].unsqueeze(1)
+    proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+    # implementation starts here
+    ex_widths = proposals_x2 - proposals_x1
+    ex_heights = proposals_y2 - proposals_y1
+    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+    ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+    gt_widths = reference_boxes_x2 - reference_boxes_x1
+    gt_heights = reference_boxes_y2 - reference_boxes_y1
+    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+    targets_dw = ww * torch.log(gt_widths / ex_widths)
+    targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+    return targets
+
+
+class BoxCoder:
+    """
+    This class encodes and decodes a set of bounding boxes into
+    the representation used for training the regressors.
+    """
+
+    def __init__(
+        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
+    ) -> None:
+        """
+        Args:
+            weights (4-element tuple)
+            bbox_xform_clip (float)
+        """
+        self.weights = weights
+        self.bbox_xform_clip = bbox_xform_clip
+
+    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
+        boxes_per_image = [len(b) for b in reference_boxes]
+        reference_boxes = torch.cat(reference_boxes, dim=0)
+        proposals = torch.cat(proposals, dim=0)
+        targets = self.encode_single(reference_boxes, proposals)
+        return targets.split(boxes_per_image, 0)
+
+    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some
+        reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+        """
+        dtype = reference_boxes.dtype
+        device = reference_boxes.device
+        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+        targets = encode_boxes(reference_boxes, proposals, weights)
+
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
+        torch._assert(
+            isinstance(boxes, (list, tuple)),
+            "This function expects boxes of type list or tuple.",
+        )
+        torch._assert(
+            isinstance(rel_codes, torch.Tensor),
+            "This function expects rel_codes of type torch.Tensor.",
+        )
+        boxes_per_image = [b.size(0) for b in boxes]
+        concat_boxes = torch.cat(boxes, dim=0)
+        box_sum = 0
+        for val in boxes_per_image:
+            box_sum += val
+        if box_sum > 0:
+            rel_codes = rel_codes.reshape(box_sum, -1)
+        pred_boxes = self.decode_single(rel_codes, concat_boxes)
+        if box_sum > 0:
+            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+        return pred_boxes
+
+    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+        """
+
+        boxes = boxes.to(rel_codes.dtype)
+
+        widths = boxes[:, 2] - boxes[:, 0]
+        heights = boxes[:, 3] - boxes[:, 1]
+        ctr_x = boxes[:, 0] + 0.5 * widths
+        ctr_y = boxes[:, 1] + 0.5 * heights
+
+        wx, wy, ww, wh = self.weights
+        dx = rel_codes[:, 0::4] / wx
+        dy = rel_codes[:, 1::4] / wy
+        dw = rel_codes[:, 2::4] / ww
+        dh = rel_codes[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=self.bbox_xform_clip)
+        dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        # Distance from center to box's corner.
+        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+
+        pred_boxes1 = pred_ctr_x - c_to_c_w
+        pred_boxes2 = pred_ctr_y - c_to_c_h
+        pred_boxes3 = pred_ctr_x + c_to_c_w
+        pred_boxes4 = pred_ctr_y + c_to_c_h
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+        return pred_boxes
+
+
+class BoxLinearCoder:
+    """
+    The linear box-to-box transform defined in FCOS. The transformation is parameterized
+    by the distance from the center of (square) src box to 4 edges of the target box.
+    """
+
+    def __init__(self, normalize_by_size: bool = True) -> None:
+        """
+        Args:
+            normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
+        """
+        self.normalize_by_size = normalize_by_size
+
+    def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+
+        Returns:
+            Tensor: the encoded relative box offsets that can be used to
+            decode the boxes.
+
+        """
+
+        # get the center of reference_boxes
+        reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
+        reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
+
+        # get box regression transformation deltas
+        target_l = reference_boxes_ctr_x - proposals[..., 0]
+        target_t = reference_boxes_ctr_y - proposals[..., 1]
+        target_r = proposals[..., 2] - reference_boxes_ctr_x
+        target_b = proposals[..., 3] - reference_boxes_ctr_y
+
+        targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
+
+        if self.normalize_by_size:
+            reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
+            reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
+            reference_boxes_size = torch.stack(
+                (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
+            )
+            targets = targets / reference_boxes_size
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+
+        Returns:
+            Tensor: the predicted boxes with the encoded relative box offsets.
+
+        .. note::
+            This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
+
+        """
+
+        boxes = boxes.to(dtype=rel_codes.dtype)
+
+        ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
+        ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
+
+        if self.normalize_by_size:
+            boxes_w = boxes[..., 2] - boxes[..., 0]
+            boxes_h = boxes[..., 3] - boxes[..., 1]
+
+            list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
+            rel_codes = rel_codes * list_box_size
+
+        pred_boxes1 = ctr_x - rel_codes[..., 0]
+        pred_boxes2 = ctr_y - rel_codes[..., 1]
+        pred_boxes3 = ctr_x + rel_codes[..., 2]
+        pred_boxes4 = ctr_y + rel_codes[..., 3]
+
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
+        return pred_boxes
+
+
+class Matcher:
+    """
+    This class assigns to each predicted "element" (e.g., a box) a ground-truth
+    element. Each predicted element will have exactly zero or one matches; each
+    ground-truth element may be assigned to zero or more predicted elements.
+
+    Matching is based on the MxN match_quality_matrix, that characterizes how well
+    each (ground-truth, predicted)-pair match. For example, if the elements are
+    boxes, the matrix may contain box IoU overlap values.
+
+    The matcher returns a tensor of size N containing the index of the ground-truth
+    element m that matches to prediction n. If there is no match, a negative value
+    is returned.
+    """
+
+    BELOW_LOW_THRESHOLD = -1
+    BETWEEN_THRESHOLDS = -2
+
+    __annotations__ = {
+        "BELOW_LOW_THRESHOLD": int,
+        "BETWEEN_THRESHOLDS": int,
+    }
+
+    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
+        """
+        Args:
+            high_threshold (float): quality values greater than or equal to
+                this value are candidate matches.
+            low_threshold (float): a lower quality threshold used to stratify
+                matches into three levels:
+                1) matches >= high_threshold
+                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+            allow_low_quality_matches (bool): if True, produce additional matches
+                for predictions that have only low-quality match candidates. See
+                set_low_quality_matches_ for more details.
+        """
+        self.BELOW_LOW_THRESHOLD = -1
+        self.BETWEEN_THRESHOLDS = -2
+        torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
+        self.high_threshold = high_threshold
+        self.low_threshold = low_threshold
+        self.allow_low_quality_matches = allow_low_quality_matches
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        """
+        Args:
+            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+            pairwise quality between M ground-truth elements and N predicted elements.
+
+        Returns:
+            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+            [0, M - 1] or a negative value indicating that prediction i could not
+            be matched.
+        """
+        if match_quality_matrix.numel() == 0:
+            # empty targets or proposals not supported during training
+            if match_quality_matrix.shape[0] == 0:
+                raise ValueError("No ground-truth boxes available for one of the images during training")
+            else:
+                raise ValueError("No proposal boxes available for one of the images during training")
+
+        # match_quality_matrix is M (gt) x N (predicted)
+        # Max over gt elements (dim 0) to find best gt candidate for each prediction
+        matched_vals, matches = match_quality_matrix.max(dim=0)
+        if self.allow_low_quality_matches:
+            all_matches = matches.clone()
+        else:
+            all_matches = None  # type: ignore[assignment]
+
+        # Assign candidate matches with low quality to negative (unassigned) values
+        below_low_threshold = matched_vals < self.low_threshold
+        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
+        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
+        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
+
+        if self.allow_low_quality_matches:
+            if all_matches is None:
+                torch._assert(False, "all_matches should not be None")
+            else:
+                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+        return matches
+
+    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
+        """
+        Produce additional matches for predictions that have only low-quality matches.
+        Specifically, for each ground-truth find the set of predictions that have
+        maximum overlap with it (including ties); for each prediction in that set, if
+        it is unmatched, then match it to the ground-truth with which it has the highest
+        quality value.
+        """
+        # For each gt, find the prediction with which it has the highest quality
+        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+        # Find the highest quality match available, even if it is low, including ties
+        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
+        # Example gt_pred_pairs_of_highest_quality:
+        #   tensor([[    0, 39796],
+        #           [    1, 32055],
+        #           [    1, 32070],
+        #           [    2, 39190],
+        #           [    2, 40255],
+        #           [    3, 40390],
+        #           [    3, 41455],
+        #           [    4, 45470],
+        #           [    5, 45325],
+        #           [    5, 46390]])
+        # Each row is a (gt index, prediction index)
+        # Note how gt items 1, 2, 3, and 5 each have two ties
+
+        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
+
+
+class SSDMatcher(Matcher):
+    def __init__(self, threshold: float) -> None:
+        super().__init__(threshold, threshold, allow_low_quality_matches=False)
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        matches = super().__call__(match_quality_matrix)
+
+        # For each gt, find the prediction with which it has the highest quality
+        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
+        matches[highest_quality_pred_foreach_gt] = torch.arange(
+            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
+        )
+
+        return matches
+
+
+def overwrite_eps(model: nn.Module, eps: float) -> None:
+    """
+    This method overwrites the default eps values of all the
+    FrozenBatchNorm2d layers of the model with the provided value.
+    This is necessary to address the BC-breaking change introduced
+    by the bug-fix at pytorch/vision#2933. The overwrite is applied
+    only when the pretrained weights are loaded to maintain compatibility
+    with previous versions.
+
+    Args:
+        model (nn.Module): The model on which we perform the overwrite.
+        eps (float): The new value of eps.
+    """
+    for module in model.modules():
+        if isinstance(module, FrozenBatchNorm2d):
+            module.eps = eps
+
+
+def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
+    """
+    This method retrieves the number of output channels of a specific model.
+
+    Args:
+        model (nn.Module): The model for which we estimate the out_channels.
+            It should return a single Tensor or an OrderedDict[Tensor].
+        size (Tuple[int, int]): The size (wxh) of the input.
+
+    Returns:
+        out_channels (List[int]): A list of the output channels of the model.
+    """
+    in_training = model.training
+    model.eval()
+
+    with torch.no_grad():
+        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
+        device = next(model.parameters()).device
+        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
+        features = model(tmp_img)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        out_channels = [x.size(1) for x in features.values()]
+
+    if in_training:
+        model.train()
+
+    return out_channels
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> int:
+    return v  # type: ignore[return-value]
+
+
+def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
+    """
+    ONNX spec requires the k-value to be less than or equal to the number of inputs along
+    provided dim. Certain models use the number of elements along a particular axis instead of K
+    if K exceeds the number of elements along that axis. Previously, python's min() function was
+    used to determine whether to use the provided k-value or the specified dim axis value.
+
+    However, in cases where the model is being exported in tracing mode, python min() is
+    static causing the model to be traced incorrectly and eventually fail at the topk node.
+    In order to avoid this situation, in tracing mode, torch.min() is used instead.
+
+    Args:
+        input (Tensor): The original input tensor.
+        orig_kval (int): The provided k-value.
+        axis(int): Axis along which we retrieve the input size.
+
+    Returns:
+        min_kval (int): Appropriately selected k-value.
+    """
+    if not torch.jit.is_tracing():
+        return min(orig_kval, input.size(axis))
+    axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
+    min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
+    return _fake_cast_onnx(min_kval)
+
+
+def _box_loss(
+    type: str,
+    box_coder: BoxCoder,
+    anchors_per_image: Tensor,
+    matched_gt_boxes_per_image: Tensor,
+    bbox_regression_per_image: Tensor,
+    cnf: Optional[Dict[str, float]] = None,
+) -> Tensor:
+    torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
+
+    if type == "l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+    elif type == "smooth_l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
+        return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
+    else:
+        bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
+        eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
+        if type == "ciou":
+            return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        if type == "diou":
+            return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        # otherwise giou
+        return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

+ 1193 - 0
models/wirenet/head.py

@@ -0,0 +1,1193 @@
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+from torch.utils.data.dataloader import default_collate
+
+
+def l2loss(input, target):
+    return ((target - input) ** 2).mean(2).mean(1)
+
+
+def cross_entropy_loss(logits, positive):
+    nlogp = -F.log_softmax(logits, dim=0)
+    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
+
+
+def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
+    logp = torch.sigmoid(logits) + offset
+    loss = torch.abs(logp - target)
+    if mask is not None:
+        w = mask.mean(2, True).mean(1, True)
+        w[w == 0] = 1
+        loss = loss * (mask / w)
+
+    return loss.mean(2).mean(1)
+
+
+# def wirepoint_loss(target, outputs, feature, loss_weight,mode):
+#     wires = target['wires']
+#     result = {"feature": feature}
+#     batch, channel, row, col = outputs[0].shape
+#     print(f"Initial Output[0] shape: {outputs[0].shape}")  # 打印初始输出形状
+#     print(f"Total Stacks: {len(outputs)}")  # 打印堆栈数
+#
+#     T = wires.copy()
+#     n_jtyp = T["junc_map"].shape[1]
+#     for task in ["junc_map"]:
+#         T[task] = T[task].permute(1, 0, 2, 3)
+#     for task in ["junc_offset"]:
+#         T[task] = T[task].permute(1, 2, 0, 3, 4)
+#
+#     offset = self.head_off
+#     loss_weight = loss_weight
+#     losses = []
+#
+#     for stack, output in enumerate(outputs):
+#         output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+#         print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+#         jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+#         lmap = output[offset[0]: offset[1]].squeeze(0)
+#         joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+#
+#         if stack == 0:
+#             result["preds"] = {
+#                 "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+#                 "lmap": lmap.sigmoid(),
+#                 "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+#             }
+#             # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+#             # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+#             # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+#
+#             if mode == "testing":
+#                 return result
+#
+#         L = OrderedDict()
+#         L["junc_map"] = sum(
+#             cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+#         )
+#         L["line_map"] = (
+#             F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+#             .mean(2)
+#             .mean(1)
+#         )
+#         L["junc_offset"] = sum(
+#             sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+#             for i in range(n_jtyp)
+#             for j in range(2)
+#         )
+#         for loss_name in L:
+#             L[loss_name].mul_(loss_weight[loss_name])
+#         losses.append(L)
+#
+#     result["losses"] = losses
+#     return result
+
+def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
+    # output, feature: head返回结果
+    # x, y, idx : line中间生成结果
+    result = {}
+    batch, channel, row, col = output.shape
+
+    wires_targets = [t["wires"] for t in targets]
+    wires_targets = wires_targets.copy()
+    # print(f'wires_target:{wires_targets}')
+    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+    junc_maps = [d["junc_map"] for d in wires_targets]
+    junc_offsets = [d["junc_offset"] for d in wires_targets]
+    line_maps = [d["line_map"] for d in wires_targets]
+
+    junc_map_tensor = torch.stack(junc_maps, dim=0)
+    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+    line_map_tensor = torch.stack(line_maps, dim=0)
+    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
+
+    n_jtyp = T["junc_map"].shape[1]
+
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+    lmap = output[offset[0]: offset[1]].squeeze(0)
+    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+    L = OrderedDict()
+    L["junc_map"] = sum(
+        cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+    )
+    L["line_map"] = (
+        F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+            .mean(2)
+            .mean(1)
+    )
+    L["junc_offset"] = sum(
+        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+        for i in range(n_jtyp)
+        for j in range(2)
+    )
+    for loss_name in L:
+        L[loss_name].mul_(loss_weight[loss_name])
+    losses.append(L)
+    result["losses"] = losses
+
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
+    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+
+    return result
+
+
+def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
+    result = {}
+    result["wires"] = {}
+    p = torch.cat(ps)
+    s = torch.sigmoid(input)
+    b = s > 0.5
+    lines = []
+    score = []
+    print(f"n_batch:{n_batch}")
+    for i in range(n_batch):
+        print(f"idx:{idx}")
+        p0 = p[idx[i]: idx[i + 1]]
+        s0 = s[idx[i]: idx[i + 1]]
+        mask = b[idx[i]: idx[i + 1]]
+        p0 = p0[mask]
+        s0 = s0[mask]
+        if len(p0) == 0:
+            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+            score.append(torch.zeros([1, n_out_line], device=p.device))
+        else:
+            arg = torch.argsort(s0, descending=True)
+            p0, s0 = p0[arg], s0[arg]
+            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+        for j in range(len(jcs[i])):
+            if len(jcs[i][j]) == 0:
+                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+            jcs[i][j] = jcs[i][j][
+                None, torch.arange(n_out_junc) % len(jcs[i][j])
+            ]
+    result["wires"]["lines"] = torch.cat(lines)
+    result["wires"]["score"] = torch.cat(score)
+    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+
+    if len(jcs[i]) > 1:
+        result["preds"]["junts"] = torch.cat(
+            [jcs[i][1] for i in range(n_batch)]
+        )
+
+    return result
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
+    # print(f'mask discretization_size:{discretization_size}')
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    # print(f'mask labels:{labels}')
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    # print(f'mask labels1:{labels}')
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+    # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
+    # print(f'mask_targets:{mask_targets}')
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    # print(f'mask_loss:{mask_loss}')
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+        maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+            .index_select(2, x_int.to(dtype=torch.int64))
+            .view(-1)
+            .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+        maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    print(f'x:{x.shape}')
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+    print(f'x2:{x2}')
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+            self,
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            # Faster R-CNN training
+            fg_iou_thresh,
+            bg_iou_thresh,
+            batch_size_per_image,
+            positive_fraction,
+            bbox_reg_weights,
+            # Faster R-CNN inference
+            score_thresh,
+            nms_thresh,
+            detections_per_img,
+            # Mask
+            mask_roi_pool=None,
+            mask_head=None,
+            mask_predictor=None,
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
+            wirepoint_roi_pool=None,
+            wirepoint_head=None,
+            wirepoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+        self.wirepoint_roi_pool = wirepoint_roi_pool
+        self.wirepoint_head = wirepoint_head
+        self.wirepoint_predictor = wirepoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def has_wirepoint(self):
+        if self.wirepoint_roi_pool is None:
+            print(f'wirepoint_roi_pool is None')
+            return False
+        if self.wirepoint_head is None:
+            print(f'wirepoint_head is None')
+            return False
+        if self.wirepoint_predictor is None:
+            print(f'wirepoint_roi_predictor is None')
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+            self,
+            proposals,  # type: List[Tensor]
+            targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+            self,
+            class_logits,  # type: Tensor
+            box_regression,  # type: Tensor
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+            self,
+            features,  # type: Dict[str, Tensor]
+            proposals,  # type: List[Tensor]
+            image_shapes,  # type: List[Tuple[int, int]]
+            targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if self.has_keypoint():
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            # tmp = keypoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
+            keypoint_features = self.keypoint_head(keypoint_features)
+
+            print(f'keypoint_features:{keypoint_features.shape}')
+            tmp = keypoint_features[0][0]
+            plt.imshow(tmp.detach().numpy())
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+            print(f'keypoint_logits:{keypoint_logits.shape}')
+            """
+            接wirenet
+            """
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        if self.has_wirepoint():
+            wirepoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                wirepoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    wirepoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            print(f'proposals:{len(proposals)}')
+            wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
+
+            # tmp = keypoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
+            outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
+
+            outputs = merge_features(outputs, wirepoint_proposals)
+            wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
+
+            print(f'outpust:{outputs.shape}')
+
+            wirepoint_logits = self.wirepoint_predictor(inputs=outputs, features=wirepoint_features, targets=targets)
+            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
+
+            print(f'keypoint_features:{wirepoint_features.shape}')
+            if self.training:
+
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+
+                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+
+            else:
+                pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                result.append(pred)
+
+            # tmp = wirepoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            # wirepoint_logits = self.wirepoint_predictor((outputs,wirepoint_features))
+            # print(f'keypoint_logits:{wirepoint_logits.shape}')
+
+            # loss_wirepoint = {}    lm
+            # result=wirepoint_logits
+
+            # result.append(pred)    lm
+            losses.update(loss_wirepoint)
+        # print(f"result{result}")
+        # print(f"losses{losses}")
+
+        return result, losses
+
+
+def merge_features(features, proposals):
+    # 假设 roi_pool_features 是你的输入张量,形状为 [600, 256, 128, 128]
+
+    # 使用 torch.split 按照每个图像的提议数量分割 features
+    proposals_count = sum([p.size(0) for p in proposals])
+    features_size = features.size(0)
+    print(f'proposals sum:{proposals_count},features batch:{features.size(0)}')
+    if proposals_count != features_size:
+        raise ValueError("The length of proposals must match the batch size of features.")
+
+    split_features = []
+    start_idx = 0
+    for proposal in proposals:
+        # 提取当前图像的特征
+        current_features = features[start_idx:start_idx + proposal.size(0)]
+        print(f'current_features:{current_features.shape}')
+        split_features.append(current_features)
+        start_idx += 1
+
+    features_imgs = []
+    for features_per_img in split_features:
+        features_per_img, _ = torch.max(features_per_img, dim=0, keepdim=True)
+        features_imgs.append(features_per_img)
+
+    merged_features = torch.cat(features_imgs, dim=0)
+    print(f' merged_features:{merged_features.shape}')
+    return merged_features

+ 896 - 0
models/wirenet/roi_head.py

@@ -0,0 +1,896 @@
+from typing import Dict, List, Optional, Tuple
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
+    # print(f'mask discretization_size:{discretization_size}')
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    # print(f'mask labels:{labels}')
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    # print(f'mask labels1:{labels}')
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+    # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
+    # print(f'mask_targets:{mask_targets}')
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    # print(f'mask_loss:{mask_loss}')
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    print(f'x:{x.shape}')
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+    print(f'x2:{x2}')
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            # tmp = keypoint_features[0][0]
+            # plt.imshow(tmp.detach().numpy())
+            print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
+            keypoint_features = self.keypoint_head(keypoint_features)
+
+            print(f'keypoint_features:{keypoint_features.shape}')
+            tmp=keypoint_features[0][0]
+            plt.imshow(tmp.detach().numpy())
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+            print(f'keypoint_logits:{keypoint_logits.shape}')
+            """
+            接wirenet
+            """
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 69 - 0
models/wirenet/wirenet.yaml

@@ -0,0 +1,69 @@
+io:
+  logdir: logs/
+  datadir: D:/python/PycharmProjects/data
+  resume_from:
+  num_workers: 4
+  tensorboard_port: 0
+  validation_interval: 24000
+
+model:
+  image:
+      mean: [109.730, 103.832, 98.681]
+      stddev: [22.275, 22.124, 23.229]
+
+  batch_size: 2
+  batch_size_eval: 2
+
+  # backbone multi-task parameters
+  head_size: [[2], [1], [2]]
+  loss_weight:
+    jmap: 8.0
+    lmap: 0.5
+    joff: 0.25
+    lpos: 1
+    lneg: 1
+
+  # backbone parameters
+  backbone: stacked_hourglass
+  depth: 4
+  num_stacks: 2
+  num_blocks: 1
+
+  # sampler parameters
+  ## static sampler
+  n_stc_posl: 300
+  n_stc_negl: 40
+
+  ## dynamic sampler
+  n_dyn_junc: 300
+  n_dyn_posl: 300
+  n_dyn_negl: 80
+  n_dyn_othr: 600
+
+  # LOIPool layer parameters
+  n_pts0: 32
+  n_pts1: 8
+
+  # line verification network parameters
+  dim_loi: 128
+  dim_fc: 1024
+
+  # maximum junction and line outputs
+  n_out_junc: 250
+  n_out_line: 2500
+
+  # additional ablation study parameters
+  use_cood: 0
+  use_slop: 0
+  use_conv: 0
+
+  # junction threashold for evaluation (See #5)
+  eval_junc_thres: 0.008
+
+optim:
+  name: Adam
+  lr: 4.0e-4
+  amsgrad: True
+  weight_decay: 1.0e-4
+  max_epoch: 24
+  lr_decay_epoch: 10

+ 151 - 0
models/wirenet/wirepoint_dataset.py

@@ -0,0 +1,151 @@
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import masks_to_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+
+class WirePointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+    def __len__(self):
+        return len(self.imgs)
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        h, w = shape
+        labels = []
+        masks = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels, masks = read_masks_from_pixels_wire(lbl_path, shape)
+
+        target = {}
+        target["boxes"] = masks_to_boxes(torch.stack(masks))
+        target["labels"] = torch.stack(labels)
+        target["masks"] = torch.stack(masks)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        return target
+
+    def show(self, idx):
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[idx][:-3] + 'json')
+
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        # 可视化图像和标注
+        image = cv2.imread(img_path)  # [H,W,3]  # 默认为BGR格式
+        # print(image.shape)
+        # 绘制每个标注的多边形
+        # for ann in lable_all["segmentations"]:
+        #     segmentation = [[x * 512 for x in ann['data']]]
+        #     # segmentation = [ann['data']]
+        #     # for i in range(len(ann['data'])):
+        #     #     if i % 2 == 0:
+        #     #         segmentation[0][i] *= image.shape[0]
+        #     #     else:
+        #     #         segmentation[0][i] *= image.shape[0]
+        #
+        #     # if isinstance(segmentation, list):
+        #     #     for seg in segmentation:
+        #     #         poly = np.array(seg).reshape((-1, 2)).astype(int)
+        #     #         cv2.polylines(image, [poly], isClosed=True, color=(0, 255, 0), thickness=2)
+        #     #         cv2.fillPoly(image, [poly], color=(0, 255, 0))
+
+
+        #
+        # # 显示图像
+        # cv2.namedWindow('Image with Segmentations', cv2.WINDOW_NORMAL)
+        # cv2.imshow('Image with Segmentations', image)
+        # cv2.waitKey(0)
+        # cv2.destroyAllWindows()
+
+    def show_img(self,img_path):
+        pass
+

+ 703 - 0
models/wirenet/wirepoint_rcnn.py

@@ -0,0 +1,703 @@
+import os
+from typing import Optional, Any
+
+import numpy as np
+import torch
+from tensorboardX import SummaryWriter
+from torch import nn
+import torch.nn.functional as F
+# from torchinfo import summary
+from torchvision.io import read_image
+from torchvision.models import resnet50, ResNet50_Weights
+from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection._utils import overwrite_eps
+from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
+from torchvision.models.detection.keypoint_rcnn import KeypointRCNNHeads, KeypointRCNNPredictor, \
+    KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.ops import misc as misc_nn_ops
+# from visdom import Visdom
+
+from models.config import config_tool
+from models.config.config_tool import read_yaml
+from models.ins.trainer import get_transform
+from models.wirenet.head import RoIHeads
+from models.wirenet.wirepoint_dataset import WirePointDataset
+from tools import utils
+
+
+FEATURE_DIM = 8
+
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+class Bottleneck1D(nn.Module):
+    def __init__(self, inplanes, outplanes):
+        super(Bottleneck1D, self).__init__()
+
+        planes = outplanes // 2
+        self.op = nn.Sequential(
+            nn.BatchNorm1d(inplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(inplanes, planes, kernel_size=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, outplanes, kernel_size=1),
+        )
+
+    def forward(self, x):
+        return x + self.op(x)
+
+
+class WirepointRCNN(FasterRCNN):
+    def __init__(
+            self,
+            backbone,
+            num_classes=None,
+            # transform parameters
+            min_size=None,
+            max_size=1333,
+            image_mean=None,
+            image_std=None,
+            # RPN parameters
+            rpn_anchor_generator=None,
+            rpn_head=None,
+            rpn_pre_nms_top_n_train=2000,
+            rpn_pre_nms_top_n_test=1000,
+            rpn_post_nms_top_n_train=2000,
+            rpn_post_nms_top_n_test=1000,
+            rpn_nms_thresh=0.7,
+            rpn_fg_iou_thresh=0.7,
+            rpn_bg_iou_thresh=0.3,
+            rpn_batch_size_per_image=256,
+            rpn_positive_fraction=0.5,
+            rpn_score_thresh=0.0,
+            # Box parameters
+            box_roi_pool=None,
+            box_head=None,
+            box_predictor=None,
+            box_score_thresh=0.05,
+            box_nms_thresh=0.5,
+            box_detections_per_img=100,
+            box_fg_iou_thresh=0.5,
+            box_bg_iou_thresh=0.5,
+            box_batch_size_per_image=512,
+            box_positive_fraction=0.25,
+            bbox_reg_weights=None,
+            # keypoint parameters
+            keypoint_roi_pool=None,
+            keypoint_head=None,
+            keypoint_predictor=None,
+            num_keypoints=None,
+            wirepoint_roi_pool=None,
+            wirepoint_head=None,
+            wirepoint_predictor=None,
+            **kwargs,
+    ):
+        if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+            )
+        if min_size is None:
+            min_size = (640, 672, 704, 736, 768, 800)
+
+        if num_keypoints is not None:
+            if keypoint_predictor is not None:
+                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        else:
+            num_keypoints = 17
+
+        out_channels = backbone.out_channels
+
+        if wirepoint_roi_pool is None:
+            wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
+                                                    sampling_ratio=2,)
+
+        if wirepoint_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
+            wirepoint_head = WirepointHead(out_channels, keypoint_layers)
+
+        if wirepoint_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            wirepoint_predictor = WirepointPredictor()
+
+        super().__init__(
+            backbone,
+            num_classes,
+            # transform parameters
+            min_size,
+            max_size,
+            image_mean,
+            image_std,
+            # RPN-specific parameters
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_pre_nms_top_n_train,
+            rpn_pre_nms_top_n_test,
+            rpn_post_nms_top_n_train,
+            rpn_post_nms_top_n_test,
+            rpn_nms_thresh,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_score_thresh,
+            # Box parameters
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            **kwargs,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            # wirepoint_roi_pool=wirepoint_roi_pool,
+            # wirepoint_head=wirepoint_head,
+            # wirepoint_predictor=wirepoint_predictor,
+        )
+        self.roi_heads = roi_heads
+
+        self.roi_heads.wirepoint_roi_pool = wirepoint_roi_pool
+        self.roi_heads.wirepoint_head = wirepoint_head
+        self.roi_heads.wirepoint_predictor = wirepoint_predictor
+
+
+class WirepointHead(nn.Module):
+    def __init__(self, input_channels, num_class):
+        super(WirepointHead, self).__init__()
+        self.head_size = [[2], [1], [2]]
+        m = int(input_channels / 4)
+        heads = []
+        # print(f'M.head_size:{M.head_size}')
+        # for output_channels in sum(M.head_size, []):
+        for output_channels in sum(self.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+
+    def forward(self, x):
+        # for idx, head in enumerate(self.heads):
+        #     print(f'{idx},multitask head:{head(x).shape},input x:{x.shape}')
+
+        outputs = torch.cat([head(x) for head in self.heads], dim=1)
+
+        features = x
+        return outputs, features
+
+
+class WirepointPredictor(nn.Module):
+
+    def __init__(self):
+        super().__init__()
+        # self.backbone = backbone
+        # self.cfg = read_yaml(cfg)
+        self.cfg = read_yaml('wirenet.yaml')
+        self.n_pts0 = self.cfg['model']['n_pts0']
+        self.n_pts1 = self.cfg['model']['n_pts1']
+        self.n_stc_posl = self.cfg['model']['n_stc_posl']
+        self.dim_loi = self.cfg['model']['dim_loi']
+        self.use_conv = self.cfg['model']['use_conv']
+        self.dim_fc = self.cfg['model']['dim_fc']
+        self.n_out_line = self.cfg['model']['n_out_line']
+        self.n_out_junc = self.cfg['model']['n_out_junc']
+        self.loss_weight = self.cfg['model']['loss_weight']
+        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
+        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
+        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
+        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
+        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
+        self.use_cood = self.cfg['model']['use_cood']
+        self.use_slop = self.cfg['model']['use_slop']
+        self.n_stc_negl = self.cfg['model']['n_stc_negl']
+        self.head_size = self.cfg['model']['head_size']
+        self.num_class = sum(sum(self.head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in self.head_size])
+
+        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
+        scale_factor = self.n_pts0 // self.n_pts1
+        if self.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(self.dim_loi, self.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, inputs,features, targets=None):
+
+        # outputs, features = input
+        # for out in outputs:
+        #     print(f'out:{out.shape}')
+        # outputs=merge_features(outputs,100)
+        batch, channel, row, col = inputs.shape
+        print(f'outputs:{inputs.shape}')
+        print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
+
+        if targets is not None:
+            self.training = True
+            # print(f'target:{targets}')
+            wires_targets = [t["wires"] for t in targets]
+            # print(f'wires_target:{wires_targets}')
+            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+            junc_maps = [d["junc_map"] for d in wires_targets]
+            junc_offsets = [d["junc_offset"] for d in wires_targets]
+            line_maps = [d["line_map"] for d in wires_targets]
+
+            junc_map_tensor = torch.stack(junc_maps, dim=0)
+            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+            line_map_tensor = torch.stack(line_maps, dim=0)
+
+            wires_meta = {
+                "junc_map": junc_map_tensor,
+                "junc_offset": junc_offset_tensor,
+                # "line_map": line_map_tensor,
+            }
+        else:
+            self.training = False
+            t = {
+                    "junc_coords": torch.zeros(1, 2),
+                    "jtyp": torch.zeros(1, dtype=torch.uint8),
+                    "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                    "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                    "junc_map": torch.zeros([1, 1, 128, 128]),
+                    "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+                }
+            wires_targets=[t for b in range(inputs.size(0))]
+
+            wires_meta={
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+
+
+        T = wires_meta.copy()
+        n_jtyp = T["junc_map"].shape[1]
+        offset = self.head_off
+        result={}
+        for stack, output in enumerate([inputs]):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+
+        h = result["preds"]
+        print(f'features shape:{features.shape}')
+        x = self.fc1(features)
+
+        print(f'x:{x.shape}')
+
+        n_batch, n_channel, row, col = x.shape
+
+        print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
+
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+
+        for i, meta in enumerate(wires_targets):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i],
+            )
+            print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            ys.append(label)
+            if self.training and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                .reshape(n_channel, -1, self.n_pts0)
+                .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            print(f'xp.shape:{xp.shape}')
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+            print(f'idx__:{idx}')
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+
+        # print("Weight dtype:", self.fc2.weight.dtype)
+        x = torch.cat([x, f], 1)
+        print("Input dtype:", x.dtype)
+        x = x.to(dtype=torch.float32)
+        print("Input dtype1:", x.dtype)
+        x = self.fc2(x).flatten()
+
+        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
+        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+
+
+        # if mode != "training":
+        # self.inference(x, idx, jcs, n_batch, ps)
+
+        # return result
+
+
+    ####deprecated
+    # def inference(self,input, idx, jcs, n_batch, ps):
+    #     if not self.training:
+    #         p = torch.cat(ps)
+    #         s = torch.sigmoid(input)
+    #         b = s > 0.5
+    #         lines = []
+    #         score = []
+    #         print(f"n_batch:{n_batch}")
+    #         for i in range(n_batch):
+    #             print(f"idx:{idx}")
+    #             p0 = p[idx[i]: idx[i + 1]]
+    #             s0 = s[idx[i]: idx[i + 1]]
+    #             mask = b[idx[i]: idx[i + 1]]
+    #             p0 = p0[mask]
+    #             s0 = s0[mask]
+    #             if len(p0) == 0:
+    #                 lines.append(torch.zeros([1, self.n_out_line, 2, 2], device=p.device))
+    #                 score.append(torch.zeros([1, self.n_out_line], device=p.device))
+    #             else:
+    #                 arg = torch.argsort(s0, descending=True)
+    #                 p0, s0 = p0[arg], s0[arg]
+    #                 lines.append(p0[None, torch.arange(self.n_out_line) % len(p0)])
+    #                 score.append(s0[None, torch.arange(self.n_out_line) % len(s0)])
+    #             for j in range(len(jcs[i])):
+    #                 if len(jcs[i][j]) == 0:
+    #                     jcs[i][j] = torch.zeros([self.n_out_junc, 2], device=p.device)
+    #                 jcs[i][j] = jcs[i][j][
+    #                     None, torch.arange(self.n_out_junc) % len(jcs[i][j])
+    #                 ]
+    #         result["preds"]["lines"] = torch.cat(lines)
+    #         result["preds"]["score"] = torch.cat(score)
+    #         result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+    #
+    #         if len(jcs[i]) > 1:
+    #             result["preds"]["junts"] = torch.cat(
+    #                 [jcs[i][1] for i in range(n_batch)]
+    #             )
+    #     if self.training:
+    #         del result["preds"]
+
+    def sample_lines(self, meta, jmap, joff):
+        with torch.no_grad():
+            junc = meta["junc_coords"]  # [N, 2]
+            jtyp = meta["jtyp"]  # [N]
+            Lpos = meta["line_pos_idx"]
+            Lneg = meta["line_neg_idx"]
+
+            n_type = jmap.shape[0]
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = self.n_dyn_junc // n_type
+            N = len(junc)
+            # if mode != "training":
+            if not self.training:
+                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            # if mode == "training":
+            if self.training:
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > self.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > self.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * self.use_cood,
+                    xyv / 128 * self.use_cood,
+                    u2v * self.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
+
+
+def wirepointrcnn_resnet50_fpn(
+        *,
+        weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> WirepointRCNN:
+    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = WirepointRCNN(backbone, num_classes=5, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress))
+        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+if __name__ == '__main__':
+    cfg = 'wirenet.yaml'
+    cfg = read_yaml(cfg)
+    print(f'cfg:{cfg}')
+    print(cfg['model']['n_dyn_negl'])
+    # net = WirepointPredictor()
+
+    dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
+    train_collate_fn = utils.collate_fn_wirepoint
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=10, collate_fn=train_collate_fn
+    )
+    model = wirepointrcnn_resnet50_fpn()
+
+    imgs, targets = next(iter(data_loader))
+
+    model.train()
+    pred = model(imgs, targets)
+    print(f'pred:{pred}')
+    # result, losses = model(imgs, targets)
+    # print(f'result:{result}')
+    # print(f'pred:{losses}')
+'''
+########### predict#############
+
+    img_path=r"I:\wirenet_dateset\images\train\00030078_2.png"
+    transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    img = read_image(img_path)
+    img = transforms(img)
+
+    img = torch.ones((2, 3, 512, 512))
+    # print(f'img shape:{img.shape}')
+    model.eval()
+    onnx_file_path = "./wirenet.onnx"
+
+    # 导出模型为ONNX格式
+    # torch.onnx.export(model, img, onnx_file_path, verbose=True, input_names=['input'],
+    #                   output_names=['output'])
+    # torch.save(model,'./wirenet.pt')
+
+
+
+    # 5. 指定输出的 ONNX 文件名
+    # onnx_file_path = "./wirepoint_rcnn.onnx"
+
+    # 准备一个示例输入:Mask R-CNN 需要一个图像列表作为输入,每个图像张量的形状应为 [C, H, W]
+    img = [torch.ones((3, 800, 800))]  # 示例输入图像大小为 800x800,3个通道
+
+
+
+    # 指定输出的 ONNX 文件名
+    # onnx_file_path = "./mask_rcnn.onnx"
+
+
+
+    # model_scripted = torch.jit.script(model)
+    # torch.onnx.export(model_scripted, input, "model.onnx", verbose=True, input_names=["input"],
+    #                   output_names=["output"])
+    #
+    # print(f"Model has been converted to ONNX and saved to {onnx_file_path}")
+
+    pred=model(img)
+    #
+    print(f'pred:{pred}')
+
+
+
+################################################## end predict
+
+
+
+########## traing ###################################
+    # imgs, targets = next(iter(data_loader))
+
+    # model.train()
+    # pred = model(imgs, targets)
+
+    # class WrapperModule(torch.nn.Module):
+    #     def __init__(self, model):
+    #         super(WrapperModule, self).__init__()
+    #         self.model = model
+    #
+    #     def forward(self,img, targets):
+    #         # 在这里处理复杂的输入结构,将其转换为适合追踪的形式
+    #         return self.model(img,targets)
+
+    # torch.save(model.state_dict(),'./wire.pt')
+    # 包装原始模型
+    # wrapped_model = WrapperModule(model)
+    # # model_scripted = torch.jit.trace(wrapped_model,img)
+    # writer = SummaryWriter('./')
+    # writer.add_graph(wrapped_model, (imgs,targets))
+    # writer.close()
+
+
+    #
+    # print(f'pred:{pred}')
+########## end traing ###################################
+    # for imgs,targets in data_loader:
+    #     print(f'imgs:{imgs}')
+    #     print(f'targets:{targets}')
+'''

+ 192 - 0
tools/coco_eval.py

@@ -0,0 +1,192 @@
+import copy
+import io
+from contextlib import redirect_stdout
+
+import numpy as np
+import pycocotools.mask as mask_util
+import torch
+from tools import utils
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+
+
+class CocoEvaluator:
+    def __init__(self, coco_gt, iou_types):
+        if not isinstance(iou_types, (list, tuple)):
+            raise TypeError(f"This constructor expects iou_types of type list or tuple, instead  got {type(iou_types)}")
+        coco_gt = copy.deepcopy(coco_gt)
+        self.coco_gt = coco_gt
+
+        self.iou_types = iou_types
+        self.coco_eval = {}
+        for iou_type in iou_types:
+            self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
+
+        self.img_ids = []
+        self.eval_imgs = {k: [] for k in iou_types}
+
+    def update(self, predictions):
+        img_ids = list(np.unique(list(predictions.keys())))
+        self.img_ids.extend(img_ids)
+
+        for iou_type in self.iou_types:
+            results = self.prepare(predictions, iou_type)
+            with redirect_stdout(io.StringIO()):
+                coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
+            coco_eval = self.coco_eval[iou_type]
+
+            coco_eval.cocoDt = coco_dt
+            coco_eval.params.imgIds = list(img_ids)
+            img_ids, eval_imgs = evaluate(coco_eval)
+
+            self.eval_imgs[iou_type].append(eval_imgs)
+
+    def synchronize_between_processes(self):
+        for iou_type in self.iou_types:
+            self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
+            create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
+
+    def accumulate(self):
+        for coco_eval in self.coco_eval.values():
+            coco_eval.accumulate()
+
+    def summarize(self):
+        for iou_type, coco_eval in self.coco_eval.items():
+            print(f"IoU metric: {iou_type}")
+            coco_eval.summarize()
+
+    def prepare(self, predictions, iou_type):
+        if iou_type == "bbox":
+            return self.prepare_for_coco_detection(predictions)
+        if iou_type == "segm":
+            return self.prepare_for_coco_segmentation(predictions)
+        if iou_type == "keypoints":
+            return self.prepare_for_coco_keypoint(predictions)
+        raise ValueError(f"Unknown iou type {iou_type}")
+
+    def prepare_for_coco_detection(self, predictions):
+        coco_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            boxes = prediction["boxes"]
+            boxes = convert_to_xywh(boxes).tolist()
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            coco_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "bbox": box,
+                        "score": scores[k],
+                    }
+                    for k, box in enumerate(boxes)
+                ]
+            )
+        return coco_results
+
+    def prepare_for_coco_segmentation(self, predictions):
+        coco_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            scores = prediction["scores"]
+            labels = prediction["labels"]
+            masks = prediction["masks"]
+
+            masks = masks > 0.5
+
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            rles = [
+                mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
+            ]
+            for rle in rles:
+                rle["counts"] = rle["counts"].decode("utf-8")
+
+            coco_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "segmentation": rle,
+                        "score": scores[k],
+                    }
+                    for k, rle in enumerate(rles)
+                ]
+            )
+        return coco_results
+
+    def prepare_for_coco_keypoint(self, predictions):
+        coco_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            boxes = prediction["boxes"]
+            boxes = convert_to_xywh(boxes).tolist()
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+            keypoints = prediction["keypoints"]
+            keypoints = keypoints.flatten(start_dim=1).tolist()
+
+            coco_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "keypoints": keypoint,
+                        "score": scores[k],
+                    }
+                    for k, keypoint in enumerate(keypoints)
+                ]
+            )
+        return coco_results
+
+
+def convert_to_xywh(boxes):
+    xmin, ymin, xmax, ymax = boxes.unbind(1)
+    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
+
+
+def merge(img_ids, eval_imgs):
+    all_img_ids = utils.all_gather(img_ids)
+    all_eval_imgs = utils.all_gather(eval_imgs)
+
+    merged_img_ids = []
+    for p in all_img_ids:
+        merged_img_ids.extend(p)
+
+    merged_eval_imgs = []
+    for p in all_eval_imgs:
+        merged_eval_imgs.append(p)
+
+    merged_img_ids = np.array(merged_img_ids)
+    merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
+
+    # keep only unique (and in sorted order) images
+    merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
+    merged_eval_imgs = merged_eval_imgs[..., idx]
+
+    return merged_img_ids, merged_eval_imgs
+
+
+def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
+    img_ids, eval_imgs = merge(img_ids, eval_imgs)
+    img_ids = list(img_ids)
+    eval_imgs = list(eval_imgs.flatten())
+
+    coco_eval.evalImgs = eval_imgs
+    coco_eval.params.imgIds = img_ids
+    coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
+
+
+def evaluate(imgs):
+    with redirect_stdout(io.StringIO()):
+        imgs.evaluate()
+    return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))

+ 234 - 0
tools/coco_utils.py

@@ -0,0 +1,234 @@
+import os
+
+import torch
+import torch.utils.data
+import torchvision
+from tools import transforms as T
+from pycocotools import mask as coco_mask
+from pycocotools.coco import COCO
+
+
+def convert_coco_poly_to_mask(segmentations, height, width):
+    masks = []
+    for polygons in segmentations:
+        rles = coco_mask.frPyObjects(polygons, height, width)
+        mask = coco_mask.decode(rles)
+        if len(mask.shape) < 3:
+            mask = mask[..., None]
+        mask = torch.as_tensor(mask, dtype=torch.uint8)
+        mask = mask.any(dim=2)
+        masks.append(mask)
+    if masks:
+        masks = torch.stack(masks, dim=0)
+    else:
+        masks = torch.zeros((0, height, width), dtype=torch.uint8)
+    return masks
+
+
+class ConvertCocoPolysToMask:
+    def __call__(self, image, target):
+        w, h = image.size
+
+        image_id = target["image_id"]
+
+        anno = target["annotations"]
+
+        anno = [obj for obj in anno if obj["iscrowd"] == 0]
+
+        boxes = [obj["bbox"] for obj in anno]
+        # guard against no boxes via resizing
+        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
+        boxes[:, 2:] += boxes[:, :2]
+        boxes[:, 0::2].clamp_(min=0, max=w)
+        boxes[:, 1::2].clamp_(min=0, max=h)
+
+        classes = [obj["category_id"] for obj in anno]
+        classes = torch.tensor(classes, dtype=torch.int64)
+
+        segmentations = [obj["segmentation"] for obj in anno]
+        masks = convert_coco_poly_to_mask(segmentations, h, w)
+
+        keypoints = None
+        if anno and "keypoints" in anno[0]:
+            keypoints = [obj["keypoints"] for obj in anno]
+            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
+            num_keypoints = keypoints.shape[0]
+            if num_keypoints:
+                keypoints = keypoints.view(num_keypoints, -1, 3)
+
+        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+        boxes = boxes[keep]
+        classes = classes[keep]
+        masks = masks[keep]
+        if keypoints is not None:
+            keypoints = keypoints[keep]
+
+        target = {}
+        target["boxes"] = boxes
+        target["labels"] = classes
+        target["masks"] = masks
+        target["image_id"] = image_id
+        if keypoints is not None:
+            target["keypoints"] = keypoints
+
+        # for conversion to coco api
+        area = torch.tensor([obj["area"] for obj in anno])
+        iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
+        target["area"] = area
+        target["iscrowd"] = iscrowd
+
+        return image, target
+
+
+def _coco_remove_images_without_annotations(dataset, cat_list=None):
+    def _has_only_empty_bbox(anno):
+        return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
+
+    def _count_visible_keypoints(anno):
+        return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
+
+    min_keypoints_per_image = 10
+
+    def _has_valid_annotation(anno):
+        # if it's empty, there is no annotation
+        if len(anno) == 0:
+            return False
+        # if all boxes have close to zero area, there is no annotation
+        if _has_only_empty_bbox(anno):
+            return False
+        # keypoints task have a slight different criteria for considering
+        # if an annotation is valid
+        if "keypoints" not in anno[0]:
+            return True
+        # for keypoint detection tasks, only consider valid images those
+        # containing at least min_keypoints_per_image
+        if _count_visible_keypoints(anno) >= min_keypoints_per_image:
+            return True
+        return False
+
+    ids = []
+    for ds_idx, img_id in enumerate(dataset.ids):
+        ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+        anno = dataset.coco.loadAnns(ann_ids)
+        if cat_list:
+            anno = [obj for obj in anno if obj["category_id"] in cat_list]
+        if _has_valid_annotation(anno):
+            ids.append(ds_idx)
+
+    dataset = torch.utils.data.Subset(dataset, ids)
+    return dataset
+
+
+def convert_to_coco_api(ds):
+    coco_ds = COCO()
+    # annotation IDs need to start at 1, not 0, see torchvision issue #1530
+    ann_id = 1
+    dataset = {"images": [], "categories": [], "annotations": []}
+    categories = set()
+    for img_idx in range(len(ds)):
+        # find better way to get target
+        # targets = ds.get_annotations(img_idx)
+        img, targets = ds[img_idx]
+        image_id = targets["image_id"]
+        img_dict = {}
+        img_dict["id"] = image_id
+        img_dict["height"] = img.shape[-2]
+        img_dict["width"] = img.shape[-1]
+        dataset["images"].append(img_dict)
+        bboxes = targets["boxes"].clone()
+        bboxes[:, 2:] -= bboxes[:, :2]
+        bboxes = bboxes.tolist()
+        labels = targets["labels"].tolist()
+        areas = targets["area"].tolist()
+        iscrowd = targets["iscrowd"].tolist()
+        if "masks" in targets:
+            masks = targets["masks"]
+            # make masks Fortran contiguous for coco_mask
+            masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
+        if "keypoints" in targets:
+            keypoints = targets["keypoints"]
+            keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
+        num_objs = len(bboxes)
+        for i in range(num_objs):
+            ann = {}
+            ann["image_id"] = image_id
+            ann["bbox"] = bboxes[i]
+            ann["category_id"] = labels[i]
+            categories.add(labels[i])
+            ann["area"] = areas[i]
+            ann["iscrowd"] = iscrowd[i]
+            ann["id"] = ann_id
+            if "masks" in targets:
+                ann["segmentation"] = coco_mask.encode(masks[i].numpy())
+            if "keypoints" in targets:
+                ann["keypoints"] = keypoints[i]
+                ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
+            dataset["annotations"].append(ann)
+            ann_id += 1
+    dataset["categories"] = [{"id": i} for i in sorted(categories)]
+    coco_ds.dataset = dataset
+    coco_ds.createIndex()
+    return coco_ds
+
+
+def get_coco_api_from_dataset(dataset):
+    # FIXME: This is... awful?
+    for _ in range(10):
+        if isinstance(dataset, torchvision.datasets.CocoDetection):
+            break
+        if isinstance(dataset, torch.utils.data.Subset):
+            dataset = dataset.dataset
+    if isinstance(dataset, torchvision.datasets.CocoDetection):
+        return dataset.coco
+    return convert_to_coco_api(dataset)
+
+
+class CocoDetection(torchvision.datasets.CocoDetection):
+    def __init__(self, img_folder, ann_file, transforms):
+        super().__init__(img_folder, ann_file)
+        self._transforms = transforms
+
+    def __getitem__(self, idx):
+        img, target = super().__getitem__(idx)
+        image_id = self.ids[idx]
+        target = dict(image_id=image_id, annotations=target)
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+        return img, target
+
+
+def get_coco(root, image_set, transforms, mode="instances", use_v2=False, with_masks=False):
+    anno_file_template = "{}_{}2017.json"
+    PATHS = {
+        "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
+        "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
+        # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
+    }
+
+    img_folder, ann_file = PATHS[image_set]
+    img_folder = os.path.join(root, img_folder)
+    ann_file = os.path.join(root, ann_file)
+
+    if use_v2:
+        from torchvision.datasets import wrap_dataset_for_transforms_v2
+
+        dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
+        target_keys = ["boxes", "labels", "image_id"]
+        if with_masks:
+            target_keys += ["masks"]
+        dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
+    else:
+        # TODO: handle with_masks for V1?
+        t = [ConvertCocoPolysToMask()]
+        if transforms is not None:
+            t.append(transforms)
+        transforms = T.Compose(t)
+
+        dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
+
+    if image_set == "train":
+        dataset = _coco_remove_images_without_annotations(dataset)
+
+    # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
+
+    return dataset

+ 116 - 0
tools/engine.py

@@ -0,0 +1,116 @@
+import math
+import sys
+import time
+
+import torch
+import torchvision.models.detection.mask_rcnn
+from tools import  utils
+from tools.coco_eval import  CocoEvaluator
+from tools.coco_utils import get_coco_api_from_dataset
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+
+def _get_iou_types(model):
+    model_without_ddp = model
+    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+        model_without_ddp = model.module
+    iou_types = ["bbox"]
+    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
+        iou_types.append("segm")
+    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
+        iou_types.append("keypoints")
+    return iou_types
+
+
+@torch.inference_mode()
+def evaluate(model, data_loader, device):
+    n_threads = torch.get_num_threads()
+    # FIXME remove this and make paste_masks_in_image run on the GPU
+    torch.set_num_threads(1)
+    cpu_device = torch.device("cpu")
+    model.eval()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    header = "Test:"
+
+    coco = get_coco_api_from_dataset(data_loader.dataset)
+    iou_types = _get_iou_types(model)
+    coco_evaluator = CocoEvaluator(coco, iou_types)
+
+    print(f'start to evaluate!!!')
+    for images, targets in metric_logger.log_every(data_loader, 10, header):
+        images = list(img.to(device) for img in images)
+
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+        model_time = time.time()
+        outputs = model(images)
+
+        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+        model_time = time.time() - model_time
+
+        res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+        evaluator_time = time.time()
+        coco_evaluator.update(res)
+        evaluator_time = time.time() - evaluator_time
+        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print("Averaged stats:", metric_logger)
+    coco_evaluator.synchronize_between_processes()
+
+    # accumulate predictions from all images
+    coco_evaluator.accumulate()
+    coco_evaluator.summarize()
+    torch.set_num_threads(n_threads)
+    return coco_evaluator

+ 196 - 0
tools/group_by_aspect_ratio.py

@@ -0,0 +1,196 @@
+import bisect
+import copy
+import math
+from collections import defaultdict
+from itertools import chain, repeat
+
+import numpy as np
+import torch
+import torch.utils.data
+import torchvision
+from PIL import Image
+from torch.utils.data.sampler import BatchSampler, Sampler
+from torch.utils.model_zoo import tqdm
+
+
+def _repeat_to_at_least(iterable, n):
+    repeat_times = math.ceil(n / len(iterable))
+    repeated = chain.from_iterable(repeat(iterable, repeat_times))
+    return list(repeated)
+
+
+class GroupedBatchSampler(BatchSampler):
+    """
+    Wraps another sampler to yield a mini-batch of indices.
+    It enforces that the batch only contain elements from the same group.
+    It also tries to provide mini-batches which follows an ordering which is
+    as close as possible to the ordering from the original sampler.
+    Args:
+        sampler (Sampler): Base sampler.
+        group_ids (list[int]): If the sampler produces indices in range [0, N),
+            `group_ids` must be a list of `N` ints which contains the group id of each sample.
+            The group ids must be a continuous set of integers starting from
+            0, i.e. they must be in the range [0, num_groups).
+        batch_size (int): Size of mini-batch.
+    """
+
+    def __init__(self, sampler, group_ids, batch_size):
+        if not isinstance(sampler, Sampler):
+            raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
+        self.sampler = sampler
+        self.group_ids = group_ids
+        self.batch_size = batch_size
+
+    def __iter__(self):
+        buffer_per_group = defaultdict(list)
+        samples_per_group = defaultdict(list)
+
+        num_batches = 0
+        for idx in self.sampler:
+            group_id = self.group_ids[idx]
+            buffer_per_group[group_id].append(idx)
+            samples_per_group[group_id].append(idx)
+            if len(buffer_per_group[group_id]) == self.batch_size:
+                yield buffer_per_group[group_id]
+                num_batches += 1
+                del buffer_per_group[group_id]
+            assert len(buffer_per_group[group_id]) < self.batch_size
+
+        # now we have run out of elements that satisfy
+        # the group criteria, let's return the remaining
+        # elements so that the size of the sampler is
+        # deterministic
+        expected_num_batches = len(self)
+        num_remaining = expected_num_batches - num_batches
+        if num_remaining > 0:
+            # for the remaining batches, take first the buffers with the largest number
+            # of elements
+            for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
+                remaining = self.batch_size - len(buffer_per_group[group_id])
+                samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
+                buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
+                assert len(buffer_per_group[group_id]) == self.batch_size
+                yield buffer_per_group[group_id]
+                num_remaining -= 1
+                if num_remaining == 0:
+                    break
+        assert num_remaining == 0
+
+    def __len__(self):
+        return len(self.sampler) // self.batch_size
+
+
+def _compute_aspect_ratios_slow(dataset, indices=None):
+    print(
+        "Your dataset doesn't support the fast path for "
+        "computing the aspect ratios, so will iterate over "
+        "the full dataset and load every image instead. "
+        "This might take some time..."
+    )
+    if indices is None:
+        indices = range(len(dataset))
+
+    class SubsetSampler(Sampler):
+        def __init__(self, indices):
+            self.indices = indices
+
+        def __iter__(self):
+            return iter(self.indices)
+
+        def __len__(self):
+            return len(self.indices)
+
+    sampler = SubsetSampler(indices)
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=1,
+        sampler=sampler,
+        num_workers=14,  # you might want to increase it for faster processing
+        collate_fn=lambda x: x[0],
+    )
+    aspect_ratios = []
+    with tqdm(total=len(dataset)) as pbar:
+        for _i, (img, _) in enumerate(data_loader):
+            pbar.update(1)
+            height, width = img.shape[-2:]
+            aspect_ratio = float(width) / float(height)
+            aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+    aspect_ratios = []
+    for i in indices:
+        height, width = dataset.get_height_and_width(i)
+        aspect_ratio = float(width) / float(height)
+        aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+    aspect_ratios = []
+    for i in indices:
+        img_info = dataset.coco.imgs[dataset.ids[i]]
+        aspect_ratio = float(img_info["width"]) / float(img_info["height"])
+        aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+    aspect_ratios = []
+    for i in indices:
+        # this doesn't load the data into memory, because PIL loads it lazily
+        width, height = Image.open(dataset.images[i]).size
+        aspect_ratio = float(width) / float(height)
+        aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+
+    ds_indices = [dataset.indices[i] for i in indices]
+    return compute_aspect_ratios(dataset.dataset, ds_indices)
+
+
+def compute_aspect_ratios(dataset, indices=None):
+    if hasattr(dataset, "get_height_and_width"):
+        return _compute_aspect_ratios_custom_dataset(dataset, indices)
+
+    if isinstance(dataset, torchvision.datasets.CocoDetection):
+        return _compute_aspect_ratios_coco_dataset(dataset, indices)
+
+    if isinstance(dataset, torchvision.datasets.VOCDetection):
+        return _compute_aspect_ratios_voc_dataset(dataset, indices)
+
+    if isinstance(dataset, torch.utils.data.Subset):
+        return _compute_aspect_ratios_subset_dataset(dataset, indices)
+
+    # slow path
+    return _compute_aspect_ratios_slow(dataset, indices)
+
+
+def _quantize(x, bins):
+    bins = copy.deepcopy(bins)
+    bins = sorted(bins)
+    quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
+    return quantized
+
+
+def create_aspect_ratio_groups(dataset, k=0):
+    aspect_ratios = compute_aspect_ratios(dataset)
+    bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
+    groups = _quantize(aspect_ratios, bins)
+    # count number of elements per group
+    counts = np.unique(groups, return_counts=True)[1]
+    fbins = [0] + bins + [np.inf]
+    print(f"Using {fbins} as bins for aspect ratio quantization")
+    print(f"Count of instances per bin: {counts}")
+    return groups

+ 115 - 0
tools/presets.py

@@ -0,0 +1,115 @@
+from collections import defaultdict
+
+import torch
+from tools import transforms as reference_transforms
+# import transforms as reference_transforms
+
+
+def get_modules(use_v2):
+    # We need a protected import to avoid the V2 warning in case just V1 is used
+    if use_v2:
+        import torchvision.transforms.v2
+        import torchvision.tv_tensors
+
+        return torchvision.transforms.v2, torchvision.tv_tensors
+    else:
+        return reference_transforms, None
+
+
+class DetectionPresetTrain:
+    # Note: this transform assumes that the input to forward() are always PIL
+    # images, regardless of the backend parameter.
+    def __init__(
+        self,
+        *,
+        data_augmentation,
+        hflip_prob=0.5,
+        mean=(123.0, 117.0, 104.0),
+        backend="pil",
+        use_v2=False,
+    ):
+
+        T, tv_tensors = get_modules(use_v2)
+
+        transforms = []
+        backend = backend.lower()
+        if backend == "tv_tensor":
+            transforms.append(T.ToImage())
+        elif backend == "tensor":
+            transforms.append(T.PILToTensor())
+        elif backend != "pil":
+            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
+
+        if data_augmentation == "hflip":
+            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
+        elif data_augmentation == "lsj":
+            transforms += [
+                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
+                # TODO: FixedSizeCrop below doesn't work on tensors!
+                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "multiscale":
+            transforms += [
+                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssd":
+            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
+            transforms += [
+                T.RandomPhotometricDistort(),
+                T.RandomZoomOut(fill=fill),
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssdlite":
+            transforms += [
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        else:
+            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
+
+        if backend == "pil":
+            # Note: we could just convert to pure tensors even in v2.
+            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
+
+        transforms += [T.ToDtype(torch.float, scale=True)]
+
+        if use_v2:
+            transforms += [
+                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
+                T.SanitizeBoundingBoxes(),
+                T.ToPureTensor(),
+            ]
+
+        self.transforms = T.Compose(transforms)
+
+    def __call__(self, img, target):
+        return self.transforms(img, target)
+
+
+class DetectionPresetEval:
+    def __init__(self, backend="pil", use_v2=False):
+        T, _ = get_modules(use_v2)
+        transforms = []
+        backend = backend.lower()
+        if backend == "pil":
+            # Note: we could just convert to pure tensors even in v2?
+            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
+        elif backend == "tensor":
+            transforms += [T.PILToTensor()]
+        elif backend == "tv_tensor":
+            transforms += [T.ToImage()]
+        else:
+            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
+
+        transforms += [T.ToDtype(torch.float, scale=True)]
+
+        if use_v2:
+            transforms += [T.ToPureTensor()]
+
+        self.transforms = T.Compose(transforms)
+
+    def __call__(self, img, target):
+        return self.transforms(img, target)

+ 334 - 0
tools/train.py

@@ -0,0 +1,334 @@
+r"""PyTorch Detection Training.
+
+To run in a multi-gpu environment, use the distributed launcher::
+
+    python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
+        trainer.py ... --world-size $NGPU
+
+The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu.
+    --lr 0.02 --batch-size 2 --world-size 8
+If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU.
+
+On top of that, for training Faster/Mask R-CNN, the default hyperparameters are
+    --epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3
+
+Also, if you train Keypoint R-CNN, the default hyperparameters are
+    --epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3
+Because the number of images is smaller in the person keypoint subset of COCO,
+the number of epochs should be adapted so that we have the same number of iterations.
+"""
+import datetime
+import os
+import time
+
+import presets
+import torch
+import torch.utils.data
+import torchvision
+import torchvision.models.detection
+import torchvision.models.detection.mask_rcnn
+import utils
+from coco_utils import get_coco
+from engine import evaluate, train_one_epoch
+from group_by_aspect_ratio import create_aspect_ratio_groups, GroupedBatchSampler
+from torchvision.transforms import InterpolationMode
+from transforms import SimpleCopyPaste
+
+
+def copypaste_collate_fn(batch):
+    copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
+    return copypaste(*utils.collate_fn(batch))
+
+
+def get_dataset(is_train, args):
+    image_set = "train" if is_train else "val"
+    num_classes, mode = {"coco": (91, "instances"), "coco_kp": (2, "person_keypoints")}[args.dataset]
+    with_masks = "mask" in args.model
+    ds = get_coco(
+        root=args.data_path,
+        image_set=image_set,
+        transforms=get_transform(is_train, args),
+        mode=mode,
+        use_v2=args.use_v2,
+        with_masks=with_masks,
+    )
+    return ds, num_classes
+
+
+def get_transform(is_train, args):
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2
+        )
+    elif args.weights and args.test_only:
+        weights = torchvision.models.get_weight(args.weights)
+        trans = weights.transforms()
+        return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2)
+
+
+def get_args_parser(add_help=True):
+    import argparse
+
+    parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
+
+    parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
+    parser.add_argument(
+        "--dataset",
+        default="coco",
+        type=str,
+        help="dataset name. Use coco for object detection and instance segmentation and coco_kp for Keypoint detection",
+    )
+    parser.add_argument("--models", default="maskrcnn_resnet50_fpn", type=str, help="models name")
+    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
+    parser.add_argument(
+        "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
+    )
+    parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run")
+    parser.add_argument(
+        "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
+    )
+    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
+    parser.add_argument(
+        "--lr",
+        default=0.02,
+        type=float,
+        help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
+    )
+    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
+    parser.add_argument(
+        "--wd",
+        "--weight-decay",
+        default=1e-4,
+        type=float,
+        metavar="W",
+        help="weight decay (default: 1e-4)",
+        dest="weight_decay",
+    )
+    parser.add_argument(
+        "--norm-weight-decay",
+        default=None,
+        type=float,
+        help="weight decay for Normalization layers (default: None, same value as --wd)",
+    )
+    parser.add_argument(
+        "--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
+    )
+    parser.add_argument(
+        "--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)"
+    )
+    parser.add_argument(
+        "--lr-steps",
+        default=[16, 22],
+        nargs="+",
+        type=int,
+        help="decrease lr every step-size epochs (multisteplr scheduler only)",
+    )
+    parser.add_argument(
+        "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
+    )
+    parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
+    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
+    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
+    parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
+    parser.add_argument("--aspect-ratio-group-factor", default=3, type=int)
+    parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn")
+    parser.add_argument(
+        "--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone"
+    )
+    parser.add_argument(
+        "--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
+    )
+    parser.add_argument(
+        "--sync-bn",
+        dest="sync_bn",
+        help="Use sync batch norm",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--test-only",
+        dest="test_only",
+        help="Only test the models",
+        action="store_true",
+    )
+
+    parser.add_argument(
+        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
+    )
+
+    # distributed training parameters
+    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
+    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
+    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
+    parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
+
+    # Mixed precision training parameters
+    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
+
+    # Use CopyPaste augmentation training parameter
+    parser.add_argument(
+        "--use-copypaste",
+        action="store_true",
+        help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
+    )
+
+    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
+    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
+
+    return parser
+
+
+def main(args):
+    if args.backend.lower() == "tv_tensor" and not args.use_v2:
+        raise ValueError("Use --use-v2 if you want to use the tv_tensor backend.")
+    if args.dataset not in ("coco", "coco_kp"):
+        raise ValueError(f"Dataset should be coco or coco_kp, got {args.dataset}")
+    if "keypoint" in args.model and args.dataset != "coco_kp":
+        raise ValueError("Oops, if you want Keypoint detection, set --dataset coco_kp")
+    if args.dataset == "coco_kp" and args.use_v2:
+        raise ValueError("KeyPoint detection doesn't support V2 transforms yet")
+
+    if args.output_dir:
+        utils.mkdir(args.output_dir)
+
+    utils.init_distributed_mode(args)
+    print(args)
+
+    device = torch.device(args.device)
+
+    if args.use_deterministic_algorithms:
+        torch.use_deterministic_algorithms(True)
+
+    # Data loading code
+    print("Loading data")
+
+    dataset, num_classes = get_dataset(is_train=True, args=args)
+    dataset_test, _ = get_dataset(is_train=False, args=args)
+
+    print("Creating data loaders")
+    if args.distributed:
+        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
+    else:
+        train_sampler = torch.utils.data.RandomSampler(dataset)
+        test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+
+    if args.aspect_ratio_group_factor >= 0:
+        group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
+        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+    else:
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
+
+    train_collate_fn = utils.collate_fn
+    if args.use_copypaste:
+        if args.data_augmentation != "lsj":
+            raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
+
+        train_collate_fn = copypaste_collate_fn
+
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
+    )
+
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
+    )
+
+    print("Creating models")
+    kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
+    if args.data_augmentation in ["multiscale", "lsj"]:
+        kwargs["_skip_resize"] = True
+    if "rcnn" in args.model:
+        if args.rpn_score_thresh is not None:
+            kwargs["rpn_score_thresh"] = args.rpn_score_thresh
+    model = torchvision.models.get_model(
+        args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
+    )
+    model.to(device)
+    if args.distributed and args.sync_bn:
+        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
+    model_without_ddp = model
+    if args.distributed:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
+
+    if args.norm_weight_decay is None:
+        parameters = [p for p in model.parameters() if p.requires_grad]
+    else:
+        param_groups = torchvision.ops._utils.split_normalization_params(model)
+        wd_groups = [args.norm_weight_decay, args.weight_decay]
+        parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
+
+    opt_name = args.opt.lower()
+    if opt_name.startswith("sgd"):
+        optimizer = torch.optim.SGD(
+            parameters,
+            lr=args.lr,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay,
+            nesterov="nesterov" in opt_name,
+        )
+    elif opt_name == "adamw":
+        optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
+    else:
+        raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
+
+    scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+    args.lr_scheduler = args.lr_scheduler.lower()
+    if args.lr_scheduler == "multisteplr":
+        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
+    elif args.lr_scheduler == "cosineannealinglr":
+        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+    else:
+        raise RuntimeError(
+            f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
+        )
+
+    if args.resume:
+        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
+        model_without_ddp.load_state_dict(checkpoint["models"])
+        optimizer.load_state_dict(checkpoint["optimizer"])
+        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
+        args.start_epoch = checkpoint["epoch"] + 1
+        if args.amp:
+            scaler.load_state_dict(checkpoint["scaler"])
+
+    if args.test_only:
+        torch.backends.cudnn.deterministic = True
+        evaluate(model, data_loader_test, device=device)
+        return
+
+    print("Start training")
+    start_time = time.time()
+    for epoch in range(args.start_epoch, args.epochs):
+        if args.distributed:
+            train_sampler.set_epoch(epoch)
+        train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
+        lr_scheduler.step()
+        if args.output_dir:
+            checkpoint = {
+                "models": model_without_ddp.state_dict(),
+                "optimizer": optimizer.state_dict(),
+                "lr_scheduler": lr_scheduler.state_dict(),
+                "args": args,
+                "epoch": epoch,
+            }
+            if args.amp:
+                checkpoint["scaler"] = scaler.state_dict()
+            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
+            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
+
+        # evaluate after every epoch
+        evaluate(model, data_loader_test, device=device)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print(f"Training time {total_time_str}")
+
+
+if __name__ == "__main__":
+    args = get_args_parser().parse_args()
+    main(args)

+ 601 - 0
tools/transforms.py

@@ -0,0 +1,601 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torchvision
+from torch import nn, Tensor
+from torchvision import ops
+from torchvision.transforms import functional as F, InterpolationMode, transforms as T
+
+
+def _flip_coco_person_keypoints(kps, width):
+    flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
+    flipped_data = kps[:, flip_inds]
+    flipped_data[..., 0] = width - flipped_data[..., 0]
+    # Maintain COCO convention that if visibility == 0, then x, y = 0
+    inds = flipped_data[..., 2] == 0
+    flipped_data[inds] = 0
+    return flipped_data
+
+
+class Compose:
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, target):
+        for t in self.transforms:
+            image, target = t(image, target)
+        return image, target
+
+
+class RandomHorizontalFlip(T.RandomHorizontalFlip):
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        if torch.rand(1) < self.p:
+            image = F.hflip(image)
+            if target is not None:
+                _, _, width = F.get_dimensions(image)
+                target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
+                if "masks" in target:
+                    target["masks"] = target["masks"].flip(-1)
+                if "keypoints" in target:
+                    keypoints = target["keypoints"]
+                    keypoints = _flip_coco_person_keypoints(keypoints, width)
+                    target["keypoints"] = keypoints
+        return image, target
+
+
+class PILToTensor(nn.Module):
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        image = F.pil_to_tensor(image)
+        return image, target
+
+
+class ToDtype(nn.Module):
+    def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
+        super().__init__()
+        self.dtype = dtype
+        self.scale = scale
+
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        if not self.scale:
+            return image.to(dtype=self.dtype), target
+        image = F.convert_image_dtype(image, self.dtype)
+        return image, target
+
+
+class RandomIoUCrop(nn.Module):
+    def __init__(
+        self,
+        min_scale: float = 0.3,
+        max_scale: float = 1.0,
+        min_aspect_ratio: float = 0.5,
+        max_aspect_ratio: float = 2.0,
+        sampler_options: Optional[List[float]] = None,
+        trials: int = 40,
+    ):
+        super().__init__()
+        # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
+        self.min_scale = min_scale
+        self.max_scale = max_scale
+        self.min_aspect_ratio = min_aspect_ratio
+        self.max_aspect_ratio = max_aspect_ratio
+        if sampler_options is None:
+            sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
+        self.options = sampler_options
+        self.trials = trials
+
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        if target is None:
+            raise ValueError("The targets can't be None for this transform.")
+
+        if isinstance(image, torch.Tensor):
+            if image.ndimension() not in {2, 3}:
+                raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
+            elif image.ndimension() == 2:
+                image = image.unsqueeze(0)
+
+        _, orig_h, orig_w = F.get_dimensions(image)
+
+        while True:
+            # sample an option
+            idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
+            min_jaccard_overlap = self.options[idx]
+            if min_jaccard_overlap >= 1.0:  # a value larger than 1 encodes the leave as-is option
+                return image, target
+
+            for _ in range(self.trials):
+                # check the aspect ratio limitations
+                r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
+                new_w = int(orig_w * r[0])
+                new_h = int(orig_h * r[1])
+                aspect_ratio = new_w / new_h
+                if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
+                    continue
+
+                # check for 0 area crops
+                r = torch.rand(2)
+                left = int((orig_w - new_w) * r[0])
+                top = int((orig_h - new_h) * r[1])
+                right = left + new_w
+                bottom = top + new_h
+                if left == right or top == bottom:
+                    continue
+
+                # check for any valid boxes with centers within the crop area
+                cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
+                cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
+                is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
+                if not is_within_crop_area.any():
+                    continue
+
+                # check at least 1 box with jaccard limitations
+                boxes = target["boxes"][is_within_crop_area]
+                ious = torchvision.ops.boxes.box_iou(
+                    boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
+                )
+                if ious.max() < min_jaccard_overlap:
+                    continue
+
+                # keep only valid boxes and perform cropping
+                target["boxes"] = boxes
+                target["labels"] = target["labels"][is_within_crop_area]
+                target["boxes"][:, 0::2] -= left
+                target["boxes"][:, 1::2] -= top
+                target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
+                target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
+                image = F.crop(image, top, left, new_h, new_w)
+
+                return image, target
+
+
+class RandomZoomOut(nn.Module):
+    def __init__(
+        self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
+    ):
+        super().__init__()
+        if fill is None:
+            fill = [0.0, 0.0, 0.0]
+        self.fill = fill
+        self.side_range = side_range
+        if side_range[0] < 1.0 or side_range[0] > side_range[1]:
+            raise ValueError(f"Invalid canvas side range provided {side_range}.")
+        self.p = p
+
+    @torch.jit.unused
+    def _get_fill_value(self, is_pil):
+        # type: (bool) -> int
+        # We fake the type to make it work on JIT
+        return tuple(int(x) for x in self.fill) if is_pil else 0
+
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        if isinstance(image, torch.Tensor):
+            if image.ndimension() not in {2, 3}:
+                raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
+            elif image.ndimension() == 2:
+                image = image.unsqueeze(0)
+
+        if torch.rand(1) >= self.p:
+            return image, target
+
+        _, orig_h, orig_w = F.get_dimensions(image)
+
+        r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
+        canvas_width = int(orig_w * r)
+        canvas_height = int(orig_h * r)
+
+        r = torch.rand(2)
+        left = int((canvas_width - orig_w) * r[0])
+        top = int((canvas_height - orig_h) * r[1])
+        right = canvas_width - (left + orig_w)
+        bottom = canvas_height - (top + orig_h)
+
+        if torch.jit.is_scripting():
+            fill = 0
+        else:
+            fill = self._get_fill_value(F._is_pil_image(image))
+
+        image = F.pad(image, [left, top, right, bottom], fill=fill)
+        if isinstance(image, torch.Tensor):
+            # PyTorch's pad supports only integers on fill. So we need to overwrite the colour
+            v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
+            image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
+                ..., :, (left + orig_w) :
+            ] = v
+
+        if target is not None:
+            target["boxes"][:, 0::2] += left
+            target["boxes"][:, 1::2] += top
+
+        return image, target
+
+
+class RandomPhotometricDistort(nn.Module):
+    def __init__(
+        self,
+        contrast: Tuple[float, float] = (0.5, 1.5),
+        saturation: Tuple[float, float] = (0.5, 1.5),
+        hue: Tuple[float, float] = (-0.05, 0.05),
+        brightness: Tuple[float, float] = (0.875, 1.125),
+        p: float = 0.5,
+    ):
+        super().__init__()
+        self._brightness = T.ColorJitter(brightness=brightness)
+        self._contrast = T.ColorJitter(contrast=contrast)
+        self._hue = T.ColorJitter(hue=hue)
+        self._saturation = T.ColorJitter(saturation=saturation)
+        self.p = p
+
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        if isinstance(image, torch.Tensor):
+            if image.ndimension() not in {2, 3}:
+                raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
+            elif image.ndimension() == 2:
+                image = image.unsqueeze(0)
+
+        r = torch.rand(7)
+
+        if r[0] < self.p:
+            image = self._brightness(image)
+
+        contrast_before = r[1] < 0.5
+        if contrast_before:
+            if r[2] < self.p:
+                image = self._contrast(image)
+
+        if r[3] < self.p:
+            image = self._saturation(image)
+
+        if r[4] < self.p:
+            image = self._hue(image)
+
+        if not contrast_before:
+            if r[5] < self.p:
+                image = self._contrast(image)
+
+        if r[6] < self.p:
+            channels, _, _ = F.get_dimensions(image)
+            permutation = torch.randperm(channels)
+
+            is_pil = F._is_pil_image(image)
+            if is_pil:
+                image = F.pil_to_tensor(image)
+                image = F.convert_image_dtype(image)
+            image = image[..., permutation, :, :]
+            if is_pil:
+                image = F.to_pil_image(image)
+
+        return image, target
+
+
+class ScaleJitter(nn.Module):
+    """Randomly resizes the image and its bounding boxes  within the specified scale range.
+    The class implements the Scale Jitter augmentation as described in the paper
+    `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
+
+    Args:
+        target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
+        scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
+            range a <= scale <= b.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
+    """
+
+    def __init__(
+        self,
+        target_size: Tuple[int, int],
+        scale_range: Tuple[float, float] = (0.1, 2.0),
+        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
+        antialias=True,
+    ):
+        super().__init__()
+        self.target_size = target_size
+        self.scale_range = scale_range
+        self.interpolation = interpolation
+        self.antialias = antialias
+
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        if isinstance(image, torch.Tensor):
+            if image.ndimension() not in {2, 3}:
+                raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
+            elif image.ndimension() == 2:
+                image = image.unsqueeze(0)
+
+        _, orig_height, orig_width = F.get_dimensions(image)
+
+        scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
+        r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
+        new_width = int(orig_width * r)
+        new_height = int(orig_height * r)
+
+        image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)
+
+        if target is not None:
+            target["boxes"][:, 0::2] *= new_width / orig_width
+            target["boxes"][:, 1::2] *= new_height / orig_height
+            if "masks" in target:
+                target["masks"] = F.resize(
+                    target["masks"],
+                    [new_height, new_width],
+                    interpolation=InterpolationMode.NEAREST,
+                    antialias=self.antialias,
+                )
+
+        return image, target
+
+
+class FixedSizeCrop(nn.Module):
+    def __init__(self, size, fill=0, padding_mode="constant"):
+        super().__init__()
+        size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
+        self.crop_height = size[0]
+        self.crop_width = size[1]
+        self.fill = fill  # TODO: Fill is currently respected only on PIL. Apply tensor patch.
+        self.padding_mode = padding_mode
+
+    def _pad(self, img, target, padding):
+        # Taken from the functional_tensor.py pad
+        if isinstance(padding, int):
+            pad_left = pad_right = pad_top = pad_bottom = padding
+        elif len(padding) == 1:
+            pad_left = pad_right = pad_top = pad_bottom = padding[0]
+        elif len(padding) == 2:
+            pad_left = pad_right = padding[0]
+            pad_top = pad_bottom = padding[1]
+        else:
+            pad_left = padding[0]
+            pad_top = padding[1]
+            pad_right = padding[2]
+            pad_bottom = padding[3]
+
+        padding = [pad_left, pad_top, pad_right, pad_bottom]
+        img = F.pad(img, padding, self.fill, self.padding_mode)
+        if target is not None:
+            target["boxes"][:, 0::2] += pad_left
+            target["boxes"][:, 1::2] += pad_top
+            if "masks" in target:
+                target["masks"] = F.pad(target["masks"], padding, 0, "constant")
+
+        return img, target
+
+    def _crop(self, img, target, top, left, height, width):
+        img = F.crop(img, top, left, height, width)
+        if target is not None:
+            boxes = target["boxes"]
+            boxes[:, 0::2] -= left
+            boxes[:, 1::2] -= top
+            boxes[:, 0::2].clamp_(min=0, max=width)
+            boxes[:, 1::2].clamp_(min=0, max=height)
+
+            is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
+
+            target["boxes"] = boxes[is_valid]
+            target["labels"] = target["labels"][is_valid]
+            if "masks" in target:
+                target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
+
+        return img, target
+
+    def forward(self, img, target=None):
+        _, height, width = F.get_dimensions(img)
+        new_height = min(height, self.crop_height)
+        new_width = min(width, self.crop_width)
+
+        if new_height != height or new_width != width:
+            offset_height = max(height - self.crop_height, 0)
+            offset_width = max(width - self.crop_width, 0)
+
+            r = torch.rand(1)
+            top = int(offset_height * r)
+            left = int(offset_width * r)
+
+            img, target = self._crop(img, target, top, left, new_height, new_width)
+
+        pad_bottom = max(self.crop_height - new_height, 0)
+        pad_right = max(self.crop_width - new_width, 0)
+        if pad_bottom != 0 or pad_right != 0:
+            img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
+
+        return img, target
+
+
+class RandomShortestSize(nn.Module):
+    def __init__(
+        self,
+        min_size: Union[List[int], Tuple[int], int],
+        max_size: int,
+        interpolation: InterpolationMode = InterpolationMode.BILINEAR,
+    ):
+        super().__init__()
+        self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
+        self.max_size = max_size
+        self.interpolation = interpolation
+
+    def forward(
+        self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        _, orig_height, orig_width = F.get_dimensions(image)
+
+        min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
+        r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
+
+        new_width = int(orig_width * r)
+        new_height = int(orig_height * r)
+
+        image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
+
+        if target is not None:
+            target["boxes"][:, 0::2] *= new_width / orig_width
+            target["boxes"][:, 1::2] *= new_height / orig_height
+            if "masks" in target:
+                target["masks"] = F.resize(
+                    target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
+                )
+
+        return image, target
+
+
+def _copy_paste(
+    image: torch.Tensor,
+    target: Dict[str, Tensor],
+    paste_image: torch.Tensor,
+    paste_target: Dict[str, Tensor],
+    blending: bool = True,
+    resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
+) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
+
+    # Random paste targets selection:
+    num_masks = len(paste_target["masks"])
+
+    if num_masks < 1:
+        # Such degerante case with num_masks=0 can happen with LSJ
+        # Let's just return (image, target)
+        return image, target
+
+    # We have to please torch script by explicitly specifying dtype as torch.long
+    random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
+    random_selection = torch.unique(random_selection).to(torch.long)
+
+    paste_masks = paste_target["masks"][random_selection]
+    paste_boxes = paste_target["boxes"][random_selection]
+    paste_labels = paste_target["labels"][random_selection]
+
+    masks = target["masks"]
+
+    # We resize source and paste data if they have different sizes
+    # This is something we introduced here as originally the algorithm works
+    # on equal-sized data (for example, coming from LSJ data augmentations)
+    size1 = image.shape[-2:]
+    size2 = paste_image.shape[-2:]
+    if size1 != size2:
+        paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
+        paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
+        # resize bboxes:
+        ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
+        paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
+
+    paste_alpha_mask = paste_masks.sum(dim=0) > 0
+
+    if blending:
+        paste_alpha_mask = F.gaussian_blur(
+            paste_alpha_mask.unsqueeze(0),
+            kernel_size=(5, 5),
+            sigma=[
+                2.0,
+            ],
+        )
+
+    # Copy-paste images:
+    image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
+
+    # Copy-paste masks:
+    masks = masks * (~paste_alpha_mask)
+    non_all_zero_masks = masks.sum((-1, -2)) > 0
+    masks = masks[non_all_zero_masks]
+
+    # Do a shallow copy of the target dict
+    out_target = {k: v for k, v in target.items()}
+
+    out_target["masks"] = torch.cat([masks, paste_masks])
+
+    # Copy-paste boxes and labels
+    boxes = ops.masks_to_boxes(masks)
+    out_target["boxes"] = torch.cat([boxes, paste_boxes])
+
+    labels = target["labels"][non_all_zero_masks]
+    out_target["labels"] = torch.cat([labels, paste_labels])
+
+    # Update additional optional keys: area and iscrowd if exist
+    if "area" in target:
+        out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
+
+    if "iscrowd" in target and "iscrowd" in paste_target:
+        # target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
+        # For example, if previous transforms geometrically modifies masks/boxes/labels but
+        # does not update "iscrowd"
+        if len(target["iscrowd"]) == len(non_all_zero_masks):
+            iscrowd = target["iscrowd"][non_all_zero_masks]
+            paste_iscrowd = paste_target["iscrowd"][random_selection]
+            out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
+
+    # Check for degenerated boxes and remove them
+    boxes = out_target["boxes"]
+    degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+    if degenerate_boxes.any():
+        valid_targets = ~degenerate_boxes.any(dim=1)
+
+        out_target["boxes"] = boxes[valid_targets]
+        out_target["masks"] = out_target["masks"][valid_targets]
+        out_target["labels"] = out_target["labels"][valid_targets]
+
+        if "area" in out_target:
+            out_target["area"] = out_target["area"][valid_targets]
+        if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
+            out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
+
+    return image, out_target
+
+
+class SimpleCopyPaste(torch.nn.Module):
+    def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
+        super().__init__()
+        self.resize_interpolation = resize_interpolation
+        self.blending = blending
+
+    def forward(
+        self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
+    ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
+        torch._assert(
+            isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
+            "images should be a list of tensors",
+        )
+        torch._assert(
+            isinstance(targets, (list, tuple)) and len(images) == len(targets),
+            "targets should be a list of the same size as images",
+        )
+        for target in targets:
+            # Can not check for instance type dict with inside torch.jit.script
+            # torch._assert(isinstance(target, dict), "targets item should be a dict")
+            for k in ["masks", "boxes", "labels"]:
+                torch._assert(k in target, f"Key {k} should be present in targets")
+                torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
+
+        # images = [t1, t2, ..., tN]
+        # Let's define paste_images as shifted list of input images
+        # paste_images = [t2, t3, ..., tN, t1]
+        # FYI: in TF they mix data on the dataset level
+        images_rolled = images[-1:] + images[:-1]
+        targets_rolled = targets[-1:] + targets[:-1]
+
+        output_images: List[torch.Tensor] = []
+        output_targets: List[Dict[str, Tensor]] = []
+
+        for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
+            output_image, output_data = _copy_paste(
+                image,
+                target,
+                paste_image,
+                paste_target,
+                blending=self.blending,
+                resize_interpolation=self.resize_interpolation,
+            )
+            output_images.append(output_image)
+            output_targets.append(output_data)
+
+        return output_images, output_targets
+
+    def __repr__(self) -> str:
+        s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
+        return s

+ 295 - 0
tools/utils.py

@@ -0,0 +1,295 @@
+import datetime
+import errno
+import os
+import time
+from collections import defaultdict, deque
+
+import torch
+import torch.distributed as dist
+from torch.utils.data.dataloader import default_collate
+
+class SmoothedValue:
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
+        )
+
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+    data_list = [None] * world_size
+    dist.all_gather_object(data_list, data)
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that all processes
+    have the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.inference_mode():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.all_reduce(values)
+        if average:
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
+
+
+class MetricLogger:
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(f"{name}: {str(meter)}")
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ""
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt="{avg:.4f}")
+        data_time = SmoothedValue(fmt="{avg:.4f}")
+        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+        if torch.cuda.is_available():
+            log_msg = self.delimiter.join(
+                [
+                    header,
+                    "[{0" + space_fmt + "}/{1}]",
+                    "eta: {eta}",
+                    "{meters}",
+                    "time: {time}",
+                    "data: {data}",
+                    "max mem: {memory:.0f}",
+                ]
+            )
+        else:
+            log_msg = self.delimiter.join(
+                [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
+            )
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(
+                        log_msg.format(
+                            i,
+                            len(iterable),
+                            eta=eta_string,
+                            meters=str(self),
+                            time=str(iter_time),
+                            data=str(data_time),
+                            memory=torch.cuda.max_memory_allocated() / MB,
+                        )
+                    )
+                else:
+                    print(
+                        log_msg.format(
+                            i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
+                        )
+                    )
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
+
+
+def collate_fn(batch):
+    print(f'batch:{len(batch)}')
+    return tuple(zip(*batch))
+
+def collate_fn_wirepoint(batch):
+    # print(f'batch[0]:{batch[0]}')
+    # for b in batch:
+    #     b[1]["wires"]= [b[1]["wires"]]
+
+    # default_collate([b[1] for b in batch]),
+
+
+    # batch[0][1]["wires"] = [batch[0][1]["wires"]]
+    batch=tuple(zip(*batch))
+
+    # print(f'batch_post:{batch}')
+    return batch
+def mkdir(path):
+    try:
+        os.makedirs(path)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop("force", False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.gpu = int(os.environ["LOCAL_RANK"])
+    elif "SLURM_PROCID" in os.environ:
+        args.rank = int(os.environ["SLURM_PROCID"])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print("Not using distributed mode")
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = "nccl"
+    print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
+    torch.distributed.init_process_group(
+        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
+    )
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)