Browse Source

修复matloptlib3.10版本stringargb函数导致的bug

RenLiqiang 3 tháng trước cách đây
mục cha
commit
64109be8bb
4 tập tin đã thay đổi với 52 bổ sung40 xóa
  1. 0 23
      models/keypoint/trainer.py
  2. 2 2
      models/line_detect/trainer.py
  3. 8 7
      readme.md
  4. 42 8
      utils/log_util.py

+ 0 - 23
models/keypoint/trainer.py

@@ -191,29 +191,6 @@ def evaluate(model, data_loader, epoch, writer, device):
         if batch_idx == 0:
             show_line(images[0], outputs[0], epoch, writer)
 
-        # print(f'outputs:{outputs}')
-        # print(f'outputs[0]:{outputs[0]}')
-
-
-    #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
-    #     model_time = time.time() - model_time
-    #
-    #     res = {target["image_id"]: output for target, output in zip(targets, outputs)}
-    #     evaluator_time = time.time()
-    #     coco_evaluator.update(res)
-    #     evaluator_time = time.time() - evaluator_time
-    #     metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
-    #
-    # # gather the stats from all processes
-    # metric_logger.synchronize_between_processes()
-    # print("Averaged stats:", metric_logger)
-    # coco_evaluator.synchronize_between_processes()
-    #
-    # # accumulate predictions from all images
-    # coco_evaluator.accumulate()
-    # coco_evaluator.summarize()
-    # torch.set_num_threads(n_threads)
-    # return coco_evaluator
 
 
 def train_cfg(model, cfg):

+ 2 - 2
models/line_detect/trainer.py

@@ -95,7 +95,7 @@ class Trainer(BaseTrainer):
         dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
         train_collate_fn = utils.collate_fn_wirepoint
         data_loader_train = torch.utils.data.DataLoader(
             dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
@@ -104,7 +104,7 @@ class Trainer(BaseTrainer):
         dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
         val_collate_fn = utils.collate_fn_wirepoint
         data_loader_val = torch.utils.data.DataLoader(
             dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn

+ 8 - 7
readme.md

@@ -2,13 +2,14 @@
 
 ## 1.Dependences
 
-| Name  | Version  | Build|  Channel |
-| -------- | -------- | -------- |------|
-| python    |3.12.8      | h3f84c4b_1_cpython|conda-forge|
-|pytorch| 2.2.2 |py3.12_cuda12.1_cudnn8_0|pytorch
-pytorch-cuda|12.1| hde6ce7c_6| pytorch
-torchvision|0.17.2 |pypi_0|pypi
-numpy|1.26.3|py312h8753938_0|conda-forge
+| Name  | Version | Build|  Channel |
+| -------- |---------| -------- |------|
+| python    | 3.12.8  | h3f84c4b_1_cpython|conda-forge|
+|pytorch| 2.3.1   |py3.12_cuda12.1_cudnn8_0|pytorch
+pytorch-cuda| 12.1    | hde6ce7c_6| pytorch
+torchvision| 0.17.2  |pypi_0|pypi
+numpy| 1.26.3  |py312h8753938_0|conda-forge
+matplotlib| 3.10.0  | pypi_0  |  pypi
 
 
 ## 2.Overview

+ 42 - 8
utils/log_util.py

@@ -1,14 +1,18 @@
+import io
 import os
 
 import numpy as np
 import torch
+from PIL import Image
 from matplotlib import pyplot as plt
 
 from libs.vision_libs.utils import draw_bounding_boxes
 from models.wirenet.postprocess import postprocess
 from torchvision import transforms
 import matplotlib as mpl
-
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
+from io import BytesIO
+from PIL import Image
 
 cmap = plt.get_cmap("jet")
 norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
@@ -60,13 +64,34 @@ def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=
         return current_loss
 
     return best_loss
+
+
+# def show_line(img, pred, epoch, writer):
+#     fig = plt.figure(figsize=(15, 15))
+#
+#     # ... your plotting code here ...
+#
+#     # Save the figure to a BytesIO buffer
+#     buf = BytesIO()
+#     plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
+#     buf.seek(0)
+#
+#     # Load the image from the buffer and convert to numpy array
+#     image = Image.open(buf)
+#     image_from_plot = np.array(image)[..., :3]  # Keep RGB channels if there's an alpha
+#
+#     # Close the figure to free memory
+#     plt.close(fig)
+#
+#     # Log the image to TensorBoard or other logger
+#     writer.add_image('validate', image_from_plot, epoch, dataformats='HWC')
 def show_line(img, pred, epoch, writer):
     im = img.permute(1, 2, 0)
-    writer.add_image("ori", im, epoch, dataformats="HWC")
+    writer.add_image("z-ori", im, epoch, dataformats="HWC")
 
     boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
                                       colors="yellow", width=1)
-    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+    writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
     # print(f'pred[1]:{pred[1]}')
@@ -99,9 +124,18 @@ def show_line(img, pred, epoch, writer):
         plt.tight_layout()
         fig = plt.gcf()
         fig.canvas.draw()
-        image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(
-            fig.canvas.get_width_height()[::-1] + (3,))
-        plt.close()
-        img2 = transforms.ToTensor()(image_from_plot)
 
-        writer.add_image("output", img2, epoch)
+        width, height = fig.get_size_inches() * fig.get_dpi()  # 获取图像尺寸
+        tmp_img=fig.canvas.tostring_argb()
+        tmp_img_np=np.frombuffer(tmp_img, dtype=np.uint8)
+        tmp_img_np=tmp_img_np.reshape(int(height), int(width), 4)
+
+        img_rgb = tmp_img_np[:, :, 1:]  # 提取RGB部分,忽略Alpha通道
+
+        # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
+        #     fig.canvas.get_width_height()[::-1] + (3,))
+        # plt.close()
+
+        img2 = transforms.ToTensor()(img_rgb)
+
+        writer.add_image("z-output", img2, epoch)