RenLiqiang 6 месяцев назад
Родитель
Сommit
d47a7a84ab

+ 3 - 0
libs/vision_libs/models/detection/transform.py

@@ -166,6 +166,9 @@ class GeneralizedRCNNTransform(nn.Module):
         dtype, device = image.dtype, image.device
         mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
         std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
+        # print(f'mean:{mean}')
+        # print(f'std:{std}')
+        # print(f'image:{image.shape}')
         return (image - mean[:, None, None]) / std[:, None, None]
 
     def torch_choice(self, k: List[int]) -> int:

+ 27 - 13
models/line_detect/dataset_LD.py

@@ -27,6 +27,7 @@ from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks
 
 from tools.presets import DetectionPresetTrain
 
+
 def line_boxes1(target):
     boxs = []
     lines = target.cpu().numpy() * 4
@@ -74,20 +75,37 @@ class WirePointDataset(BaseDataset):
         img_path = os.path.join(self.img_path, self.imgs[index])
         lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
 
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
+        # img = PIL.Image.open(img_path).convert('RGB')
+        # w, h = img.size
+        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+        print(img.shape)
+        w, h = img.shape[0:2]
 
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        if self.transforms:
-            img, target = self.transforms(img, target)
-        else:
-            img = self.default_transform(img)
+        # if self.transforms:
+        #     img, target = self.transforms(img, target)
+        # else:
+        #     img = self.default_transform(img)
+
+        # 分离RGB和深度通道
+        rgb_channels = img[:, :, :3]
+        depth_channel = img[:, :, 3]
+
+        rgb_normalized = rgb_channels.astype(np.float32) / 255.0
+        depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())
+
+        # 将归一化后的RGB和深度通道重新组合
+        normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized))  # 或者使用depth_normalized_fixed_range
+
+        print("Normalized RGBA image shape:", normalized_rgba_image.shape)
 
+        img = torch.tensor(normalized_rgba_image,dtype=torch.float32).permute(2,1,0)
 
         # new_channel = torch.zeros(1, 512, 512)
         # img=torch.cat((img,new_channel),dim=0)
-        # print(f'img:{img.shape}')
+        print(f'img:{img.shape}')
+        # print(f'img dtype:{img.dtype}')
         return img, target
 
     def __len__(self):
@@ -146,13 +164,12 @@ class WirePointDataset(BaseDataset):
         target = {}
         # target["labels"] = torch.stack(labels)
 
-
         target["image_id"] = torch.tensor(item)
         # return wire_labels, target
         target["wires"] = wire_labels
         # target["boxes"] = line_boxes(target)
         target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
-        target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
+        target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
         # print(f'target["labels"]:{ target["labels"]}')
         # print(f'boxes:{target["boxes"].shape}')
         if target["boxes"].numel() == 0:
@@ -190,7 +207,6 @@ class WirePointDataset(BaseDataset):
                         break
                     plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # ? s=64
 
-
             img_path = os.path.join(self.img_path, self.imgs[idx])
             img = PIL.Image.open(img_path).convert('RGB')
             boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
@@ -214,11 +230,9 @@ class WirePointDataset(BaseDataset):
         # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
         draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
 
-
     def show_img(self, img_path):
         pass
 
-
 # dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
 # for i in dataset_train:
-#     a = 1
+#     a = 1

+ 12 - 12
models/line_detect/line_net.py

@@ -134,14 +134,14 @@ class LineNet(BaseDetectionNet):
             )
 
         # 修改第一个卷积层,将 in_channels 从 3 改为 4
-        # backbone.body.conv1 = nn.Conv2d(
-        #     in_channels=4,
-        #     out_channels=64,
-        #     kernel_size=7,
-        #     stride=2,
-        #     padding=3,
-        #     bias=False
-        # )
+        backbone.body.conv1 = nn.Conv2d(
+            in_channels=4,
+            out_channels=64,
+            kernel_size=7,
+            stride=2,
+            padding=3,
+            bias=False
+        )
         if num_classes is not None:
             if box_predictor is not None:
                 raise ValueError("num_classes should be None when box_predictor is specified")
@@ -212,12 +212,12 @@ class LineNet(BaseDetectionNet):
         )
 
         if image_mean is None:
-            image_mean = [0.485, 0.456, 0.406]
-            # image_mean = [0.485, 0.456, 0.406, 0.5]  # 假设你新加的通道均值为0.5
+            # image_mean = [0.485, 0.456, 0.406]
+            image_mean = [0.485, 0.456, 0.406, 0.5]  # 假设你新加的通道均值为0.5
 
         if image_std is None:
-            image_std = [0.229, 0.224, 0.225]
-            # image_std = [0.229, 0.224, 0.225, 0.2]  # 标准差也补一个值
+            # image_std = [0.229, 0.224, 0.225]
+            image_std = [0.229, 0.224, 0.225, 0.2]  # 标准差也补一个值
         transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
 
         super().__init__(backbone, rpn, roi_heads, transform)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/zyh/5月彩色钢板数据汇总/zjh/a_dataset
+  datadir: \\192.168.50.222/share/lm/1-dataset/a_dataset
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 3 - 2
models/line_detect/train_demo.py

@@ -9,8 +9,9 @@ if __name__ == '__main__':
     # model = LineNet('line_net.yaml')
     # model=linenet_resnet50_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
-    # model=linenet_resnet18_fpn()
-    model=linenet_resnet101_fpn_v2()
+    model=linenet_resnet18_fpn()
+    # model=linenet_resnet101_fpn_v2()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     model.start_train(cfg='train.yaml')
+