|
|
@@ -191,7 +191,7 @@ class Trainer(BaseTrainer):
|
|
|
|
|
|
|
|
|
|
|
|
- def writer_predict_result(self, img, result, epoch):
|
|
|
+ def writer_predict_result(self, img, result, epoch,type=1):
|
|
|
img = img.cpu().detach()
|
|
|
im = img.permute(1, 2, 0) # [512, 512, 3]
|
|
|
self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
|
|
|
@@ -203,19 +203,24 @@ class Trainer(BaseTrainer):
|
|
|
# plt.show()
|
|
|
|
|
|
self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
- keypoint_img = draw_keypoints(boxed_image, result['lines'], colors='red', width=3)
|
|
|
|
|
|
- self.writer.add_image("z-output", keypoint_img, epoch)
|
|
|
- print("lines shape:", result['lines'].shape)
|
|
|
|
|
|
- # 用自己写的函数画线段
|
|
|
- # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
|
|
|
- print(f"shape of linescore:{result['liness_scores'].shape}")
|
|
|
- scores = result['liness_scores'].mean(dim=1) # shape: [31]
|
|
|
+ if type==1:
|
|
|
+ keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
|
|
|
|
|
|
- line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'],scores, width=3, cmap='jet')
|
|
|
+ self.writer.add_image("z-output", keypoint_img, epoch)
|
|
|
+ # print("lines shape:", result['lines'].shape)
|
|
|
|
|
|
- self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+
|
|
|
+ if type==2:
|
|
|
+ # 用自己写的函数画线段
|
|
|
+ # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
|
|
|
+ print(f"shape of linescore:{result['liness_scores'].shape}")
|
|
|
+ scores = result['liness_scores'].mean(dim=1) # shape: [31]
|
|
|
+
|
|
|
+ line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'],scores, width=3, cmap='jet')
|
|
|
+
|
|
|
+ self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
|
|
|
|
|
|
|
|
|
@@ -311,7 +316,9 @@ class Trainer(BaseTrainer):
|
|
|
if phase== 'val':
|
|
|
result,loss_dict = model(imgs, targets)
|
|
|
losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
|
|
|
+
|
|
|
print(f'val losses:{losses}')
|
|
|
+ print(f'val result:{result}')
|
|
|
else:
|
|
|
loss_dict = model(imgs, targets)
|
|
|
losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
|