|
|
@@ -260,12 +260,12 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
# ========== Validation ==========
|
|
|
with torch.no_grad():
|
|
|
- model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='eval')
|
|
|
+ model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
|
|
|
|
|
|
if epoch==0:
|
|
|
best_train_loss = epoch_train_loss
|
|
|
best_val_loss = epoch_val_loss
|
|
|
-
|
|
|
+
|
|
|
self.save_last_model(model,self.last_model_path, epoch, optimizer)
|
|
|
best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
|
|
|
best_train_loss,
|
|
|
@@ -276,7 +276,7 @@ class Trainer(BaseTrainer):
|
|
|
def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
|
|
|
if phase == 'train':
|
|
|
model.train()
|
|
|
- if phase == 'eval':
|
|
|
+ if phase == 'val':
|
|
|
model.eval()
|
|
|
|
|
|
total_loss = 0
|
|
|
@@ -285,7 +285,7 @@ class Trainer(BaseTrainer):
|
|
|
for imgs, targets in data_loader:
|
|
|
imgs = self.move_to_device(imgs, device)
|
|
|
targets = self.move_to_device(targets, device)
|
|
|
- if phase== 'eval':
|
|
|
+ if phase== 'val':
|
|
|
|
|
|
result,losses = model(imgs, targets)
|
|
|
else:
|