|
|
@@ -260,9 +260,13 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
# ========== Validation ==========
|
|
|
with torch.no_grad():
|
|
|
- model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val', )
|
|
|
+ model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='eval')
|
|
|
|
|
|
- self.save_last_model(model, epoch, optimizer)
|
|
|
+ if epoch==0:
|
|
|
+ best_train_loss = epoch_train_loss
|
|
|
+ best_val_loss = epoch_val_loss
|
|
|
+
|
|
|
+ self.save_last_model(model,self.last_model_path, epoch, optimizer)
|
|
|
best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
|
|
|
best_train_loss,
|
|
|
optimizer)
|
|
|
@@ -272,8 +276,8 @@ class Trainer(BaseTrainer):
|
|
|
def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
|
|
|
if phase == 'train':
|
|
|
model.train()
|
|
|
- if phase == 'val':
|
|
|
- model.eval
|
|
|
+ if phase == 'eval':
|
|
|
+ model.eval()
|
|
|
|
|
|
total_loss = 0
|
|
|
epoch_step = 0
|
|
|
@@ -281,7 +285,12 @@ class Trainer(BaseTrainer):
|
|
|
for imgs, targets in data_loader:
|
|
|
imgs = self.move_to_device(imgs, device)
|
|
|
targets = self.move_to_device(targets, device)
|
|
|
- losses = model(imgs, targets)
|
|
|
+ if phase== 'eval':
|
|
|
+
|
|
|
+ result,losses = model(imgs, targets)
|
|
|
+ else:
|
|
|
+ losses = model(imgs, targets)
|
|
|
+
|
|
|
loss = _loss(losses)
|
|
|
total_loss += loss.item()
|
|
|
if phase == 'train':
|