浏览代码

4080 train rgbd

lstrlq 6 月之前
父节点
当前提交
2758de1491
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 1 1
      models/line_detect/trainer.py
  2. 1 1
      models/line_net/dataset_LD.py

+ 1 - 1
models/line_detect/trainer.py

@@ -148,7 +148,7 @@ class Trainer(BaseTrainer):
 
 
     def writer_predict_result(self, img, result, epoch):
     def writer_predict_result(self, img, result, epoch):
         img = img.cpu().detach()
         img = img.cpu().detach()
-        img=img[:3,:,:]
+        img=img[:3]
         im = img.permute(1, 2, 0)
         im = img.permute(1, 2, 0)
         self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
         self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
 
 

+ 1 - 1
models/line_net/dataset_LD.py

@@ -94,7 +94,7 @@ class WirePointDataset(BaseDataset):
 
 
         # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
         # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
         rgb_normalized = rgb_channels
         rgb_normalized = rgb_channels
-        depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())
+        depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
 
 
         # 将归一化后的RGB和深度通道重新组合
         # 将归一化后的RGB和深度通道重新组合
         normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized))  # 或者使用depth_normalized_fixed_range
         normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized))  # 或者使用depth_normalized_fixed_range