Browse Source

self.training is wrong

xue50 3 months ago
parent
commit
6251ff34cd

+ 204 - 2
lcnn/datasets.py

@@ -96,6 +96,9 @@
 #         default_collate([b[2] for b in batch]),
 #     )
 
+
+# 原LCNN数据格式,改了属性名,加了box相关
+
 from torch.utils.data.dataset import T_co
 
 from .models.base.base_dataset import BaseDataset
@@ -122,8 +125,6 @@ from torch.utils.data.dataloader import default_collate
 import matplotlib.pyplot as plt
 from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
 
-
-
 class  WireframeDataset(BaseDataset):
     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
         super().__init__(dataset_path)
@@ -292,3 +293,204 @@ def collate(batch):
 #     dataset.show(0)
 
 
+
+'''
+# 使用roi_head数据格式有要求,更改数据格式
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+from tools.presets import DetectionPresetTrain
+
+
+class WireframeDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+        self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip")  # multiscale会改变图像大小
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        # if self.transforms:
+        #     img, target = self.transforms(img, target)
+        # else:
+        #     img = self.default_transform(img)
+
+        img, target = self.data_augmentation(img, target)
+
+        print(f'img:{img.shape}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        # if self.target_type == 'polygon':
+        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        # elif self.target_type == 'pixel':
+        #     labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+        # target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        self._draw_vecl(img_path, target)
+
+    def show_img(self, img_path):
+
+        """根据给定的图片路径展示图像及其标注信息"""
+        # 获取对应的标签文件路径
+        img_name = os.path.basename(img_path)
+        img_path = os.path.join(self.img_path, img_name)
+        print(img_path)
+        lbl_name = img_name[:-3] + 'json'
+        lbl_path = os.path.join(self.lbl_path, lbl_name)
+        print(lbl_path)
+
+        if not os.path.exists(lbl_path):
+            raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        target = self.read_target(0, lbl_path, shape=(h, w))
+
+        # 调用绘图函数
+        self._draw_vecl(img_path, target)
+
+
+    def _draw_vecl(self, img_path, target, fn=None):
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        lines = lpre
+        sline = np.ones(lpre.shape[0])
+        imshow(io.imread(img_path))
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                if i > 0 and (lines[i] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+        if not (juncs[0] == 0).all():
+            for i, j in enumerate(juncs):
+                if i > 0 and (i == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                          colors="yellow", width=1)
+        plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+        plt.show()
+
+        if fn != None:
+            plt.savefig(fn)
+
+'''

+ 6 - 0
libs/vision_libs/models/detection/generalized_rcnn.py

@@ -108,6 +108,12 @@ class GeneralizedRCNN(nn.Module):
         losses = {}
         losses.update(detector_losses)
         losses.update(proposal_losses)
+        # print(f'1{detector_losses.keys()}')
+        # print(f'2{proposal_losses.keys()}')
+        # print(f'123{losses.keys()}')
+        print(f'self.training:{self.training}')
+        print(f'123{losses}')
+
 
         if torch.jit.is_scripting():
             if not self._has_warned:

+ 1 - 6
models/dataset_tool.py

@@ -224,17 +224,12 @@ def line_boxes(target):
     lines = lpre
     sline = np.ones(lpre.shape[0])
 
-    keypoints = []
-
     if len(lines) > 0 and not (lines[0] == 0).all():
         for i, ((a, b), s) in enumerate(zip(lines, sline)):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
 
-            keypoints.append([a[0], b[0]])
-            keypoints.append([a[1], b[1]])
-
             if a[1] > b[1]:
                 ymax = a[1] + 10
                 ymin = b[1] - 10
@@ -249,7 +244,7 @@ def line_boxes(target):
                 xmax = b[0] + 10
             boxs.append([ymin, xmin, ymax, xmax])
 
-    return torch.tensor(boxs), torch.tensor(keypoints)
+    return torch.tensor(boxs)
 
 
 def read_polygon_points_wire(lbl_path, shape):

+ 0 - 4
models/line_detect/aaa.py

@@ -1,4 +0,0 @@
-from models.config.config_tool import read_yaml
-
-cfg = read_yaml('wireframe.yaml')
-print(cfg)

+ 178 - 0
models/line_detect/dataset_LD.py

@@ -0,0 +1,178 @@
+# 使用roi_head数据格式有要求,更改数据格式
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+from tools.presets import DetectionPresetTrain
+
+
+class WirePointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # 不存在线条的临界矩阵
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+        target["labels"] = torch.stack(labels)
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        target["boxes"] = line_boxes(target)
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass

+ 277 - 31
models/line_detect/line_rcnn.py

@@ -13,7 +13,7 @@ from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_
 from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
 from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 
 from models.config.config_tool import read_yaml
 import numpy as np
@@ -196,7 +196,7 @@ class LineRCNN(FasterRCNN):
             backbone,
             num_classes=None,
             # transform parameters
-            min_size=None,
+            min_size=512,   # 原为None
             max_size=1333,
             image_mean=None,
             image_std=None,
@@ -292,6 +292,18 @@ class LineRCNN(FasterRCNN):
             **kwargs,
         )
 
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
         roi_heads = RoIHeads(
             # Box
             box_roi_pool,
@@ -311,7 +323,6 @@ class LineRCNN(FasterRCNN):
         )
         # super().roi_heads = roi_heads
         self.roi_heads = roi_heads
-
         self.roi_heads.line_head = line_head
         self.roi_heads.line_predictor = line_predictor
 
@@ -355,7 +366,7 @@ class LineRCNNPredictor(nn.Module):
         super().__init__()
         # self.backbone = backbone
         # self.cfg = read_yaml(cfg)
-        self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\models\line_detect\wireframe.yaml')
+        self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml')
         self.n_pts0 = self.cfg['model']['n_pts0']
         self.n_pts1 = self.cfg['model']['n_pts1']
         self.n_stc_posl = self.cfg['model']['n_stc_posl']
@@ -402,12 +413,15 @@ class LineRCNNPredictor(nn.Module):
             )
         self.loss = nn.BCEWithLogitsLoss(reduction="none")
 
-    def forward(self, result, targets=None):
+    def forward(self, inputs, features, targets=None):
 
-        # result = self.backbone(input_dict)
-        h = result["preds"]
-        x = self.fc1(result["feature"])
-        n_batch, n_channel, row, col = x.shape
+        # outputs, features = input
+        # for out in outputs:
+        #     print(f'out:{out.shape}')
+        # outputs=merge_features(outputs,100)
+        batch, channel, row, col = inputs.shape
+        # print(f'outputs:{inputs.shape}')
+        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
 
         if targets is not None:
             self.training = True
@@ -430,30 +444,61 @@ class LineRCNNPredictor(nn.Module):
             }
         else:
             self.training = False
-            # self.training = False
             t = {
-                "junc_coords": torch.zeros(1, 2).to(device),
-                "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
-                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
-                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
-                "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
+                "junc_coords": torch.zeros(1, 2),
+                "jtyp": torch.zeros(1, dtype=torch.uint8),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
             }
             wires_targets = [t for b in range(inputs.size(0))]
 
             wires_meta = {
-                "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
             }
 
+        T = wires_meta.copy()
+        n_jtyp = T["junc_map"].shape[1]
+        offset = self.head_off
+        result = {}
+        for stack, output in enumerate([inputs]):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+
+        h = result["preds"]
+        # print(f'features shape:{features.shape}')
+        x = self.fc1(features)
+
+        # print(f'x:{x.shape}')
+
+        n_batch, n_channel, row, col = x.shape
+
+        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
+
         xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-        for i, meta in enumerate(input_dict["meta"]):
+
+        for i, meta in enumerate(wires_targets):
             p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
+                meta, h["jmap"][i], h["joff"][i],
             )
-            # print("p.shape:", p.shape)
+            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
             ys.append(label)
-            if input_dict["mode"] == "training" and self.do_static_sampling:
+            if self.training and self.do_static_sampling:
                 p = torch.cat([p, meta["lpre"]])
                 feat = torch.cat([feat, meta["lpre_feat"]])
                 ys.append(meta["lpre_label"])
@@ -480,25 +525,28 @@ class LineRCNNPredictor(nn.Module):
                         + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
                         + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
                 )
-                    .reshape(n_channel, -1, M.n_pts0)
+                    .reshape(n_channel, -1, self.n_pts0)
                     .permute(1, 0, 2)
             )
             xp = self.pooling(xp)
+            # print(f'xp.shape:{xp.shape}')
             xs.append(xp)
             idx.append(idx[-1] + xp.shape[0])
-
+            # print(f'idx__:{idx}')
 
         x, y = torch.cat(xs), torch.cat(ys)
         f = torch.cat(fs)
         x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+
+        # print("Weight dtype:", self.fc2.weight.dtype)
         x = torch.cat([x, f], 1)
+        # print("Input dtype:", x.dtype)
         x = x.to(dtype=torch.float32)
+        # print("Input dtype1:", x.dtype)
         x = self.fc2(x).flatten()
 
         # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
-        all=[x, ys, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc]
-        return all
-        # return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
 
         # if mode != "training":
         # self.inference(x, idx, jcs, n_batch, ps)
@@ -536,9 +584,6 @@ class LineRCNNPredictor(nn.Module):
             xy_ = xy[..., None, :]
             del x, y, index
 
-            # print(f"xy_.is_cuda: {xy_.is_cuda}")
-            # print(f"junc.is_cuda: {junc.is_cuda}")
-
             # dist: [N_TYPE, K, N]
             dist = torch.sum((xy_ - junc) ** 2, -1)
             cost, match = torch.min(dist, -1)
@@ -604,6 +649,208 @@ class LineRCNNPredictor(nn.Module):
             xy = xy.reshape(n_type, K, 2)
             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
             return line, label.float(), feat, jcs
+    # def forward(self, result, targets=None):
+    #
+    #     # result = self.backbone(input_dict)
+    #     h = result["preds"]
+    #     x = self.fc1(result["feature"])
+    #     n_batch, n_channel, row, col = x.shape
+    #
+    #     if targets is not None:
+    #         self.training = True
+    #         # print(f'target:{targets}')
+    #         wires_targets = [t["wires"] for t in targets]
+    #         # print(f'wires_target:{wires_targets}')
+    #         # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+    #         junc_maps = [d["junc_map"] for d in wires_targets]
+    #         junc_offsets = [d["junc_offset"] for d in wires_targets]
+    #         line_maps = [d["line_map"] for d in wires_targets]
+    #
+    #         junc_map_tensor = torch.stack(junc_maps, dim=0)
+    #         junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+    #         line_map_tensor = torch.stack(line_maps, dim=0)
+    #
+    #         wires_meta = {
+    #             "junc_map": junc_map_tensor,
+    #             "junc_offset": junc_offset_tensor,
+    #             # "line_map": line_map_tensor,
+    #         }
+    #     else:
+    #         self.training = False
+    #         # self.training = False
+    #         t = {
+    #             "junc_coords": torch.zeros(1, 2).to(device),
+    #             "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
+    #             "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+    #             "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+    #             "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
+    #             "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
+    #         }
+    #         wires_targets = [t for b in range(inputs.size(0))]
+    #
+    #         wires_meta = {
+    #             "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
+    #             "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
+    #         }
+    #
+    #     xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+    #     for i, meta in enumerate(input_dict["meta"]):
+    #         p, label, feat, jc = self.sample_lines(
+    #             meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
+    #         )
+    #         # print("p.shape:", p.shape)
+    #         ys.append(label)
+    #         if input_dict["mode"] == "training" and self.do_static_sampling:
+    #             p = torch.cat([p, meta["lpre"]])
+    #             feat = torch.cat([feat, meta["lpre_feat"]])
+    #             ys.append(meta["lpre_label"])
+    #             del jc
+    #         else:
+    #             jcs.append(jc)
+    #             ps.append(p)
+    #         fs.append(feat)
+    #
+    #         p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+    #         p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+    #         px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+    #         px0 = px.floor().clamp(min=0, max=127)
+    #         py0 = py.floor().clamp(min=0, max=127)
+    #         px1 = (px0 + 1).clamp(min=0, max=127)
+    #         py1 = (py0 + 1).clamp(min=0, max=127)
+    #         px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+    #
+    #         # xp: [N_LINE, N_CHANNEL, N_POINT]
+    #         xp = (
+    #             (
+    #                     x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+    #                     + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+    #                     + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+    #                     + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+    #             )
+    #                 .reshape(n_channel, -1, M.n_pts0)
+    #                 .permute(1, 0, 2)
+    #         )
+    #         xp = self.pooling(xp)
+    #         xs.append(xp)
+    #         idx.append(idx[-1] + xp.shape[0])
+    #
+    #
+    #     x, y = torch.cat(xs), torch.cat(ys)
+    #     f = torch.cat(fs)
+    #     x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+    #     x = torch.cat([x, f], 1)
+    #     x = x.to(dtype=torch.float32)
+    #     x = self.fc2(x).flatten()
+    #
+    #     # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
+    #     all=[x, ys, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc]
+    #     return all
+    #     # return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+    #
+    #     # if mode != "training":
+    #     # self.inference(x, idx, jcs, n_batch, ps)
+    #
+    #     # return result
+    #
+    # def sample_lines(self, meta, jmap, joff):
+    #     with torch.no_grad():
+    #         junc = meta["junc_coords"]  # [N, 2]
+    #         jtyp = meta["jtyp"]  # [N]
+    #         Lpos = meta["line_pos_idx"]
+    #         Lneg = meta["line_neg_idx"]
+    #
+    #         n_type = jmap.shape[0]
+    #         jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+    #         joff = joff.reshape(n_type, 2, -1)
+    #         max_K = self.n_dyn_junc // n_type
+    #         N = len(junc)
+    #         # if mode != "training":
+    #         if not self.training:
+    #             K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
+    #         else:
+    #             K = min(int(N * 2 + 2), max_K)
+    #         if K < 2:
+    #             K = 2
+    #         device = jmap.device
+    #
+    #         # index: [N_TYPE, K]
+    #         score, index = torch.topk(jmap, k=K)
+    #         y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+    #         x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+    #
+    #         # xy: [N_TYPE, K, 2]
+    #         xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+    #         xy_ = xy[..., None, :]
+    #         del x, y, index
+    #
+    #         # print(f"xy_.is_cuda: {xy_.is_cuda}")
+    #         # print(f"junc.is_cuda: {junc.is_cuda}")
+    #
+    #         # dist: [N_TYPE, K, N]
+    #         dist = torch.sum((xy_ - junc) ** 2, -1)
+    #         cost, match = torch.min(dist, -1)
+    #
+    #         # xy: [N_TYPE * K, 2]
+    #         # match: [N_TYPE, K]
+    #         for t in range(n_type):
+    #             match[t, jtyp[match[t]] != t] = N
+    #         match[cost > 1.5 * 1.5] = N
+    #         match = match.flatten()
+    #
+    #         _ = torch.arange(n_type * K, device=device)
+    #         u, v = torch.meshgrid(_, _)
+    #         u, v = u.flatten(), v.flatten()
+    #         up, vp = match[u], match[v]
+    #         label = Lpos[up, vp]
+    #
+    #         # if mode == "training":
+    #         if self.training:
+    #             c = torch.zeros_like(label, dtype=torch.bool)
+    #
+    #             # sample positive lines
+    #             cdx = label.nonzero().flatten()
+    #             if len(cdx) > self.n_dyn_posl:
+    #                 # print("too many positive lines")
+    #                 perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
+    #                 cdx = cdx[perm]
+    #             c[cdx] = 1
+    #
+    #             # sample negative lines
+    #             cdx = Lneg[up, vp].nonzero().flatten()
+    #             if len(cdx) > self.n_dyn_negl:
+    #                 # print("too many negative lines")
+    #                 perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
+    #                 cdx = cdx[perm]
+    #             c[cdx] = 1
+    #
+    #             # sample other (unmatched) lines
+    #             cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
+    #             c[cdx] = 1
+    #         else:
+    #             c = (u < v).flatten()
+    #
+    #         # sample lines
+    #         u, v, label = u[c], v[c], label[c]
+    #         xy = xy.reshape(n_type * K, 2)
+    #         xyu, xyv = xy[u], xy[v]
+    #
+    #         u2v = xyu - xyv
+    #         u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+    #         feat = torch.cat(
+    #             [
+    #                 xyu / 128 * self.use_cood,
+    #                 xyv / 128 * self.use_cood,
+    #                 u2v * self.use_slop,
+    #                 (u[:, None] > K).float(),
+    #                 (v[:, None] > K).float(),
+    #             ],
+    #             1,
+    #         )
+    #         line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+    #
+    #         xy = xy.reshape(n_type, K, 2)
+    #         jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+    #         return line, label.float(), feat, jcs
 
 
 
@@ -746,7 +993,6 @@ def linercnn_resnet50_fpn(
     """
     weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
     weights_backbone = ResNet50_Weights.verify(weights_backbone)
-
     if weights is not None:
         weights_backbone = None
         num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))

+ 139 - 11
models/line_detect/roi_heads.py

@@ -146,6 +146,117 @@ def line_vectorizer_loss(result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out
     return result
 
 
+
+def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
+    # output, feature: head返回结果
+    # x, y, idx : line中间生成结果
+    result = {}
+    batch, channel, row, col = output.shape
+
+    wires_targets = [t["wires"] for t in targets]
+    wires_targets = wires_targets.copy()
+    # print(f'wires_target:{wires_targets}')
+    # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+    junc_maps = [d["junc_map"] for d in wires_targets]
+    junc_offsets = [d["junc_offset"] for d in wires_targets]
+    line_maps = [d["line_map"] for d in wires_targets]
+
+    junc_map_tensor = torch.stack(junc_maps, dim=0)
+    junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+    line_map_tensor = torch.stack(line_maps, dim=0)
+    T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
+
+    n_jtyp = T["junc_map"].shape[1]
+
+    for task in ["junc_map"]:
+        T[task] = T[task].permute(1, 0, 2, 3)
+    for task in ["junc_offset"]:
+        T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+    offset = [2, 3, 5]
+    losses = []
+    output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+    jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+    lmap = output[offset[0]: offset[1]].squeeze(0)
+    joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+    L = OrderedDict()
+    L["junc_map"] = sum(
+        cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+    )
+    L["line_map"] = (
+        F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+            .mean(2)
+            .mean(1)
+    )
+    L["junc_offset"] = sum(
+        sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+        for i in range(n_jtyp)
+        for j in range(2)
+    )
+    for loss_name in L:
+        L[loss_name].mul_(loss_weight[loss_name])
+    losses.append(L)
+    result["losses"] = losses
+
+    loss = nn.BCEWithLogitsLoss(reduction="none")
+    loss = loss(x, y)
+    lpos_mask, lneg_mask = y, 1 - y
+    loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+    def sum_batch(x):
+        xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
+        return torch.cat(xs)
+
+    lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+    lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+    result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
+    result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
+
+    return result
+
+
+def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
+    result = {}
+    result["wires"] = {}
+    p = torch.cat(ps)
+    s = torch.sigmoid(input)
+    b = s > 0.5
+    lines = []
+    score = []
+    # print(f"n_batch:{n_batch}")
+    for i in range(n_batch):
+        # print(f"idx:{idx}")
+        p0 = p[idx[i]: idx[i + 1]]
+        s0 = s[idx[i]: idx[i + 1]]
+        mask = b[idx[i]: idx[i + 1]]
+        p0 = p0[mask]
+        s0 = s0[mask]
+        if len(p0) == 0:
+            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+            score.append(torch.zeros([1, n_out_line], device=p.device))
+        else:
+            arg = torch.argsort(s0, descending=True)
+            p0, s0 = p0[arg], s0[arg]
+            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+        for j in range(len(jcs[i])):
+            if len(jcs[i][j]) == 0:
+                jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
+            jcs[i][j] = jcs[i][j][
+                None, torch.arange(n_out_junc) % len(jcs[i][j])
+            ]
+    result["wires"]["lines"] = torch.cat(lines)
+    result["wires"]["score"] = torch.cat(score)
+    result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+
+    if len(jcs[i]) > 1:
+        result["preds"]["junts"] = torch.cat(
+            [jcs[i][1] for i in range(n_batch)]
+        )
+
+    return result
+
+
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
     # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
     """
@@ -890,6 +1001,13 @@ class RoIHeads(nn.Module):
             image_shapes (List[Tuple[H, W]])
             targets (List[Dict])
         """
+        if targets is not None:
+            self.training = True
+
+        else:
+            self.training = False
+
+        print(f'self.training:{self.training}')
 
         if targets is not None:
             for t in targets:
@@ -937,20 +1055,29 @@ class RoIHeads(nn.Module):
 
         features_lcnn = features['0']
         if self.has_line():
-            line_features = self.line_head(features_lcnn)
-            loss_weight = {'jmap': 8.0, 'lmap': 0.5, 'joff': 0.25, 'lpos': 1, 'lneg': 1, 'boxes': 1.0}
-            x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
-                features_lcnn)  # x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+            outputs = self.line_head(features_lcnn)
+            loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+            x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
+                inputs=outputs, features=features_lcnn, targets=targets)
+
+            # # line_loss(multitasklearner)
+            # if self.training:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=True)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=True)
+            # else:
+            #     head_result = line_head_loss(targets, outputs, features_lcnn, loss_weight, mode_train=False)
+            #     line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
+            #                                        loss_weight, mode_train=False)
 
-            # line_loss(multitasklearner)
             if self.training:
-                head_result = line_head_loss(targets, line_features, features_lcnn, loss_weight, mode_train=True)
-                line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
-                                                   loss_weight, mode_train=True)
+                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
             else:
-                head_result = line_head_loss(targets, line_features, features_lcnn, loss_weight, mode_train=False)
-                line_result = line_vectorizer_loss(head_result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out_junc,
-                                                   loss_weight, mode_train=False)
+                pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
+                result.append(pred)
+                loss_wirepoint = {}
+            losses.update(loss_wirepoint)
 
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]
@@ -1041,5 +1168,6 @@ class RoIHeads(nn.Module):
                     r["keypoints"] = keypoint_prob
                     r["keypoints_scores"] = kps
             losses.update(loss_keypoint)
+        print(f'losses111:{losses.keys()}')
 
         return result, losses

+ 0 - 71
models/line_detect/wireframe.yaml

@@ -1,71 +0,0 @@
-io:
-  logdir: logs/
-  datadir: D:\python\PycharmProjects\data
-  resume_from:
-  num_workers: 0
-  tensorboard_port: 0
-  validation_interval: 300
-
-model:
-  image:
-      mean: [109.730, 103.832, 98.681]
-      stddev: [22.275, 22.124, 23.229]
-
-  batch_size: 4
-  batch_size_eval: 2
-
-  # backbone multi-task parameters
-  head_size: [[2], [1], [2],[4]]
-  loss_weight:
-    jmap: 8.0
-    lmap: 0.5
-    joff: 0.25
-    lpos: 1
-    lneg: 1
-    boxes: 1.0
-
-  # backbone parameters
-  backbone: fasterrcnn_resnet50
-#  backbone: unet
-  depth: 4
-  num_stacks: 1
-  num_blocks: 1
-
-  # sampler parameters
-  ## static sampler
-  n_stc_posl: 300
-  n_stc_negl: 40
-
-  ## dynamic sampler
-  n_dyn_junc: 300
-  n_dyn_posl: 300
-  n_dyn_negl: 80
-  n_dyn_othr: 600
-
-  # LOIPool layer parameters
-  n_pts0: 32
-  n_pts1: 8
-
-  # line verification network parameters
-  dim_loi: 128
-  dim_fc: 1024
-
-  # maximum junction and line outputs
-  n_out_junc: 250
-  n_out_line: 2500
-
-  # additional ablation study parameters
-  use_cood: 0
-  use_slop: 0
-  use_conv: 0
-
-  # junction threashold for evaluation (See #5)
-  eval_junc_thres: 0.008
-
-optim:
-  name: Adam
-  lr: 4.0e-4
-  amsgrad: True
-  weight_decay: 1.0e-4
-  max_epoch: 1000
-  lr_decay_epoch: 10

+ 204 - 0
train——line_rcnn.py

@@ -1,3 +1,6 @@
+
+# 根据LCNN写的train    2025/2/7
+'''
 #!/usr/bin/env python3
 import datetime
 import glob
@@ -156,3 +159,204 @@ def main():
 
 if __name__ == "__main__":
     main()
+'''
+
+import os
+from typing import Optional, Any
+
+import cv2
+import numpy as np
+import torch
+
+from models.config.config_tool import read_yaml
+from models.line_detect.dataset_LD import WirePointDataset
+from tools import utils
+
+from torch.utils.tensorboard import SummaryWriter
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from skimage import io
+
+from models.line_detect.line_rcnn import linercnn_resnet50_fpn
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+
+    return total_loss
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def _plot_samples(self, i, index, result, targets, prefix):
+    fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
+    img = io.imread(fn)
+    imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
+
+    def draw_vecl(lines, sline, juncs, junts, fn):
+        imshow(img)
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                if i > 0 and (lines[i] == lines[0]).all():
+                    break
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+        if not (juncs[0] == 0).all():
+            for i, j in enumerate(juncs):
+                if i > 0 and (i == juncs[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+            for i, j in enumerate(junts):
+                if i > 0 and (i == junts[0]).all():
+                    break
+                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+        plt.savefig(fn), plt.close()
+
+    junc = targets[i]["junc"].cpu().numpy() * 4
+    jtyp = targets[i]["jtyp"].cpu().numpy()
+    juncs = junc[jtyp == 0]
+    junts = junc[jtyp == 1]
+    rjuncs = result["juncs"][i].cpu().numpy() * 4
+    rjunts = None
+    if "junts" in result:
+        rjunts = result["junts"][i].cpu().numpy() * 4
+
+    lpre = targets[i]["lpre"].cpu().numpy() * 4
+    vecl_target = targets[i]["lpre_label"].cpu().numpy()
+    vecl_result = result["lines"][i].cpu().numpy() * 4
+    score = result["score"][i].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
+    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+
+    img = cv2.imread(f"{prefix}_vecl_a.jpg")
+    img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
+    self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC')
+
+
+if __name__ == '__main__':
+    cfg = r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml'
+    cfg = read_yaml(cfg)
+    print(f'cfg:{cfg}')
+    print(cfg['model']['n_dyn_negl'])
+    # net = WirepointPredictor()
+
+    dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
+    train_sampler = torch.utils.data.RandomSampler(dataset_train)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+    train_collate_fn = utils.collate_fn_wirepoint
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
+    )
+
+    dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+    val_sampler = torch.utils.data.RandomSampler(dataset_val)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
+    val_collate_fn = utils.collate_fn_wirepoint
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
+    )
+
+    model = linercnn_resnet50_fpn().to(device)
+
+    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
+    writer = SummaryWriter(cfg['io']['logdir'])
+
+
+    def move_to_device(data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+
+    def writer_loss(writer, losses, epoch):
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    # ?? wirepoint ??????
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            # ?? .item() ?????
+                            writer.add_scalar(f'loss_wirepoint/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    writer.add_scalar(key, value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+
+    for epoch in range(cfg['optim']['max_epoch']):
+        print(f"epoch:{epoch}")
+        model.train()
+
+        for imgs, targets in data_loader_train:
+
+            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+            # print(type(losses))
+            # print(losses)
+            loss = _loss(losses)
+            # print(loss)
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            writer_loss(writer, losses, epoch)
+
+            model.eval()
+            with torch.no_grad():
+                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                    pred = model(move_to_device(imgs, device))
+                    # print(f"perd:{pred}")
+                    break
+
+                    # print(f"perd:{pred}")
+
+                # if batch_idx == 0:
+                #     viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}")
+                #     H = pred["wires"]
+                #     _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}")
+
+# imgs, targets = next(iter(data_loader))
+#
+# model.train()
+# pred = model(imgs, targets)
+# print(f'pred:{pred}')
+
+# result, losses = model(imgs, targets)
+# print(f'result:{result}')
+# print(f'pred:{losses}')