소스 검색

修复tensorboard绘line图重叠bug

RenLiqiang 3 달 전
부모
커밋
0f7ce4ab80
4개의 변경된 파일18개의 추가작업 그리고 10개의 파일을 삭제
  1. 1 1
      config/wireframe.yaml
  2. 1 1
      models/line_detect/line_net.py
  3. 15 7
      train——line_rcnn.py
  4. 1 1
      utils/log_util.py

+ 1 - 1
config/wireframe.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: I:/datasets/wirenet_lm
+  datadir: I:/datasets/wirenet_1000
   resume_from:
   num_workers: 8
   tensorboard_port: 6000

+ 1 - 1
models/line_detect/line_net.py

@@ -67,7 +67,7 @@ class LineNet(BaseDetectionNet):
     #         backbone=backbone_factory.get_resnet18_fpn()
     #
     #     self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
-
+    
     def __init__(
             self,
             backbone,

+ 15 - 7
train——line_rcnn.py

@@ -97,12 +97,20 @@ def show_line(img, pred, epoch, writer):
         plt.tight_layout()
         fig = plt.gcf()
         fig.canvas.draw()
-        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
-            fig.canvas.get_width_height()[::-1] + (3,))
-        plt.close()
-        img2 = transforms.ToTensor()(image_from_plot)
+        width, height = fig.get_size_inches() * fig.get_dpi()  # 获取图像尺寸
+        tmp_img = fig.canvas.tostring_argb()
+        tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
+        tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
 
-        writer.add_image("output", img2, epoch)
+        img_rgb = tmp_img_np[:, :, 1:]  # 提取RGB部分,忽略Alpha通道
+
+        # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
+        #     fig.canvas.get_width_height()[::-1] + (3,))
+        # plt.close()
+
+        img2 = transforms.ToTensor()(img_rgb)
+
+        writer.add_image("z-output", img2, epoch)
 
 
 def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
@@ -282,8 +290,8 @@ if __name__ == '__main__':
             for batch_idx, (imgs, targets) in enumerate(data_loader_val):
                 pred = model(move_to_device(imgs, device))
 
-                pred_ = box_line_(pred)  # 将box与line对应
-                show_(imgs, pred_, epoch, writer)
+                # pred_ = box_line_(pred)  # 将box与line对应
+                # show_(imgs, pred_, epoch, writer)
 
                 if batch_idx == 0:
                     show_line(imgs[0], pred, epoch, writer)

+ 1 - 1
utils/log_util.py

@@ -134,7 +134,7 @@ def show_line(img, pred, epoch, writer):
 
         # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
         #     fig.canvas.get_width_height()[::-1] + (3,))
-        # plt.close()
+        plt.close()
 
         img2 = transforms.ToTensor()(img_rgb)