|
|
@@ -14,9 +14,8 @@ import os
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
import matplotlib.pyplot as plt
|
|
|
-import matplotlib as mpl
|
|
|
import numpy as np
|
|
|
-# from models.line_detect.line_net import linenet_resnet50_fpn
|
|
|
+
|
|
|
from torchvision import transforms
|
|
|
|
|
|
from models.wirenet.postprocess import postprocess
|
|
|
@@ -276,7 +275,7 @@ def show_predict(imgs, pred, threshold, t_start):
|
|
|
|
|
|
|
|
|
class Predict:
|
|
|
- def __init__(self, pt_path, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
|
|
|
+ def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
|
|
|
"""
|
|
|
初始化预测器。
|
|
|
|
|
|
@@ -290,9 +289,11 @@ class Predict:
|
|
|
show: 是否显示结果。
|
|
|
device: 运行设备(默认 'cuda')。
|
|
|
"""
|
|
|
- self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
|
|
+ # self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
|
|
+
|
|
|
self.model = model
|
|
|
- self.pt_path = pt_path
|
|
|
+ self.device = next(model.parameters()).device
|
|
|
+ # self.pt_path = pt_path
|
|
|
self.img = self.load_image(img)
|
|
|
self.type = type
|
|
|
self.threshold = threshold
|
|
|
@@ -338,16 +339,16 @@ class Predict:
|
|
|
|
|
|
def predict(self):
|
|
|
"""执行预测"""
|
|
|
- model = self.load_best_model(self.model, self.pt_path, device)
|
|
|
-
|
|
|
- model.eval()
|
|
|
+ # model = self.load_best_model(self.model, self.pt_path, device)
|
|
|
+ #
|
|
|
+ # model.eval()
|
|
|
|
|
|
# 预处理图像
|
|
|
img_ = self.preprocess_image(self.img)
|
|
|
|
|
|
# 模型推理
|
|
|
with torch.no_grad():
|
|
|
- predictions = model([img_.to(self.device)])
|
|
|
+ predictions =self.model([img_.to(self.device)])
|
|
|
print("Model predictions completed.")
|
|
|
|
|
|
# 后处理
|
|
|
@@ -469,3 +470,4 @@ class Predict1:
|
|
|
def run(self):
|
|
|
"""运行预测流程"""
|
|
|
self.predict()
|
|
|
+
|