|
|
@@ -284,6 +284,7 @@ class Trainer(BaseTrainer):
|
|
|
points = result['circles']
|
|
|
points = points.squeeze(1)
|
|
|
print(f'points shape:{points.shape}')
|
|
|
+ features = result['features']
|
|
|
|
|
|
circle_image = img.cpu().numpy().transpose((1, 2, 0)) # CHW -> HWC
|
|
|
circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
|
|
|
@@ -309,6 +310,7 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
# keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
|
|
|
self.writer.add_image('z-out-circle', img_tensor, global_step=epoch)
|
|
|
+ self.writer.add_image('z-feature', features, global_step=epoch)
|
|
|
|
|
|
# cv2.imshow('arc', img_rgb)
|
|
|
# cv2.waitKey(1000000)
|