RenLiqiang 7 месяцев назад
Родитель
Сommit
a6b19e6074

+ 9 - 5
models/line_detect/line_net.py

@@ -220,10 +220,11 @@ class LineNet(BaseDetectionNet):
         self.trainer = Trainer()
         self.trainer.train_from_cfg(model=self, cfg=cfg)
 
-    def load_best_model(self,model,  save_path, device='cuda'):
+    def load_best_model(self,save_path, device='cuda'):
         if os.path.exists(save_path):
             checkpoint = torch.load(save_path, map_location=device)
-            model.load_state_dict(checkpoint['model_state_dict'])
+
+            self.load_state_dict(checkpoint['model_state_dict'])
             # if optimizer is not None:
             #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
             # epoch = checkpoint['epoch']
@@ -232,11 +233,14 @@ class LineNet(BaseDetectionNet):
             print(f"Loaded model from {save_path}")
         else:
             print(f"No saved model found at {save_path}")
-        return model
+        return self
 
     # 加载权重和推理一起
-    def predict(self, pt_path, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
-        self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
+    def predict(self,img_path, type=0, threshold=0.5, save_path=None, show=False):
+        # self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
+        self.eval()
+        self.to(device)
+        self.predict = Predict(self, img_path, type, threshold, save_path, show)
         self.predict.run()
 
     # 不加载权重

+ 11 - 9
models/line_detect/predict.py

@@ -14,9 +14,8 @@ import os
 import torch
 from PIL import Image
 import matplotlib.pyplot as plt
-import matplotlib as mpl
 import numpy as np
-# from models.line_detect.line_net import linenet_resnet50_fpn
+
 from torchvision import transforms
 
 from models.wirenet.postprocess import postprocess
@@ -276,7 +275,7 @@ def show_predict(imgs, pred, threshold, t_start):
 
 
 class Predict:
-    def __init__(self, pt_path, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
+    def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
         """
         初始化预测器。
 
@@ -290,9 +289,11 @@ class Predict:
             show: 是否显示结果。
             device: 运行设备(默认 'cuda')。
         """
-        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
+        # self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
+
         self.model = model
-        self.pt_path = pt_path
+        self.device = next(model.parameters()).device
+        # self.pt_path = pt_path
         self.img = self.load_image(img)
         self.type = type
         self.threshold = threshold
@@ -338,16 +339,16 @@ class Predict:
 
     def predict(self):
         """执行预测"""
-        model = self.load_best_model(self.model, self.pt_path, device)
-
-        model.eval()
+        # model = self.load_best_model(self.model, self.pt_path, device)
+        #
+        # model.eval()
 
         # 预处理图像
         img_ = self.preprocess_image(self.img)
 
         # 模型推理
         with torch.no_grad():
-            predictions = model([img_.to(self.device)])
+            predictions =self.model([img_.to(self.device)])
             print("Model predictions completed.")
 
         # 后处理
@@ -469,3 +470,4 @@ class Predict1:
     def run(self):
         """运行预测流程"""
         self.predict()
+

+ 9 - 0
models/line_detect/predict_demo.py

@@ -0,0 +1,9 @@
+from models.line_detect.line_net import linenet_resnet18_fpn
+
+if __name__ == '__main__':
+    model=linenet_resnet18_fpn()
+    model.load_best_model(r'E:\projects\tmp\MultiVisionModels\models\line_detect\train_results\20250515_173829\weights\best_val.pth')
+
+    img_path=r"\\192.168.50.222\share\rlq\datasets\修订513pcd转换彩图标注后汇总\2025-05-13-08-37-48_LaserData_ID019504_color.jpg"
+
+    model.predict(img_path)

+ 1 - 1
models/line_detect/trainer.py

@@ -44,7 +44,7 @@ class Trainer(BaseTrainer):
     def __init__(self, model=None, **kwargs):
         super().__init__(model, device, **kwargs)
         self.model = model
-        print(f'kwargs:{kwargs}')
+        # print(f'kwargs:{kwargs}')
         self.init_params(**kwargs)
 
     def init_params(self, **kwargs):