|
@@ -1,14 +1,18 @@
|
|
|
+import io
|
|
|
import os
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
+from PIL import Image
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
from libs.vision_libs.utils import draw_bounding_boxes
|
|
|
from models.wirenet.postprocess import postprocess
|
|
|
from torchvision import transforms
|
|
|
import matplotlib as mpl
|
|
|
-
|
|
|
+from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
|
+from io import BytesIO
|
|
|
+from PIL import Image
|
|
|
|
|
|
cmap = plt.get_cmap("jet")
|
|
|
norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
@@ -60,13 +64,34 @@ def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=
|
|
|
return current_loss
|
|
|
|
|
|
return best_loss
|
|
|
+
|
|
|
+
|
|
|
+# def show_line(img, pred, epoch, writer):
|
|
|
+# fig = plt.figure(figsize=(15, 15))
|
|
|
+#
|
|
|
+# # ... your plotting code here ...
|
|
|
+#
|
|
|
+# # Save the figure to a BytesIO buffer
|
|
|
+# buf = BytesIO()
|
|
|
+# plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
|
|
+# buf.seek(0)
|
|
|
+#
|
|
|
+# # Load the image from the buffer and convert to numpy array
|
|
|
+# image = Image.open(buf)
|
|
|
+# image_from_plot = np.array(image)[..., :3] # Keep RGB channels if there's an alpha
|
|
|
+#
|
|
|
+# # Close the figure to free memory
|
|
|
+# plt.close(fig)
|
|
|
+#
|
|
|
+# # Log the image to TensorBoard or other logger
|
|
|
+# writer.add_image('validate', image_from_plot, epoch, dataformats='HWC')
|
|
|
def show_line(img, pred, epoch, writer):
|
|
|
im = img.permute(1, 2, 0)
|
|
|
- writer.add_image("ori", im, epoch, dataformats="HWC")
|
|
|
+ writer.add_image("z-ori", im, epoch, dataformats="HWC")
|
|
|
|
|
|
boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
|
|
|
colors="yellow", width=1)
|
|
|
- writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+ writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
|
|
|
PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
# print(f'pred[1]:{pred[1]}')
|
|
@@ -99,9 +124,18 @@ def show_line(img, pred, epoch, writer):
|
|
|
plt.tight_layout()
|
|
|
fig = plt.gcf()
|
|
|
fig.canvas.draw()
|
|
|
- image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(
|
|
|
- fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
- plt.close()
|
|
|
- img2 = transforms.ToTensor()(image_from_plot)
|
|
|
|
|
|
- writer.add_image("output", img2, epoch)
|
|
|
+ width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
|
|
|
+ tmp_img=fig.canvas.tostring_argb()
|
|
|
+ tmp_img_np=np.frombuffer(tmp_img, dtype=np.uint8)
|
|
|
+ tmp_img_np=tmp_img_np.reshape(int(height), int(width), 4)
|
|
|
+
|
|
|
+ img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
|
|
|
+
|
|
|
+ # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
|
|
|
+ # fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
+ # plt.close()
|
|
|
+
|
|
|
+ img2 = transforms.ToTensor()(img_rgb)
|
|
|
+
|
|
|
+ writer.add_image("z-output", img2, epoch)
|