Browse Source

删除lines_roi_pool

RenLiqiang 5 months ago
parent
commit
91ca516394

+ 8 - 8
models/line_detect/line_dataset.py

@@ -46,7 +46,7 @@ def apply_transform_with_boxes_and_keypoints(img,target):
     # 定义一系列用于数据增强的变换
     data_transforms = transforms.Compose([
         # 随机调整大小和随机裁剪
-        transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
+        # transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
 
         # 随机水平翻转
         transforms.RandomHorizontalFlip(p=0.5),
@@ -167,22 +167,22 @@ class LineDataset(BaseDataset):
         sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
         sm.set_array([])
 
-        img_path = os.path.join(self.img_path, self.imgs[idx])
-        img = PIL.Image.open(img_path).convert('RGB')
+        # img_path = os.path.join(self.img_path, self.imgs[idx])
+        img = image
         if show_type=='all':
-            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+            boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
                                                   colors="yellow", width=1)
             keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
             plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
             plt.show()
 
         if show_type=='lines':
-            keypoint_img=draw_keypoints((self.default_transform(img) * 255).to(torch.uint8),target['lines'],colors='red',width=3)
+            keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
             plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
             plt.show()
 
         if show_type=='boxes':
-            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+            boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
                                               colors="yellow", width=1)
             plt.imshow(boxed_image.permute(1, 2, 0).numpy())
             plt.show()
@@ -223,5 +223,5 @@ def get_boxes_lines(objs,shape):
 
 if __name__ == '__main__':
     path=r"\\192.168.50.222/share/rlq/datasets/0706_"
-    dataset= LineDataset(dataset_path=path, dataset_type='train',data_type='jpg')
-    dataset.show(1,show_type='lines')
+    dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=True, data_type='jpg')
+    dataset.show(1,show_type='all')

+ 212 - 0
models/line_detect/line_dataset_old.py

@@ -0,0 +1,212 @@
+from torch.utils.data.dataset import T_co
+
+from libs.vision_libs.utils import draw_keypoints
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+import imageio
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+def validate_keypoints(keypoints, image_width, image_height):
+    for kp in keypoints:
+        x, y, v = kp
+        if not (0 <= x < image_width and 0 <= y < image_height):
+            raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
+
+
+class LineDataset(BaseDataset):
+    def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        self.data_type = data_type
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        self.img_type=img_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        if self.data_type == 'tiff':
+            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
+            # img = imageio.v3.imread(img_path).reshape(512, 512, 1)
+            img = imageio.v3.imread(img_path)[:, :, :3]
+            # img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
+            # img_3channel[:, :, 2] = img[:, :, 0]
+
+            img_3channel=img
+            w, h = img.shape[:2]
+            img = torch.from_numpy(img_3channel).permute(2, 0, 1)
+        else:
+            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+            img = PIL.Image.open(img_path).convert('RGB')
+            w, h = img.size
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+
+        # target["labels"] = torch.stack(labels)
+
+        # print(f'labels:{target["labels"]}')
+        # target["boxes"] = line_boxes(target)
+        target["boxes"], lines = get_boxes_lines(target)
+        target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
+        # keypoints=keypoints/512
+        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
+
+        # keypoints= wire_labels["junc_coords"]
+        a = torch.full((lines.shape[0],), 2).unsqueeze(1)
+        lines = torch.cat((lines, a), dim=1)
+
+        target["lines"] = lines.to(torch.float32).view(-1,2,3)
+
+        target["img_size"] = shape
+        # print(f'boxes:{target["boxes"].shape}')
+        # 在 __getitem__ 方法中调用此函数
+        validate_keypoints(lines, shape[0], shape[1])
+
+        # print(f'keypoints:{target["keypoints"].shape}')
+        # print(f'target:{target}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        img_path = os.path.join(self.img_path, self.imgs[idx])
+        img = PIL.Image.open(img_path).convert('RGB')
+        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+        keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
+        plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
+        plt.show()
+
+
+
+
+
+    def show_img(self, img_path):
+        pass
+
+def get_boxes_lines(target):
+    boxs = []
+    lpre = target['wires']["lpre"].cpu().numpy()
+    vecl_target = target['wires']["lpre_label"].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+    lines = lpre
+    sline = np.ones(lpre.shape[0])
+    line_point_pairs = []
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            line_point_pairs.append([a[1], a[0]])
+            line_point_pairs.append([b[1], b[0]])
+
+            xmin = max(0, (min(a[0], b[0]) - 6))
+            xmax = min(511, (max(a[0], b[0]) + 6))
+            ymin = max(0, (min(a[1], b[1]) - 6))
+            ymax = min(511, (max(a[1], b[1]) + 6))
+
+            boxs.append([ymin, xmin, ymax, xmax])
+
+    return torch.tensor(boxs), torch.tensor(line_point_pairs)
+
+if __name__ == '__main__':
+    path=r"\\192.168.50.222/share/lm/Dataset_all"
+    dataset= LineDataset(dataset_path=path, dataset_type='train')
+    dataset.show(10)

+ 2 - 2
models/line_detect/line_detect.py

@@ -167,8 +167,8 @@ class LineDetect(BaseDetectionNet):
             line_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
 
         if line_head is None:
-            keypoint_layers = tuple(512 for _ in range(8))
-            line_head = LineHeads(out_channels, keypoint_layers)
+            keypoint_layers = tuple(1 for _ in range(8))
+            line_head = LineHeads(16, keypoint_layers)
 
         if line_predictor is None:
             keypoint_dim_reduced = 512  # == keypoint_layers[-1]

+ 45 - 3
models/line_detect/loi_heads.py

@@ -496,6 +496,31 @@ def heatmaps_to_lines(maps, rois):
 
     return xy_preds.permute(0, 2, 1), end_scores
 
+
+def lines_features_align(features, proposals, img_size):
+    print(f'lines_features_align features:{features.shape}')
+
+
+    for feat, proposals_per_img  in zip(features,proposals):
+        # print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
+        align_feat_list=[]
+        feat=feat.unsqueeze(0)
+        for proposal in proposals_per_img:
+            align_feat = torch.zeros_like(feat)
+            # print(f'align_feat:{align_feat.shape}')
+            x1, y1, x2, y2 = map(lambda v: int(v.item()), proposal)
+            # 将每个proposal框内的部分赋值到align_feats对应位置
+            align_feat[:,:, y1:y2 + 1, x1:x2 + 1] = feat[:,:, y1:y2 + 1, x1:x2 + 1]
+            align_feat_list.append(align_feat)
+
+
+    feats_tensor=torch.cat(align_feat_list)
+
+    print(f'align features :{feats_tensor.shape}')
+
+    return  feats_tensor
+
+
 def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
     N, K, H, W = line_logits.shape
@@ -909,6 +934,12 @@ class RoIHeads(nn.Module):
         self.keypoint_head = keypoint_head
         self.keypoint_predictor = keypoint_predictor
 
+        self.channel_compress = nn.Sequential(
+             nn.Conv2d(256, 16, kernel_size=1),
+             nn.BatchNorm2d(16),
+            nn.ReLU(inplace=True)
+        )
+
     def has_mask(self):
         if self.mask_roi_pool is None:
             return False
@@ -1208,14 +1239,22 @@ class RoIHeads(nn.Module):
             print(f'line_proposals:{len(line_proposals)}')
 
 
-            line_features = self.line_roi_pool(features, line_proposals, image_shapes)
+
+            # line_features = self.line_roi_pool(features, line_proposals, image_shapes)
+
+
+            # print(f'line_features from line_roi_pool:{line_features.shape}')
+
+            line_features=self.channel_compress(features['0'])
+
+            line_features=lines_features_align(line_features,line_proposals,image_shapes)
 
 
-            print(f'line_features from line_roi_pool:{line_features.shape}')
             line_features = self.line_head(line_features)
             print(f'line_features from line_head:{line_features.shape}')
-            line_logits = self.line_predictor(line_features)
+            # line_logits = self.line_predictor(line_features)
 
+            line_logits=line_features
             print(f'line_logits:{line_logits.shape}')
 
             loss_line = {}
@@ -1329,6 +1368,9 @@ class RoIHeads(nn.Module):
                 pos_matched_idxs = None
 
             keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
+
+
+
             keypoint_features = self.line_head(keypoint_features)
             keypoint_logits = self.line_predictor(keypoint_features)
 

+ 3 - 3
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/rlq/datasets/0706_
-  data_type: tiff
+  datadir: \\192.168.50.222/share/zyh/202507/a_dataset
+  data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 
@@ -11,7 +11,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 4
+  batch_size: 1
   max_epoch: 80000
   augmentation: True
   optim:

+ 2 - 2
models/line_detect/train_demo.py

@@ -15,7 +15,7 @@ if __name__ == '__main__':
     # model=linenet_newresnet50fpn()
     # model = lineDetect_resnet18_fpn()
 
-    model=linedetect_resnet18_fpn()
-    # model=linedetect_newresnet18fpn()
+    # model=linedetect_resnet18_fpn()
+    model=linedetect_newresnet18fpn()
 
     model.start_train(cfg='train.yaml')

+ 4 - 3
models/line_detect/trainer.py

@@ -12,7 +12,8 @@ from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
 from models.base.base_model import BaseModel
 from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
-from models.line_detect.line_dataset import LineDataset
+from models.line_detect.line_dataset_old import LineDataset
+
 from models.line_net.dataset_LD import WirePointDataset
 from models.wirenet.postprocess import postprocess
 from tools import utils
@@ -244,8 +245,8 @@ class Trainer(BaseTrainer):
 
         self.init_params(**kwargs)
 
-        dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
-        dataset_val = LineDataset(dataset_path=self.dataset_path, augmentation=self.augmentation, data_type=self.data_type, dataset_type='val')
+        dataset_train = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='train')
+        dataset_val = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='val')
 
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         val_sampler = torch.utils.data.RandomSampler(dataset_val)