Browse Source

修复训练报错bugs

RenLiqiang 7 months ago
parent
commit
70ccab6e16
1 changed files with 4 additions and 4 deletions
  1. 4 4
      models/line_detect/trainer.py

+ 4 - 4
models/line_detect/trainer.py

@@ -260,12 +260,12 @@ class Trainer(BaseTrainer):
 
             # ========== Validation ==========
             with torch.no_grad():
-                model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='eval')
+                model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
 
             if epoch==0:
                 best_train_loss = epoch_train_loss
                 best_val_loss = epoch_val_loss
-                
+
             self.save_last_model(model,self.last_model_path, epoch, optimizer)
             best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
                                                    best_train_loss,
@@ -276,7 +276,7 @@ class Trainer(BaseTrainer):
     def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
         if phase == 'train':
             model.train()
-        if phase == 'eval':
+        if phase == 'val':
             model.eval()
 
         total_loss = 0
@@ -285,7 +285,7 @@ class Trainer(BaseTrainer):
         for imgs, targets in data_loader:
             imgs = self.move_to_device(imgs, device)
             targets = self.move_to_device(targets, device)
-            if phase== 'eval':
+            if phase== 'val':
 
                 result,losses = model(imgs, targets)
             else: