Jelajahi Sumber

添加convnext backbone

RenLiqiang 7 bulan lalu
induk
melakukan
dd55089e7e

+ 1 - 1
models/line_detect/line_net.py

@@ -339,7 +339,7 @@ _COMMON_META = {
 }
 
 
-def create_efficientnetv2_backbone(name='efficientnet_v2_m', pretrained=True):
+def create_efficientnetv2_backbone(name='efficientnet_v2_l', pretrained=True):
     # 加载EfficientNetV2模型
     if name == 'efficientnet_v2_s':
         weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None

+ 18 - 8
models/line_detect/predict2.py

@@ -171,6 +171,7 @@ from torchvision import transforms
 # from models.wirenet.postprocess import postprocess
 from models.wirenet.postprocess import postprocess
 from rtree import index
+import imageio
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -423,11 +424,18 @@ def predict(pt_path, model, img):
 
     model.eval()
 
-    if isinstance(img, str):
-        img = Image.open(img).convert("RGB")
+    # if isinstance(img, str):
+    #     img = Image.open(img).convert("RGB")
+    print(imageio.v3.imread(img_path).shape)
+    img = imageio.v3.imread(img_path).reshape(2114, 1332, 1)
+    img_3channel = np.zeros((2114, 1332, 3), dtype=img.dtype)
+    img_3channel[:, :, 2] = img[:, :, 0]
+    img = torch.from_numpy(img_3channel).permute(2, 0, 1)
 
-    transform = transforms.ToTensor()
-    img_tensor = transform(img)  # [3, 512, 512]
+    img_tensor = img
+
+    # transform = transforms.ToTensor()
+    # img_tensor = transform(img)  # [3, 512, 512]
 
     # 将图像调整为512x512大小
     t_start = time.time()
@@ -437,7 +445,7 @@ def predict(pt_path, model, img):
 
     im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
     if im.shape != (512, 512, 3):
-        im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+        im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_NEAREST)
     img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
 
     t_end = time.time()
@@ -463,16 +471,18 @@ def predict(pt_path, model, img):
     #
     # show_predict(img_, pred, t_start)
 
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, get_line_net_efficientnetv2
 
 if __name__ == '__main__':
     t_start = time.time()
     print(f'start to predict:{t_start}')
-    model = linenet_resnet50_fpn().to(device)
+    # model = linenet_resnet50_fpn().to(device)
+    model = get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
-    pt_path = r"\\192.168.50.222\share\lm\weight\20250425_112601\weights\best.pth"
+    pt_path = r"\\192.168.50.222\share\lm\weight\20250510_155941\weights\best.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
-    img_path = r"C:\Users\m2337\Desktop\p\140502.png"
+    img_path = r"D:\python\PycharmProjects\20250214\cloud\新建文件夹\depth_map.tiff"
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: I:/datasets/depth_map0510
+  datadir: \\192.168.50.222\share\lm\04\0510\0510_split
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
   resume_from: