Explorar o código

修改图片训练为3通道

RenLiqiang hai 7 meses
pai
achega
44829ea84f

+ 2 - 0
.gitignore

@@ -34,3 +34,5 @@ __pycache__
 train_results
 *.tif
 *.tiff
+*.jpg
+*.png

+ 2 - 2
models/line_detect/111.py

@@ -234,8 +234,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
     # model = LineNet('line_net.yaml')
     # model = linenet_resnet50_fpn().to(device)
-    model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
-    # model=get_line_net_convnext_fpn(num_classes=2).to(device)
+    # model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
+    model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=linenet_resnet18_fpn()
     trainer = Trainer()
     trainer.train_cfg(model,cfg='./train.yaml')

BIN=BIN
models/line_detect/color_img.jpg


+ 3 - 27
models/line_detect/dataset_LD.py

@@ -1,5 +1,4 @@
 # ??roi_head??????????????
-import imageio
 from torch.utils.data.dataset import T_co
 
 from models.base.base_dataset import BaseDataset
@@ -73,24 +72,10 @@ class WirePointDataset(BaseDataset):
 
     def __getitem__(self, index) -> T_co:
         img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
-        img=imageio.v3.imread(img_path).reshape(512,512,1)
-        img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
-        # 将原始通道复制到第一个通道
-        img_3channel[:, :, 2] = img[:, :, 0]
-
-        # print(f'dataset img shape:{img.shape}')
-        # img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.shape[:2]
-
-        img=torch.from_numpy(img_3channel).permute(2, 0, 1)
-
-        img=self.zscore_normalize_depth(img)
-
-
-        # img=img.transpose(2,0,1)
-        # print(f'dataset img shape2:{img.shape}')
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
 
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
 
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
@@ -105,15 +90,6 @@ class WirePointDataset(BaseDataset):
     def __len__(self):
         return len(self.imgs)
 
-    def zscore_normalize_depth(self,img):
-        depth = img[2]
-        mean = depth.mean()
-        std = depth.std()
-        depth_normalized = (depth - mean) / (std + 1e-8)
-        img_normalized = img.clone()
-        img_normalized[2] = depth_normalized
-        return img_normalized
-
     def read_target(self, item, lbl_path, shape, extra=None):
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:

+ 1 - 1
models/line_detect/test_tiff.py

@@ -58,7 +58,7 @@ def pointscloud2colorimg(points):
 
 
 # 加载PCD文件
-pcd = o3d.io.read_point_cloud(r"F:\test_pointcloud\color.pcd")
+pcd = o3d.io.read_point_cloud(r"F:\test_pointcloud\color2.pcd")
 
 # 打印点的数量
 print("Number of points:", len(pcd.points))

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: \\192.168.50.222/share/zyh/512/init/last
+  datadir: \\192.168.50.222/share/zyh/513/a_dataset
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
   resume_from: