Sfoglia il codice sorgente

keypoint tensorboard

xue50 5 mesi fa
parent
commit
3f2d8516b5
1 ha cambiato i file con 25 aggiunte e 23 eliminazioni
  1. 25 23
      models/keypoint/trainer.py

+ 25 - 23
models/keypoint/trainer.py

@@ -32,9 +32,8 @@ def log_losses_to_tensorboard(writer, result, step):
     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
 
 
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq,writer, scaler=None):
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, scaler=None):
     model.train()
-    total_train_loss=0.0
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"
@@ -130,7 +129,7 @@ def show_line(img, pred, epoch, writer):
         plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
         plt.margins(0, 0)
         for (a, b), s in zip(nlines, nscores):
-            if s < t :
+            if s < t:
                 continue
             plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
             plt.scatter(a[1], a[0], **PLTOPTS)
@@ -186,24 +185,24 @@ def evaluate(model, data_loader, epoch, writer, device):
             show_line(images[0], outputs[0], epoch, writer)
 
     #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
-    #     #     model_time = time.time() - model_time
-    #     #
-    #     #     res = {target["image_id"]: output for target, output in zip(targets, outputs)}
-    #     #     evaluator_time = time.time()
-    #     #     coco_evaluator.update(res)
-    #     #     evaluator_time = time.time() - evaluator_time
-    #     #     metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
-    #     #
-    #     # # gather the stats from all processes
-    #     # metric_logger.synchronize_between_processes()
-    #     # print("Averaged stats:", metric_logger)
-    #     # coco_evaluator.synchronize_between_processes()
-    #     #
-    #     # # accumulate predictions from all images
-    #     # coco_evaluator.accumulate()
-    #     # coco_evaluator.summarize()
-    #     # torch.set_num_threads(n_threads)
-    #     # return coco_evaluator
+    #     model_time = time.time() - model_time
+    #
+    #     res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+    #     evaluator_time = time.time()
+    #     coco_evaluator.update(res)
+    #     evaluator_time = time.time() - evaluator_time
+    #     metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+    #
+    # # gather the stats from all processes
+    # metric_logger.synchronize_between_processes()
+    # print("Averaged stats:", metric_logger)
+    # coco_evaluator.synchronize_between_processes()
+    #
+    # # accumulate predictions from all images
+    # coco_evaluator.accumulate()
+    # coco_evaluator.summarize()
+    # torch.set_num_threads(n_threads)
+    # return coco_evaluator
 
 
 def train_cfg(model, cfg):
@@ -300,8 +299,11 @@ def train(model, **kwargs):
         os.mkdir(wts_path)
         os.mkdir(img_results_path)
 
+
+    total_train_loss = 0.0
+
     for epoch in range(epochs):
-        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
         losses = metric_logger.meters['loss'].global_avg
         print(f'epoch {epoch}:loss:{losses}')
         if os.path.exists(f'{wts_path}/last.pt'):
@@ -363,4 +365,4 @@ def write_metric_logs(epoch, metric_logger, writer):
 #     writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
 #     writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
 #     writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
-#     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
+#     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)