Forráskód Böngészése

line_detect添加直接读取xanylabel标注数据文件功能

RenLiqiang 5 hónapja
szülő
commit
51ec3b67b1

+ 10 - 0
libs/vision_libs/models/detection/transform.py

@@ -201,6 +201,11 @@ class GeneralizedRCNNTransform(nn.Module):
             keypoints = target["keypoints"]
             keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
             target["keypoints"] = keypoints
+
+        if "lines" in target:
+            lines = target["lines"]
+            lines = resize_keypoints(lines, (h, w), image.shape[-2:])
+            target["lines"] = lines
         return image, target
 
     # _onnx_batch_images() is an implementation of
@@ -274,6 +279,11 @@ class GeneralizedRCNNTransform(nn.Module):
                 keypoints = pred["keypoints"]
                 keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                 result[i]["keypoints"] = keypoints
+
+            if "lines" in pred:
+                keypoints = pred["lines"]
+                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
+                result[i]["lines"] = keypoints
         return result
 
     def __repr__(self) -> str:

+ 1 - 0
models/base/base_detection_net.py

@@ -92,6 +92,7 @@ class BaseDetectionNet(BaseModel):
             original_image_sizes.append((val[0], val[1]))
 
         images, targets = self.transform(images, targets)
+        # print(f'images shape from transform:{images.tensors.shape }')
 
         # Check for degenerate boxes
         # TODO: Move this to a function

+ 1 - 0
models/keypoint/keypoint_dataset.py

@@ -202,6 +202,7 @@ def line_boxes(target):
                 xmax = b[0] + 1
             boxs.append([ymin, xmin, ymax, xmax])
 
+    print(f'torch.tensor(boxs):{torch.tensor(boxs).shape},torch.tensor(keypoints):{torch.tensor(keypoints).shape}')
     return torch.tensor(boxs), torch.tensor(keypoints)
 
 if __name__ == '__main__':

+ 62 - 101
models/line_detect/line_dataset.py

@@ -31,7 +31,10 @@ def validate_keypoints(keypoints, image_width, image_height):
         if not (0 <= x < image_width and 0 <= y < image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
+"""
+直接读取xanlabel标注的数据集json格式
 
+"""
 class LineDataset(BaseDataset):
     def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
         super().__init__(dataset_path)
@@ -50,28 +53,20 @@ class LineDataset(BaseDataset):
 
     def __getitem__(self, index) -> T_co:
         img_path = os.path.join(self.img_path, self.imgs[index])
-        if self.data_type == 'tiff':
-            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
-            # img = imageio.v3.imread(img_path).reshape(512, 512, 1)
-            img = imageio.v3.imread(img_path)[:, :, :3]
-            # img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
-            # img_3channel[:, :, 2] = img[:, :, 0]
-
-            img_3channel=img
-            w, h = img.shape[:2]
-            img = torch.from_numpy(img_3channel).permute(2, 0, 1)
-        else:
-            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-            img = PIL.Image.open(img_path).convert('RGB')
-            w, h = img.size
+
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         if self.transforms:
             img, target = self.transforms(img, target)
+
         else:
             img = self.default_transform(img)
 
         # print(f'img:{img}')
+        # print(f'img shape:{img.shape}')
         return img, target
 
     def __len__(self):
@@ -83,78 +78,31 @@ class LineDataset(BaseDataset):
         with open(lbl_path, 'r') as file:
             lable_all = json.load(file)
 
-        n_stc_posl = 300
-        n_stc_negl = 40
-        use_cood = 0
-        use_slop = 0
-
-        wire = lable_all["wires"][0]  # 字典
-        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
-        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
-        npos, nneg = len(line_pos_coords), len(line_neg_coords)
-        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
-        for i in range(len(lpre)):
-            if random.random() > 0.5:
-                lpre[i] = lpre[i, ::-1]
-        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-        feat = [
-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
-            ldir * use_slop,
-            lpre[:, :, 2],
-        ]
-        feat = np.concatenate(feat, 1)
-
-        wire_labels = {
-            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
-            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
-            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
-            # 真实存在线条的邻接矩阵
-            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
-
-            "lpre": torch.tensor(lpre)[:, :, :2],
-            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
-            "lpre_feat": torch.from_numpy(feat),
-            "junc_map": torch.tensor(wire['junc_map']["content"]),
-            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
-            "line_map": torch.tensor(wire['line_map']["content"]),
-        }
-
-        labels = []
-        if self.target_type == 'polygon':
-            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
-        elif self.target_type == 'pixel':
-            labels = read_masks_from_pixels_wire(lbl_path, shape)
-
-        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+
+        objs = lable_all["shapes"]
+        point_pairs=objs[0]['points']
+
+
+        # print(f'point_pairs:{point_pairs}')
         target = {}
 
         target["image_id"] = torch.tensor(item)
-        # return wire_labels, target
-        target["wires"] = wire_labels
-
-        # target["labels"] = torch.stack(labels)
 
-        # print(f'labels:{target["labels"]}')
-        # target["boxes"] = line_boxes(target)
-        target["boxes"], lines = get_boxes_lines(target)
+        target["boxes"], lines = get_boxes_lines(objs,shape)
+        # print(f'lines:{lines}')
         target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
-        # keypoints=keypoints/512
-        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
-        # keypoints= wire_labels["junc_coords"]
+
         a = torch.full((lines.shape[0],), 2).unsqueeze(1)
         lines = torch.cat((lines, a), dim=1)
 
         target["lines"] = lines.to(torch.float32).view(-1,2,3)
-        # print(f'boxes:{target["boxes"].shape}')
-        # 在 __getitem__ 方法中调用此函数
+        target["img_size"]=shape
+
         validate_keypoints(lines, shape[0], shape[1])
-        # print(f'keypoints:{target["keypoints"].shape}')
-        # print(f'target:{target}')
         return target
 
-    def show(self, idx):
+    def show(self, idx,show_type='all'):
         image, target = self.__getitem__(idx)
 
         cmap = plt.get_cmap("jet")
@@ -164,12 +112,23 @@ class LineDataset(BaseDataset):
 
         img_path = os.path.join(self.img_path, self.imgs[idx])
         img = PIL.Image.open(img_path).convert('RGB')
-        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+        if show_type=='all':
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                                  colors="yellow", width=1)
+            keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
+            plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
+            plt.show()
+
+        if show_type=='lines':
+            keypoint_img=draw_keypoints((self.default_transform(img) * 255).to(torch.uint8),target['lines'],colors='red',width=3)
+            plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
+            plt.show()
+
+        if show_type=='boxes':
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
                                               colors="yellow", width=1)
-        keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
-        plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
-        plt.show()
-
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
 
 
 
@@ -177,33 +136,35 @@ class LineDataset(BaseDataset):
     def show_img(self, img_path):
         pass
 
-def get_boxes_lines(target):
-    boxs = []
-    lpre = target['wires']["lpre"].cpu().numpy()
-    vecl_target = target['wires']["lpre_label"].cpu().numpy()
-    lpre = lpre[vecl_target == 1]
-    lines = lpre
-    sline = np.ones(lpre.shape[0])
+def get_boxes_lines(objs,shape):
+    boxes = []
+    h,w=shape
     line_point_pairs = []
 
-    if len(lines) > 0 and not (lines[0] == 0).all():
-        for i, ((a, b), s) in enumerate(zip(lines, sline)):
-            if i > 0 and (lines[i] == lines[0]).all():
-                break
-            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-            line_point_pairs.append([a[1], a[0]])
-            line_point_pairs.append([b[1], b[0]])
+    for obj in objs:
+        # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+        # print(f"points:{obj['points']}")
+
+        a,b=obj['points'][0],obj['points'][1]
+
+        line_point_pairs.append(a)
+        line_point_pairs.append(b)
+
+        xmin = max(0, (min(a[0], b[0]) - 6))
+        xmax = min(w, (max(a[0], b[0]) + 6))
+        ymin = max(0, (min(a[1], b[1]) - 6))
+        ymax = min(h, (max(a[1], b[1]) + 6))
 
-            xmin = max(0, (min(a[0], b[0]) - 6))
-            xmax = min(511, (max(a[0], b[0]) + 6))
-            ymin = max(0, (min(a[1], b[1]) - 6))
-            ymax = min(511, (max(a[1], b[1]) + 6))
+        boxes.append([ xmin,ymin,  xmax,ymax])
 
-            boxs.append([ymin, xmin, ymax, xmax])
+    boxes=torch.tensor(boxes)
+    line_point_pairs=torch.tensor(line_point_pairs)
 
-    return torch.tensor(boxs), torch.tensor(line_point_pairs)
+    # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
+    return boxes,line_point_pairs
 
 if __name__ == '__main__':
-    path=r"\\192.168.50.222/share/lm/Dataset_all"
-    dataset= LineDataset(dataset_path=path, dataset_type='train')
-    dataset.show(10)
+    path=r"\\192.168.50.222/share/rlq/datasets/0706_"
+    dataset= LineDataset(dataset_path=path, dataset_type='train',data_type='jpg')
+    dataset.show(1,show_type='lines')

+ 1 - 2
models/line_detect/line_detect.py

@@ -35,7 +35,6 @@ from ..base.high_reso_resnet import resnet50fpn, resnet18fpn
 
 __all__ = [
     "LineDetect",
-    "LineDetect_ResNet50_FPN_Weights",
     "linedetect_resnet50_fpn",
 ]
 
@@ -54,7 +53,7 @@ class LineDetect(BaseDetectionNet):
             num_classes=None,
             # transform parameters
             min_size=512,
-            max_size=1333,
+            max_size=2048,
             image_mean=None,
             image_std=None,
             # RPN parameters

+ 10 - 4
models/line_detect/loi_heads.py

@@ -176,7 +176,7 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
 
     # show_heatmap(gs_heatmap[0],'feature')
 
-    print(f'gs_heatmap:{gs_heatmap.shape}')
+    # print(f'gs_heatmap:{gs_heatmap.shape}')
     #
     # lin_ind = y * heatmap_size + x
     # print(f'lin_ind:{lin_ind.shape}')
@@ -622,8 +622,9 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511):
         loss = 1.0 - ious
         losses.append(loss)
 
-    if not losses:
-        return None
+    if not losses:  # 如果损失列表为空,则返回默认值或抛出自定义异常
+        print("Warning: No valid losses were computed.")
+        return torch.tensor(1, requires_grad=True).to(x.device)  # 返回一个标量张量
 
     total_loss = torch.mean(torch.cat(losses))
     return total_loss
@@ -1219,12 +1220,15 @@ class RoIHeads(nn.Module):
 
             loss_line = {}
             loss_line_iou={}
-            img_size=512
+
             if self.training:
+
                 if targets is None or pos_matched_idxs is None:
                     raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
                 gt_lines = [t["lines"] for t in targets]
+                h, w = targets[0]["img_size"]
+                img_size = h
                 rcnn_loss_line = lines_point_pair_loss(
                     line_logits, line_proposals, gt_lines, pos_matched_idxs
                 )
@@ -1235,6 +1239,8 @@ class RoIHeads(nn.Module):
 
             else:
                 if targets is not None:
+                    h, w = targets[0]["img_size"]
+                    img_size = h
                     gt_lines = [t["lines"] for t in targets]
                     rcnn_loss_lines = lines_point_pair_loss(
                         line_logits, line_proposals, gt_lines, pos_matched_idxs

+ 2 - 2
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
-  datadir: /data/share/zyh/202507/a_dataset
-  data_type: jpg
+  datadir: \\192.168.50.222/share/rlq/datasets/0706_
+  data_type: tiff
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 2 - 3
models/line_detect/train_demo.py

@@ -15,8 +15,7 @@ if __name__ == '__main__':
     # model=linenet_newresnet50fpn()
     # model = lineDetect_resnet18_fpn()
 
-    # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn()
-    model.load_weights(r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250706_150832/weights/best_val.pth')
+    model=linedetect_resnet18_fpn()
+    # model=linedetect_newresnet18fpn()
 
     model.start_train(cfg='train.yaml')

+ 4 - 6
models/line_detect/trainer.py

@@ -5,7 +5,7 @@ from datetime import datetime
 import numpy as np
 import torch
 from matplotlib import pyplot as plt
-from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
+from torch.optim.lr_scheduler import ReduceLROnPlateau
 from torch.utils.tensorboard import SummaryWriter
 
 from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
@@ -274,9 +274,9 @@ class Trainer(BaseTrainer):
         for epoch in range(self.max_epoch):
             print(f"train epoch:{epoch}")
 
-
             model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
             scheduler.step(epoch_train_loss)
+
             # ========== Validation ==========
             with torch.no_grad():
                 model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
@@ -286,6 +286,7 @@ class Trainer(BaseTrainer):
                 best_train_loss = epoch_train_loss
                 best_val_loss = epoch_val_loss
 
+
             self.save_last_model(model,self.last_model_path, epoch, optimizer)
             best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
                                                    best_train_loss,
@@ -293,9 +294,6 @@ class Trainer(BaseTrainer):
             best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
                                                  optimizer)
 
-
-
-
     def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
         if phase == 'train':
             model.train()
@@ -331,7 +329,7 @@ class Trainer(BaseTrainer):
                 t_start = time.time()
                 print(f'start to predict:{t_start}')
                 result = model(self.move_to_device(imgs, self.device))
-                # print(f'result:{result}')
+                print(f'result:{result}')
                 t_end = time.time()
                 print(f'predict used:{t_end - t_start}')
                 self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)