Bladeren bron

添加训练tiff深度图功能

RenLiqiang 7 maanden geleden
bovenliggende
commit
601fbb297e
3 gewijzigde bestanden met toevoegingen van 133 en 4 verwijderingen
  1. 14 3
      models/line_detect/dataset_LD.py
  2. 118 0
      models/line_detect/test_tiff.py
  3. 1 1
      models/line_detect/train.yaml

+ 14 - 3
models/line_detect/dataset_LD.py

@@ -1,4 +1,5 @@
 # ??roi_head??????????????
+import imageio
 from torch.utils.data.dataset import T_co
 
 from models.base.base_dataset import BaseDataset
@@ -72,10 +73,20 @@ class WirePointDataset(BaseDataset):
 
     def __getitem__(self, index) -> T_co:
         img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
+        img=imageio.v3.imread(img_path).reshape(512,512,1)
+        img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
+        # 将原始通道复制到第一个通道
+        img_3channel[:, :, 2] = img[:, :, 0]
+
+        # print(f'dataset img shape:{img.shape}')
+        # img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.shape[:2]
+
+        img=torch.from_numpy(img_3channel).permute(2, 0, 1)
+        # img=img.transpose(2,0,1)
+        # print(f'dataset img shape2:{img.shape}')
 
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
 
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))

+ 118 - 0
models/line_detect/test_tiff.py

@@ -0,0 +1,118 @@
+import cv2
+import numpy as np
+import imageio
+import open3d as o3d
+from tifffile import tifffile
+# 相机内参矩阵 [fx, 0, cx; 0, fy, cy; 0, 0, 1]
+K = np.array([
+    [1.30449e3, 0,         5.2602e2],
+    [0,         1.30449e3, 1.07432e3],
+    [0,         0,         1]
+])
+fx, fy = K[0, 0], K[1, 1]
+cx, cy = K[0, 2], K[1, 2]
+
+def pointscloud2depthmap(points):
+    # 初始化一个空的目标数组
+    point_image = np.zeros((height, width, 3), dtype=np.float32)
+    # 遍历点云中的每个点,进行投影并填充目标数组
+    for point in points:
+        X, Y, Z = point
+        if Z > 0:  # 确保Z值有效
+            # 计算2D图像坐标
+            u = int((X * fx) / Z + cx)
+            v = int((Y * fy) / Z + cy)
+
+            # 检查是否在图像边界内
+            if 0 <= u < width and 0 <= v < height:
+                point_image[v, u, :] = point
+
+    return point_image
+
+
+
+# # 使用imageio读取
+# loaded_depth_map = imageio.v3.imread(r"depth_map.tiff")
+# print(loaded_depth_map.shape)
+# print(loaded_depth_map.dtype)
+# # print(loaded_depth_map)
+
+
+
+
+# 加载PCD文件
+pcd = o3d.io.read_point_cloud(r"F:\DevTools\datasets\test.pcd")
+
+# 打印点的数量
+print("Number of points:", len(pcd.points))
+
+# 获取点云数据
+points = np.asarray(pcd.points)
+
+# 打印前5个点的坐标
+print("First 5 points:\n", points[:5])
+#
+# print(f'depth :{loaded_depth_map[0,0:5]}')
+#
+#
+# print(loaded_depth_map[102,113])
+#
+# # 将深度图转换为点云
+# height ,width = loaded_depth_map.shape[:2]
+# print(f'height:{height},width:{width}')
+
+
+# point_cloud_from_depth = []
+# for v in range(height):
+#     for u in range(width):
+#         x_,y_,z_=loaded_depth_map[v,u]
+#         print(f'x_,y_,z_:({x_},{y_},{z_})')
+#         depth = loaded_depth_map[v, u][-1]
+#         print(f'depth:{depth}')
+#         # if depth > 0:  # 忽略无效的深度值
+#         x = (u - cx) * depth / fx
+#         y = (v - cy) * depth / fy
+#         z = depth
+#         print(f'x,y,z:({x},{y},{z})')
+#         point_cloud_from_depth.append([x, y, z])
+#
+# point_cloud_from_depth = np.array(point_cloud_from_depth)
+#
+# # 打印从深度图生成的点云中的前5个点
+# print("First 5 points from depth map:\n", point_cloud_from_depth[:5])
+
+
+# 目标深度图尺寸
+height, width = 2000, 2000
+
+
+
+
+
+
+
+point_image=pointscloud2depthmap(points)
+
+# 打印结果以验证
+print("Shape of the projected point cloud:", point_image.shape)
+print("First few pixels (if any):", point_image[:5, :5, :])
+
+# 提取 Z 值作为深度图
+depth_map = point_image[:, :, 2]
+# depth_map=point_image
+# 处理无效点(例如,设置无效点的深度值为一个极大值)
+# invalid_depth_value = np.max(depth_map) * 2  # 或者选择其他合适的值
+# depth_map[depth_map == 0] = invalid_depth_value  # 将所有无效点(Z=0)替换为极大值
+
+# 打印深度图的一些信息以验证
+print("Depth map shape:", depth_map.shape)
+print("Depth map dtype:", depth_map.dtype)
+print("Min depth value:", np.min(depth_map))
+print("Max depth value:", np.max(depth_map))
+
+# 保存为 TIFF 文件
+output_tiff_path = 'depth_map.tiff'
+tifffile.imwrite(output_tiff_path, depth_map.astype(np.float16))
+
+print(f"Depth map saved to {output_tiff_path}")
+

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: I:/datasets/0322_suanzaisheng
+  datadir: I:/datasets/depth_map0510
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
   resume_from: