Pārlūkot izejas kodu

keypoint tensorboard_loss

xue50 5 mēneši atpakaļ
vecāks
revīzija
0fe1d94de5
1 mainītis faili ar 19 papildinājumiem un 9 dzēšanām
  1. 19 9
      models/keypoint/trainer.py

+ 19 - 9
models/keypoint/trainer.py

@@ -102,27 +102,29 @@ def c(x):
 
 
 def show_line(img, pred, epoch, writer):
-    im = img.permute(1, 2, 0)
+    im = img.permute(1, 2, 0)   # [512, 512, 3]
     writer.add_image("ori", im, epoch, dataformats="HWC")
 
     boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
                                       colors="yellow", width=1)
     writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+    print(f'box:{pred["boxes"][:5,:]}')
+    print(f'line:{pred["keypoints"][:5,:]}')
 
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
     # H = pred[1]['wires']
     lines = pred["keypoints"].detach().cpu().numpy()
     scores = pred["keypoints_scores"].detach().cpu().numpy()
-    for i in range(1, len(lines)):
-        if (lines[i] == lines[0]).all():
-            lines = lines[:i]
-            scores = scores[:i]
-            break
+    # for i in range(1, len(lines)):
+    #     if (lines[i] == lines[0]).all():
+    #         lines = lines[:i]
+    #         scores = scores[:i]
+    #         break
 
     # postprocess lines to remove overlapped lines
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
-    print(f'nscores:{nscores}')
+    # print(f'nscores:{nscores}')
 
     for i, t in enumerate([0.5]):
         plt.gca().set_axis_off()
@@ -184,6 +186,10 @@ def evaluate(model, data_loader, epoch, writer, device):
         if batch_idx == 0:
             show_line(images[0], outputs[0], epoch, writer)
 
+        # print(f'outputs:{outputs}')
+        # print(f'outputs[0]:{outputs[0]}')
+
+
     #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
     #     model_time = time.time() - model_time
     #
@@ -278,7 +284,7 @@ def train(model, **kwargs):
                                    dataset_type='val')
 
     train_sampler = torch.utils.data.RandomSampler(dataset)
-    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    test_sampler = torch.utils.data.RandomSampler(dataset_test)
     train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
     train_collate_fn = utils.collate_fn
     data_loader = torch.utils.data.DataLoader(
@@ -288,6 +294,7 @@ def train(model, **kwargs):
         dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
     )
 
+
     img_results_path = os.path.join(train_result_ptath, 'img_results')
     if os.path.exists(train_result_ptath):
         pass
@@ -309,7 +316,7 @@ def train(model, **kwargs):
         if os.path.exists(f'{wts_path}/last.pt'):
             os.remove(f'{wts_path}/last.pt')
         torch.save(model.state_dict(), f'{wts_path}/last.pt')
-        write_metric_logs(epoch, metric_logger, writer)
+        # write_metric_logs(epoch, metric_logger, writer)
         if epoch == 0:
             best_loss = losses;
         if best_loss >= losses:
@@ -319,6 +326,9 @@ def train(model, **kwargs):
             torch.save(model.state_dict(), f'{wts_path}/best.pt')
 
         evaluate(model, data_loader_test, epoch, writer, device=device)
+        avg_train_loss = total_train_loss / len(data_loader)
+
+        writer.add_scalar('Loss/train', avg_train_loss, epoch)
 
 
 def get_transform(is_train, **kwargs):