浏览代码

backbone is ResNet18

xue50 5 月之前
父节点
当前提交
42d97d47b3

+ 26 - 0
.gitignore

@@ -1,5 +1,31 @@
 .idea
 *.pt
+*.log
+*.onnx
 runs
+logs
+log
+
+/tensorboard/
+logs/
+tensorboard_logs/
+summaries/
+events.out.tfevents.*
+
+# If you have a specific directory for your runs, you can ignore it directly
+/runs/
+
+# Ignore checkpoint files if you don't want to track them
+checkpoint
+*.ckpt.data-*
+*.ckpt.index
+*.ckpt.meta
+
+# Ignore TensorFlow model files that are not necessary for version control
+*.pb
+/*.pbtxt
+# Ignore Jupyter Notebook checkpoints
+/.ipynb_checkpoints/
+
 __pycache__
 train_results

+ 2 - 1
models/ins/trainer.py

@@ -80,7 +80,8 @@ def train(model, **kwargs):
     # 默认参数
     default_params = {
         'dataset_path': '/path/to/dataset',
-        'num_classes': 10,
+        'num_classes': 2,
+        'num_keypoints':2,
         'opt': 'adamw',
         'batch_size': 2,
         'epochs': 10,

+ 0 - 0
models/keypoint/__init__.py


+ 78 - 0
models/keypoint/kepointrcnn.py

@@ -0,0 +1,78 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.keypoint.trainer import train_cfg
+
+from tools import utils
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+
+class KeypointRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
+        super(KeypointRCNNModel, self).__init__()
+        default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+        self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
+                                                                              num_keypoints=num_keypoints,
+                                                                              progress=False)
+        if transforms is None:
+            self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
+        # if num_classes != 0:
+        #     self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes = parameters['num_classes']
+        num_keypoints = parameters['num_keypoints']
+        # print(f'num_classes:{num_classes}')
+        # self.set_num_classes(num_classes)
+        self.num_keypoints = num_keypoints
+        train_cfg(self.__model, cfg)
+
+    # def set_num_classes(self, num_classes):
+    #     in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+    #     self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+    #
+    #     # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+    #     in_channels = self.__model.roi_heads.keypoint_predictor.
+    #     hidden_layer = 256
+    #     self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
+    #                                                               num_classes=num_classes)
+    #     self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    keypoint_model = KeypointRCNNModel(num_keypoints=2)
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    keypoint_model.train(cfg='train.yaml')

+ 202 - 0
models/keypoint/keypoint_dataset.py

@@ -0,0 +1,202 @@
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+def validate_keypoints(keypoints, image_width, image_height):
+    for kp in keypoints:
+        x, y, v = kp
+        if not (0 <= x < image_width and 0 <= y < image_height):
+            raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
+
+
+class KeypointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+
+        target["labels"] = torch.stack(labels)
+        # print(f'labels:{target["labels"]}')
+        target["boxes"] = line_boxes(target)
+        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
+
+        keypoints= wire_labels["junc_coords"]
+        keypoints[:,2]=2
+        # keypoints[:,0]=keypoints[:,0]/shape[0]
+        # keypoints[:, 1] = keypoints[:, 1] / shape[1]
+        target["keypoints"]=keypoints
+        # 在 __getitem__ 方法中调用此函数
+        validate_keypoints(keypoints, shape[0], shape[1])
+        print(f'keypoints:{target["keypoints"].shape}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
+
+
+
+if __name__ == '__main__':
+    path=r"I:\wirenet_dateset"
+    dataset= KeypointDataset(dataset_path=path, dataset_type='train')
+    dataset.show(0)

+ 61 - 0
models/keypoint/test.py

@@ -0,0 +1,61 @@
+import time
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.io import decode_image, read_image
+import torchvision.transforms.functional as F
+from torchvision.utils import draw_keypoints
+def show(imgs):
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
+    for i, img in enumerate(imgs):
+        img = img.detach()
+        img = F.to_pil_image(img)
+        axs[0, i].imshow(np.asarray(img))
+        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
+
+img_path=r"D:\python\PycharmProjects\data\images\val\00031591_0.png"
+# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
+img_int = read_image(img_path)
+
+
+# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
+
+weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+transforms = weights.transforms()
+print(f'transforms:{transforms}')
+img = transforms(img_int)
+
+person_float = transforms(img)
+
+model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
+model = model.eval()
+t1=time.time()
+# img = torch.ones((3, 3, 512, 512))
+
+
+outputs = model([img])
+t2=time.time()
+print(f'time:{t2-t1}')
+# print(f'outputs:{outputs}')
+
+kpts = outputs[0]['keypoints']
+scores = outputs[0]['scores']
+
+print(f'kpts:{kpts}')
+print(f'scores:{scores}')
+
+detect_threshold = 0.75
+idx = torch.where(scores > detect_threshold)
+keypoints = kpts[idx]
+
+# print(f'keypoints:{keypoints}')
+
+
+
+res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
+show(res)
+plt.show()

+ 32 - 0
models/keypoint/train.yaml

@@ -0,0 +1,32 @@
+
+
+dataset_path: D:\python\PycharmProjects\data
+
+#train parameters
+num_classes: 2
+num_keypoints: 2
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: pixel
+enable_logs: True
+augmentation: False
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 212 - 0
models/keypoint/trainer.py

@@ -0,0 +1,212 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.keypoint.keypoint_dataset import KeypointDataset
+from tools import utils, presets
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        # print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            print(f'loss_dict:{loss_dict}')
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = KeypointDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 0 - 0
models/obj/__init__.py


+ 46 - 0
models/wirenet/TestPointMap.py

@@ -0,0 +1,46 @@
+def map_heatmap_keypoints_to_original_image(heatmap, rois, downsample_ratio=4, joff=None):
+    """
+    将热力图中的关键点映射回原始图像的位置。
+
+    参数:
+    heatmap (torch.Tensor): 热力图,形状为 [H, W]
+    rois (list of tuples): 每个ROI的坐标列表 [(x_min, y_min, x_max, y_max), ...]
+    downsample_ratio (int): 下采样比例,默认为4
+    joff (torch.Tensor, optional): 偏移图,形状为 [2, H, W]
+
+    返回:
+    list of tuples: 每个ROI对应的关键点在原始图像中的坐标 [(x, y), ...]
+    """
+    keypoints_in_original_image = []
+
+    for i, (x_min, y_min, x_max, y_max) in enumerate(rois):
+        roi_width = x_max - x_min
+        roi_height = y_max - y_min
+
+        # 获取热力图中的关键点位置
+        heatmap_roi = heatmap[i] if len(heatmap.shape) == 4 else heatmap
+        y_prime, x_prime = torch.where(heatmap_roi == torch.max(heatmap_roi))
+
+        if len(y_prime) > 0 and len(x_prime) > 0:
+            y_prime, x_prime = y_prime.item(), x_prime.item()
+
+            # 如果有偏移图,则应用偏移修正
+            if joff is not None:
+                offset_x = joff[0, y_prime, x_prime].item()
+                offset_y = joff[1, y_prime, x_prime].item()
+                x_prime += offset_x
+                y_prime += offset_y
+
+            # 计算ROI内的相对坐标
+            relative_x = x_prime / 128 * roi_width
+            relative_y = y_prime / 128 * roi_height
+
+            # 映射回原始图像坐标
+            final_x = relative_x + x_min
+            final_y = relative_y + y_min
+
+            keypoints_in_original_image.append((final_x.item(), final_y.item()))
+        else:
+            keypoints_in_original_image.append(None)  # 如果没有找到关键点
+
+    return keypoints_in_original_image

文件差异内容过多而无法显示
+ 2693 - 0
models/wirenet/head.py


+ 1 - 0
models/wirenet/roi_head.py

@@ -767,6 +767,7 @@ class RoIHeads(nn.Module):
             regression_targets = None
             matched_idxs = None
 
+
         box_features = self.box_roi_pool(features, proposals, image_shapes)
         box_features = self.box_head(box_features)
         class_logits, box_regression = self.box_predictor(box_features)

+ 11 - 2
models/wirenet/test.py

@@ -1,6 +1,8 @@
 from models.wirenet.wirepoint_dataset import WirePointDataset
 from models.config.config_tool import read_yaml
 
+import matplotlib.pyplot as plt
+
 # image_file = "D:/python/PycharmProjects/data"
 #
 # label_file = "D:/python/PycharmProjects/data/labels/train"
@@ -8,6 +10,8 @@ from models.config.config_tool import read_yaml
 # dataset_test.show(0)
 # for i in dataset_test:
 #     print(i)
+
+
 cfg = 'wirenet.yaml'
 cfg = read_yaml(cfg)
 print(f'cfg:{cfg}')
@@ -17,7 +21,12 @@ print(cfg['model']['n_dyn_negl'])
 dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
 # dataset.show(0)
 
-for i in range(len(dataset)):
-    dataset.show(i)
+# for i in range(len(dataset)):
+#     # dataset.show(i)
+for i in dataset:
+    # dataset.show(i)
+    # print(i)
+    print(i[1]['wires']['line_map'].shape)
+    plt.show(dataset['wires']['line_map'])
 
 

+ 1 - 1
models/wirenet/wirenet.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: D:/python/PycharmProjects/data
+  datadir: D:\python\PycharmProjects\data
   resume_from:
   num_workers: 4
   tensorboard_port: 0

+ 182 - 43
models/wirenet/wirepoint_rcnn.py

@@ -10,6 +10,7 @@ import torch.nn.functional as F
 # from torchinfo import summary
 from torchvision.io import read_image
 from torchvision.models import resnet50, ResNet50_Weights
+from torchvision.models import resnet18, ResNet18_Weights
 from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
 from torchvision.models.detection._utils import overwrite_eps
 from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
@@ -522,26 +523,55 @@ class WirepointPredictor(nn.Module):
             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
             return line, label.float(), feat, jcs
 
+# def wirepointrcnn_resnet50_fpn(
+#         *,
+#         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
+#         progress: bool = True,
+#         num_classes: Optional[int] = None,
+#         num_keypoints: Optional[int] = None,
+#         weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+#         trainable_backbone_layers: Optional[int] = None,
+#         **kwargs: Any,
+# ) -> WirepointRCNN:
+#     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
+#     weights_backbone = ResNet50_Weights.verify(weights_backbone)
+#
+#     is_trained = weights is not None or weights_backbone is not None
+#     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+#
+#     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+#
+#     backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+#     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+#     model = WirepointRCNN(backbone, num_classes=5, **kwargs)
+#
+#     if weights is not None:
+#         model.load_state_dict(weights.get_state_dict(progress=progress))
+#         if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
+#             overwrite_eps(model, 0.0)
+#
+#     return model
 
-def wirepointrcnn_resnet50_fpn(
+
+def wirepointrcnn_resnet18_fpn(
         *,
         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
         progress: bool = True,
         num_classes: Optional[int] = None,
         num_keypoints: Optional[int] = None,
-        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
         trainable_backbone_layers: Optional[int] = None,
         **kwargs: Any,
 ) -> WirepointRCNN:
     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
-    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    weights_backbone = ResNet18_Weights.verify(weights_backbone)
 
     is_trained = weights is not None or weights_backbone is not None
     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
 
     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
 
-    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
     model = WirepointRCNN(backbone, num_classes=5, **kwargs)
 
@@ -577,27 +607,128 @@ sm.set_array([])
 def c(x):
     return sm.to_rgba(x)
 
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+    # plt.show()
+
+
+# def _plot_samples(img, i, result, prefix, epoch):
+#     print(f"prefix:{prefix}")
+#     def draw_vecl(lines, sline, juncs, junts, fn):
+#         directory = os.path.dirname(fn)
+#         if not os.path.exists(directory):
+#             os.makedirs(directory)
+#         imshow(img.permute(1, 2, 0))
+#         if len(lines) > 0 and not (lines[0] == 0).all():
+#             for i, ((a, b), s) in enumerate(zip(lines, sline)):
+#                 if i > 0 and (lines[i] == lines[0]).all():
+#                     break
+#                 plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+#         if not (juncs[0] == 0).all():
+#             for i, j in enumerate(juncs):
+#                 if i > 0 and (i == juncs[0]).all():
+#                     break
+#                 plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+#         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+#             for i, j in enumerate(junts):
+#                 if i > 0 and (i == junts[0]).all():
+#                     break
+#                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+#         plt.savefig(fn), plt.close()
+#
+#     rjuncs = result["juncs"][i].cpu().numpy() * 4
+#     rjunts = None
+#     if "junts" in result:
+#         rjunts = result["junts"][i].cpu().numpy() * 4
+#
+#     vecl_result = result["lines"][i].cpu().numpy() * 4
+#     score = result["score"][i].cpu().numpy()
 #
-# def imshow(im):
-#     plt.close()
-#     plt.tight_layout()
-#     plt.imshow(im)
-#     plt.colorbar(sm, fraction=0.046)
-#     plt.xlim([0, im.shape[0]])
-#     plt.ylim([im.shape[0], 0])
-#     # plt.show()
+#     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+#
+#     img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
+#     writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
+
+def _plot_samples(img, i, result, prefix, epoch, writer):
+    # print(f"prefix:{prefix}")
+
+    def draw_vecl(lines, sline, juncs, junts, fn):
+        # 确保目录存在
+        directory = os.path.dirname(fn)
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+
+        # 绘制图像
+        plt.figure()
+        plt.imshow(img.permute(1, 2, 0).cpu().numpy())
+        plt.axis('off')  # 可选:关闭坐标轴
+
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for idx, ((a, b), s) in enumerate(zip(lines, sline)):
+                if idx > 0 and (lines[idx] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1)
+
+        if not (juncs[0] == 0).all():
+            for idx, j in enumerate(juncs):
+                if idx > 0 and (j == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=20, zorder=100)
+
+        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+            for idx, j in enumerate(junts):
+                if idx > 0 and (j == junts[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="blue", s=20, zorder=100)
+
+        # plt.show()
+
+        # 将matplotlib图像转换为numpy数组
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
 
+        return image_from_plot
 
-def show_line(img, pred,  epoch, write):
-    im = img.permute(1, 2, 0)
-    writer.add_image("ori", im, epoch, dataformats="HWC")
+    # 获取结果数据并转换为numpy数组
+    rjuncs = result["juncs"][i].cpu().numpy() * 4
+    rjunts = None
+    if "junts" in result:
+        rjunts = result["junts"][i].cpu().numpy() * 4
 
-    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
-                                      colors="yellow", width=1)
-    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+    vecl_result = result["lines"][i].cpu().numpy() * 4
+    score = result["score"][i].cpu().numpy()
 
+    # 调用绘图函数并获取图像
+    image_path = f"{prefix}_vecl_b.jpg"
+    image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path)
+
+    # 将numpy数组转换为torch tensor,并写入TensorBoard
+    image_tensor = transforms.ToTensor()(image_array)
+    writer.add_image(f'output_epoch', image_tensor, global_step=epoch)
+    writer.add_image(f'ori_epoch', img, global_step=epoch)
+
+
+def show_line(img, pred, prefix, epoch, write):
+    fn = f"{prefix}_line.jpg"
+    directory = os.path.dirname(fn)
+    if not os.path.exists(directory):
+        os.makedirs(directory)
+    print(fn)
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-    H = pred[1]['wires']
+    H = pred
+
+    im = img.permute(1, 2, 0)
+
     lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
     scores = H["score"][0].cpu().numpy()
     for i in range(1, len(lines)):
@@ -623,14 +754,14 @@ def show_line(img, pred,  epoch, write):
         plt.gca().xaxis.set_major_locator(plt.NullLocator())
         plt.gca().yaxis.set_major_locator(plt.NullLocator())
         plt.imshow(im)
-        plt.tight_layout()
-        fig = plt.gcf()
-        fig.canvas.draw()
-        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
-            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.savefig(fn, bbox_inches="tight")
+        plt.show()
         plt.close()
-        img2 = transforms.ToTensor()(image_from_plot)
 
+        img2 = cv2.imread(fn)  # 预测图
+        # img1 = im.resize(img2.shape)  # 原图
+
+        # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC')
         writer.add_image("output", img2, epoch)
 
 
@@ -669,7 +800,11 @@ if __name__ == '__main__':
         dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
     )
 
-    model = wirepointrcnn_resnet50_fpn().to(device)
+    model = wirepointrcnn_resnet18_fpn().to(device)
+    # print(model)
+
+    # model1 = wirepointrcnn_resnet50_fpn().to(device)
+    # print(model1)
 
     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
     writer = SummaryWriter(cfg['io']['logdir'])
@@ -709,23 +844,27 @@ if __name__ == '__main__':
         print(f"epoch:{epoch}")
         model.train()
 
-        # for imgs, targets in data_loader_train:
-        # losses = model(move_to_device(imgs, device), move_to_device(targets, device))
-        # loss = _loss(losses)
-        # print(loss)
-        # optimizer.zero_grad()
-        # loss.backward()
-        # optimizer.step()
-        # writer_loss(writer, losses, epoch)
-
-        model.eval()
-        with torch.no_grad():
-            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                pred = model(move_to_device(imgs, device))  # # pred[0].keys()   ['boxes', 'labels', 'scores']
-                # print(f"pred:{pred}")
-
-                if batch_idx == 0:
-                    show_line(imgs[0], pred,  epoch, writer)
+        for imgs, targets in data_loader_train:
+            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+            loss = _loss(losses)
+            print(f"loss:{loss}")
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            writer_loss(writer, losses, epoch)
+
+        # model.eval()
+        # with torch.no_grad():
+        #     for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+        #         pred = model(move_to_device(imgs, device))
+        #         print(f"pred:{pred}")
+        #
+        #         if batch_idx == 0:
+        #             result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
+        #             print(imgs[0].shape)  # [3,512,512]
+        #             # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
+        #             _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
+        #             show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
 
 # imgs, targets = next(iter(data_loader))
 #

部分文件因为文件数量过多而无法显示