123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- import io
- import os
- import numpy as np
- import torch
- from PIL import Image
- from matplotlib import pyplot as plt
- from libs.vision_libs.utils import draw_bounding_boxes
- from models.wirenet.postprocess import postprocess
- from torchvision import transforms
- import matplotlib as mpl
- from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
- from io import BytesIO
- from PIL import Image
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def c(x):
- return sm.to_rgba(x)
- def imshow(im):
- plt.close()
- plt.tight_layout()
- plt.imshow(im)
- plt.colorbar(sm, fraction=0.046)
- plt.xlim([0, im.shape[0]])
- plt.ylim([im.shape[0], 0])
- def save_last_model(model, save_path, epoch, optimizer=None):
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- }
- if optimizer is not None:
- checkpoint['optimizer_state_dict'] = optimizer.state_dict()
- torch.save(checkpoint, save_path)
- def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- if current_loss <= best_loss:
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'loss': current_loss
- }
- if optimizer is not None:
- checkpoint['optimizer_state_dict'] = optimizer.state_dict()
- torch.save(checkpoint, save_path)
- print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
- return current_loss
- return best_loss
- # def show_line(img, pred, epoch, writer):
- # fig = plt.figure(figsize=(15, 15))
- #
- # # ... your plotting code here ...
- #
- # # Save the figure to a BytesIO buffer
- # buf = BytesIO()
- # plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
- # buf.seek(0)
- #
- # # Load the image from the buffer and convert to numpy array
- # image = Image.open(buf)
- # image_from_plot = np.array(image)[..., :3] # Keep RGB channels if there's an alpha
- #
- # # Close the figure to free memory
- # plt.close(fig)
- #
- # # Log the image to TensorBoard or other logger
- # writer.add_image('validate', image_from_plot, epoch, dataformats='HWC')
- def show_line(img, pred, epoch, writer):
- im = img.permute(1, 2, 0)
- writer.add_image("z-ori", im, epoch, dataformats="HWC")
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
- colors="yellow", width=1)
- writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
- # print(f'pred[1]:{pred[1]}')
- H = pred[-1]['wires']
- lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
- scores = H["score"][0].cpu().numpy()
- for i in range(1, len(lines)):
- if (lines[i] == lines[0]).all():
- lines = lines[:i]
- scores = scores[:i]
- break
- # postprocess lines to remove overlapped lines
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
- for i, t in enumerate([0.85]):
- plt.gca().set_axis_off()
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
- plt.margins(0, 0)
- for (a, b), s in zip(nlines, nscores):
- if s < t:
- continue
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
- plt.scatter(a[1], a[0], **PLTOPTS)
- plt.scatter(b[1], b[0], **PLTOPTS)
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- plt.imshow(im)
- plt.tight_layout()
- fig = plt.gcf()
- fig.canvas.draw()
- width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
- tmp_img=fig.canvas.tostring_argb()
- tmp_img_np=np.frombuffer(tmp_img, dtype=np.uint8)
- tmp_img_np=tmp_img_np.reshape(int(height), int(width), 4)
- img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
- # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
- # fig.canvas.get_width_height()[::-1] + (3,))
- # plt.close()
- img2 = transforms.ToTensor()(img_rgb)
- writer.add_image("z-output", img2, epoch)
|