Pārlūkot izejas kodu

trian 4channels on 4080 with initial results

lstrlq 6 mēneši atpakaļ
vecāks
revīzija
948627e972

+ 1 - 1
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
+#  datadir: /data/share/zyh/5月彩色钢板数据汇总/zjh/a_dataset
   datadir: /data/share/lm/1-dataset/a_dataset
-#  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 
   tensorboard_port: 6000

+ 43 - 16
models/line_net/dataset_LD.py

@@ -1,9 +1,9 @@
 # ??roi_head??????????????
-from torch import dtype
 from torch.utils.data.dataset import T_co
 
+from libs.vision_libs.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 from models.base.base_dataset import BaseDataset
-
+from torchvision.transforms import  functional as F
 import glob
 import json
 import math
@@ -70,6 +70,7 @@ class WirePointDataset(BaseDataset):
         self.imgs = os.listdir(self.img_path)
         self.lbls = os.listdir(self.lbl_path)
         self.target_type = target_type
+        self.transform = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
         # self.default_transform = DefaultTransform()
 
     def __getitem__(self, index) -> T_co:
@@ -77,35 +78,61 @@ class WirePointDataset(BaseDataset):
         lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
 
         # img = PIL.Image.open(img_path).convert('RGB')
-        # w, h = img.size
         img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
-        print(img.shape)
-        w, h = img.shape[0:2]
+        img_rgb=img[:,:,:3]
 
+        print(f'img shape:{img.shape}')
+        img_rgb=cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB)
+        # img=np.array(img,copy=True)
+        # img = self.default_transform(img)
+        # print(f'pil img:{img.dtype}')
+        # w, h = img.size
+        # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+        # cv2.imshow('img',img)
+        # cv2.waitKey(1000000)
+        # print(img.shape)
+        w, h = img.shape[0:2]
+        # w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+
+        print(f'self.default_transform:{self.default_transform}')
+
         # if self.transforms:
         #     img, target = self.transforms(img, target)
         # else:
         #     img = self.default_transform(img)
 
-        # 分离RGB和深度通道
-        rgb_channels = img[:, :, :3]
-        depth_channel = img[:, :, 3]
 
-        # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
-        rgb_normalized = rgb_channels
-        depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
-
-        # 将归一化后的RGB和深度通道重新组合
-        normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized))  # 或者使用depth_normalized_fixed_range
+        # # 分离RGB和深度通道
+        # rgb_channels = img[:, :, :3]
+        # depth_channel = img[:, :, 3]
+        #
+        # rgb_normalized = rgb_channels/255
+        # depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
+        #
+        # # 将归一化后的RGB和深度通道重新组合
+        # normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized))  # 或者使用depth_normalized_fixed_range
+        #
+        # print("Normalized RGBA image shape:", normalized_rgba_image.shape)
+        #
+        # img = torch.tensor(normalized_rgba_image,dtype=torch.uint8).permute(2,1,0)
 
-        print("Normalized RGBA image shape:", normalized_rgba_image.shape)
 
-        img = torch.tensor(normalized_rgba_image,dtype=torch.float32).permute(2,1,0)
+        # cv2.imshow('img',img[:3].permute(1,2,0).numpy().astype(np.uint8))
+        # cv2.waitKey(10000000)
+        # plt.imshow(img[:3].permute(1,2,0).numpy())
+        # plt.show()
 
         # new_channel = torch.zeros(1, 512, 512)
         # img=torch.cat((img,new_channel),dim=0)
+        img=np.dstack((img_rgb,img[:,:,3]))
+
+        img=torch.as_tensor(img).permute(2,0,1)
+        img=self.default_transform(img)
+
+
+        # img=F.convert_image_dtype(img, torch.float)
         print(f'img:{img.shape}')
         # print(f'img dtype:{img.dtype}')
         return img, target

+ 2 - 2
models/line_net/line_net.py

@@ -214,11 +214,11 @@ class LineNet(BaseDetectionNet):
 
         if image_mean is None:
             # image_mean = [0.485, 0.456, 0.406]
-            image_mean = [0.485, 0.456, 0.406, 0.5]  # 假设你新加的通道均值为0.5
+            image_mean = [0.485, 0.456, 0.406,  0.2549]  # 假设你新加的通道均值为0.5
 
         if image_std is None:
             # image_std = [0.229, 0.224, 0.225]
-            image_std = [0.229, 0.224, 0.225, 0.2]  # 标准差也补一个值
+            image_std = [0.229, 0.224, 0.225, 0.4093]  # 标准差也补一个值
         transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
 
         super().__init__(backbone, rpn, roi_heads, transform)