Ver Fonte

完善dataset,从train.yaml传入参数,rgb或单通道tiff

xue50 há 5 meses atrás
pai
commit
c59317204e

+ 15 - 7
models/line_detect/line_dataset.py

@@ -10,7 +10,7 @@ import os
 import random
 import cv2
 import PIL
-
+import imageio
 import matplotlib.pyplot as plt
 import matplotlib as mpl
 from torchvision.utils import draw_bounding_boxes
@@ -33,10 +33,11 @@ def validate_keypoints(keypoints, image_width, image_height):
 
 
 class LineDataset(BaseDataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
+    def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
         super().__init__(dataset_path)
 
         self.data_path = dataset_path
+        self.data_type = data_type
         print(f'data_path:{dataset_path}')
         self.transforms = transforms
         self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
@@ -49,11 +50,18 @@ class LineDataset(BaseDataset):
 
     def __getitem__(self, index) -> T_co:
         img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
+        if self.data_type == 'tiff':
+            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
+            img = imageio.v3.imread(img_path).reshape(512, 512, 1)
+            img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
+            img_3channel[:, :, 2] = img[:, :, 0]
+
+            w, h = img.shape[:2]
+            img = torch.from_numpy(img_3channel).permute(2, 0, 1)
+        else:
+            lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+            img = PIL.Image.open(img_path).convert('RGB')
+            w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         if self.transforms:

+ 1 - 1
models/line_detect/loi_heads.py

@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
 import torch
 import torch.nn.functional as F
 import torchvision
-from scipy.optimize import linear_sum_assignment
+# from scipy.optimize import linear_sum_assignment
 from torch import nn, Tensor
 from  libs.vision_libs.ops import boxes as box_ops, roi_align
 

+ 2 - 1
models/line_detect/train.yaml

@@ -1,6 +1,7 @@
 io:
   logdir: train_results
-  datadir: /data/share/lm/Dataset_all
+  datadir: \\192.168.50.222\share\lm\Dataset_all
+  data_type: jpg
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 3 - 2
models/line_detect/trainer.py

@@ -90,6 +90,7 @@ class Trainer(BaseTrainer):
             self.freeze_config = kwargs['train_params']['freeze_params']
             print(f'freeze_config:{self.freeze_config}')
             self.dataset_path = kwargs['io']['datadir']
+            self.data_type = kwargs['io']['data_type']
             self.batch_size = kwargs['train_params']['batch_size']
             self.num_workers = kwargs['train_params']['num_workers']
             self.logdir = kwargs['io']['logdir']
@@ -241,8 +242,8 @@ class Trainer(BaseTrainer):
 
         self.init_params(**kwargs)
 
-        dataset_train = LineDataset(dataset_path=self.dataset_path, dataset_type='train')
-        dataset_val = LineDataset(dataset_path=self.dataset_path, dataset_type='val')
+        dataset_train = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='train')
+        dataset_val = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='val')
 
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         val_sampler = torch.utils.data.RandomSampler(dataset_val)