lstrlq 5 kuukautta sitten
vanhempi
commit
186e2cf16b
2 muutettua tiedostoa jossa 19 lisäystä ja 9 poistoa
  1. 1 0
      models/keypoint/kepointrcnn.py
  2. 18 9
      models/keypoint/trainer.py

+ 1 - 0
models/keypoint/kepointrcnn.py

@@ -29,6 +29,7 @@ class KeypointRCNNModel(nn.Module):
         self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
                                                                               num_keypoints=num_keypoints,
                                                                               progress=False)
+        
         if transforms is None:
             self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
         # if num_classes != 0:

+ 18 - 9
models/keypoint/trainer.py

@@ -102,27 +102,29 @@ def c(x):
 
 
 def show_line(img, pred, epoch, writer):
-    im = img.permute(1, 2, 0)
+    im = img.permute(1, 2, 0)   # [512, 512, 3]
     writer.add_image("ori", im, epoch, dataformats="HWC")
 
     boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
                                       colors="yellow", width=1)
     writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+    print(f'box:{pred["boxes"][:5,:]}')
+    print(f'line:{pred["keypoints"][:5,:]}')
 
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
     # H = pred[1]['wires']
     lines = pred["keypoints"].detach().cpu().numpy()
     scores = pred["keypoints_scores"].detach().cpu().numpy()
-    for i in range(1, len(lines)):
-        if (lines[i] == lines[0]).all():
-            lines = lines[:i]
-            scores = scores[:i]
-            break
+    # for i in range(1, len(lines)):
+    #     if (lines[i] == lines[0]).all():
+    #         lines = lines[:i]
+    #         scores = scores[:i]
+    #         break
 
     # postprocess lines to remove overlapped lines
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
-    print(f'nscores:{nscores}')
+    # print(f'nscores:{nscores}')
 
     for i, t in enumerate([0.5]):
         plt.gca().set_axis_off()
@@ -184,6 +186,10 @@ def evaluate(model, data_loader, epoch, writer, device):
         if batch_idx == 0:
             show_line(images[0], outputs[0], epoch, writer)
 
+        # print(f'outputs:{outputs}')
+        # print(f'outputs[0]:{outputs[0]}')
+
+
     #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
     #     model_time = time.time() - model_time
     #
@@ -278,7 +284,6 @@ def train(model, **kwargs):
                                    dataset_type='val')
 
     train_sampler = torch.utils.data.RandomSampler(dataset)
-    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
     test_sampler = torch.utils.data.RandomSampler(dataset_test)
     train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
     train_collate_fn = utils.collate_fn
@@ -289,6 +294,7 @@ def train(model, **kwargs):
         dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
     )
 
+
     img_results_path = os.path.join(train_result_ptath, 'img_results')
     if os.path.exists(train_result_ptath):
         pass
@@ -310,7 +316,7 @@ def train(model, **kwargs):
         if os.path.exists(f'{wts_path}/last.pt'):
             os.remove(f'{wts_path}/last.pt')
         torch.save(model.state_dict(), f'{wts_path}/last.pt')
-        write_metric_logs(epoch, metric_logger, writer)
+        # write_metric_logs(epoch, metric_logger, writer)
         if epoch == 0:
             best_loss = losses;
         if best_loss >= losses:
@@ -320,6 +326,9 @@ def train(model, **kwargs):
             torch.save(model.state_dict(), f'{wts_path}/best.pt')
 
         evaluate(model, data_loader_test, epoch, writer, device=device)
+        avg_train_loss = total_train_loss / len(data_loader)
+
+        writer.add_scalar('Loss/train', avg_train_loss, epoch)
 
 
 def get_transform(is_train, **kwargs):