Prechádzať zdrojové kódy

修复训练报错bugs

RenLiqiang 7 mesiacov pred
rodič
commit
278640668c
1 zmenil súbory, kde vykonal 14 pridanie a 5 odobranie
  1. 14 5
      models/line_detect/trainer.py

+ 14 - 5
models/line_detect/trainer.py

@@ -260,9 +260,13 @@ class Trainer(BaseTrainer):
 
             # ========== Validation ==========
             with torch.no_grad():
-                model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val', )
+                model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='eval')
 
-            self.save_last_model(model, epoch, optimizer)
+            if epoch==0:
+                best_train_loss = epoch_train_loss
+                best_val_loss = epoch_val_loss
+                
+            self.save_last_model(model,self.last_model_path, epoch, optimizer)
             best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
                                                    best_train_loss,
                                                    optimizer)
@@ -272,8 +276,8 @@ class Trainer(BaseTrainer):
     def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
         if phase == 'train':
             model.train()
-        if phase == 'val':
-            model.eval
+        if phase == 'eval':
+            model.eval()
 
         total_loss = 0
         epoch_step = 0
@@ -281,7 +285,12 @@ class Trainer(BaseTrainer):
         for imgs, targets in data_loader:
             imgs = self.move_to_device(imgs, device)
             targets = self.move_to_device(targets, device)
-            losses = model(imgs, targets)
+            if phase== 'eval':
+
+                result,losses = model(imgs, targets)
+            else:
+                losses = model(imgs, targets)
+
             loss = _loss(losses)
             total_loss += loss.item()
             if phase == 'train':