lstrlq 6 hónapja
szülő
commit
9e8b3a7435

+ 3 - 0
libs/vision_libs/models/detection/transform.py

@@ -197,7 +197,10 @@ class GeneralizedRCNNTransform(nn.Module):
             return image, target
 
         bbox = target["boxes"]
+        print(f'bbox:{bbox}')
+        print(f'image.shape[-2:]:{image.shape},,,{image.shape[-2:]}')
         bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
+
         target["boxes"] = bbox
 
         if "keypoints" in target:

+ 2 - 1
models/line_detect/dataset_LD.py

@@ -92,7 +92,8 @@ class WirePointDataset(BaseDataset):
         rgb_channels = img[:, :, :3]
         depth_channel = img[:, :, 3]
 
-        rgb_normalized = rgb_channels.astype(np.float32) / 255.0
+        # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
+        rgb_normalized = rgb_channels
         depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())
 
         # 将归一化后的RGB和深度通道重新组合

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/lm/1-dataset/a_dataset
+  datadir: /data/share/lm/1-dataset/a_dataset
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 2 - 2
models/line_detect/train_demo.py

@@ -7,9 +7,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
-    # model=linenet_resnet50_fpn()
+    model=linenet_resnet50_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
-    model=linenet_resnet18_fpn()
+    # model=linenet_resnet18_fpn()
     # model=linenet_resnet101_fpn_v2()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')

+ 1 - 0
models/line_detect/trainer.py

@@ -148,6 +148,7 @@ class Trainer(BaseTrainer):
 
     def writer_predict_result(self, img, result, epoch):
         img = img.cpu().detach()
+        img=img[:3,:,:]
         im = img.permute(1, 2, 0)
         self.writer.add_image("z-ori", im, epoch, dataformats="HWC")