瀏覽代碼

keypoint tensorboard_loss

xue50 5 月之前
父節點
當前提交
2dd75622af
共有 1 個文件被更改,包括 2 次插入2 次删除
  1. 2 2
      models/keypoint/trainer.py

+ 2 - 2
models/keypoint/trainer.py

@@ -88,7 +88,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, wr
         metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
         metric_logger.update(lr=optimizer.param_groups[0]["lr"])
 
-    return metric_logger
+    return metric_logger, total_train_loss
 
 
 cmap = plt.get_cmap("jet")
@@ -310,7 +310,7 @@ def train(model, **kwargs):
     total_train_loss = 0.0
 
     for epoch in range(epochs):
-        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
+        metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
         losses = metric_logger.meters['loss'].global_avg
         print(f'epoch {epoch}:loss:{losses}')
         if os.path.exists(f'{wts_path}/last.pt'):