|
@@ -220,10 +220,11 @@ class LineNet(BaseDetectionNet):
|
|
|
self.trainer = Trainer()
|
|
self.trainer = Trainer()
|
|
|
self.trainer.train_from_cfg(model=self, cfg=cfg)
|
|
self.trainer.train_from_cfg(model=self, cfg=cfg)
|
|
|
|
|
|
|
|
- def load_best_model(self,model, save_path, device='cuda'):
|
|
|
|
|
|
|
+ def load_best_model(self,save_path, device='cuda'):
|
|
|
if os.path.exists(save_path):
|
|
if os.path.exists(save_path):
|
|
|
checkpoint = torch.load(save_path, map_location=device)
|
|
checkpoint = torch.load(save_path, map_location=device)
|
|
|
- model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
|
|
+
|
|
|
|
|
+ self.load_state_dict(checkpoint['model_state_dict'])
|
|
|
# if optimizer is not None:
|
|
# if optimizer is not None:
|
|
|
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
# epoch = checkpoint['epoch']
|
|
# epoch = checkpoint['epoch']
|
|
@@ -232,11 +233,14 @@ class LineNet(BaseDetectionNet):
|
|
|
print(f"Loaded model from {save_path}")
|
|
print(f"Loaded model from {save_path}")
|
|
|
else:
|
|
else:
|
|
|
print(f"No saved model found at {save_path}")
|
|
print(f"No saved model found at {save_path}")
|
|
|
- return model
|
|
|
|
|
|
|
+ return self
|
|
|
|
|
|
|
|
# 加载权重和推理一起
|
|
# 加载权重和推理一起
|
|
|
- def predict(self, pt_path, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
|
|
|
|
|
- self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
|
|
|
|
|
|
|
+ def predict(self,img_path, type=0, threshold=0.5, save_path=None, show=False):
|
|
|
|
|
+ # self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
|
|
|
|
|
+ self.eval()
|
|
|
|
|
+ self.to(device)
|
|
|
|
|
+ self.predict = Predict(self, img_path, type, threshold, save_path, show)
|
|
|
self.predict.run()
|
|
self.predict.run()
|
|
|
|
|
|
|
|
# 不加载权重
|
|
# 不加载权重
|