|
@@ -88,7 +88,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, wr
|
|
|
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
|
|
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
|
|
|
|
|
- return metric_logger
|
|
|
+ return metric_logger, total_train_loss
|
|
|
|
|
|
|
|
|
cmap = plt.get_cmap("jet")
|
|
@@ -310,7 +310,7 @@ def train(model, **kwargs):
|
|
|
total_train_loss = 0.0
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
- metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
|
|
|
+ metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
|
|
|
losses = metric_logger.meters['loss'].global_avg
|
|
|
print(f'epoch {epoch}:loss:{losses}')
|
|
|
if os.path.exists(f'{wts_path}/last.pt'):
|