@@ -148,7 +148,7 @@ class Trainer(BaseTrainer):
def writer_predict_result(self, img, result, epoch):
img = img.cpu().detach()
- img=img[:3,:,:]
+ img=img[:3]
im = img.permute(1, 2, 0)
self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
@@ -94,7 +94,7 @@ class WirePointDataset(BaseDataset):
# rgb_normalized = rgb_channels.astype(np.float32) / 255.0
rgb_normalized = rgb_channels
- depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())
+ depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
# 将归一化后的RGB和深度通道重新组合
normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # 或者使用depth_normalized_fixed_range