Browse Source

keypoint可以正常训练,有初步效果

RenLiqiang 5 tháng trước cách đây
mục cha
commit
70c2f71295
2 tập tin đã thay đổi với 3 bổ sung3 xóa
  1. 1 1
      models/keypoint/keypoint_dataset.py
  2. 2 2
      models/keypoint/trainer.py

+ 1 - 1
models/keypoint/keypoint_dataset.py

@@ -125,7 +125,7 @@ class KeypointDataset(BaseDataset):
         # print(f'labels:{target["labels"]}')
         # target["boxes"] = line_boxes(target)
         target["boxes"], keypoints = line_boxes(target)
-        keypoints=keypoints/512
+        # keypoints=keypoints/512
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
         # keypoints= wire_labels["junc_coords"]

+ 2 - 2
models/keypoint/trainer.py

@@ -27,7 +27,7 @@ from tools import utils, presets
 def log_losses_to_tensorboard(writer, result, step):
     writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
     writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
-    writer.add_scalar('Loss/loss_keypoint', result['loss_keypoint'].item(), step)
+    writer.add_scalar('Loss/keypoint', result['loss_keypoint'].item(), step)
     writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
 
@@ -137,7 +137,7 @@ def show_line(img, pred, epoch, writer):
             # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
             # plt.scatter(a[1], a[0], **PLTOPTS)
             # plt.scatter(b[1], b[0], **PLTOPTS)
-            plt.plot([a[0], b[0]], [a[1], b[1]], c=c(s), linewidth=2, zorder=s)
+            plt.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=2, zorder=s)
             plt.scatter(a[0], a[1], **PLTOPTS)
             plt.scatter(b[0], b[1], **PLTOPTS)
         plt.gca().xaxis.set_major_locator(plt.NullLocator())