import io import os import numpy as np import torch from PIL import Image from matplotlib import pyplot as plt from libs.vision_libs.utils import draw_bounding_boxes from models.wirenet.postprocess import postprocess from torchvision import transforms import matplotlib as mpl from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from io import BytesIO from PIL import Image cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def c(x): return sm.to_rgba(x) def imshow(im): plt.close() plt.tight_layout() plt.imshow(im) plt.colorbar(sm, fraction=0.046) plt.xlim([0, im.shape[0]]) plt.ylim([im.shape[0], 0]) def save_last_model(model, save_path, epoch, optimizer=None): os.makedirs(os.path.dirname(save_path), exist_ok=True) checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), } if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() torch.save(checkpoint, save_path) def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None): os.makedirs(os.path.dirname(save_path), exist_ok=True) if current_loss <= best_loss: checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'loss': current_loss } if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() torch.save(checkpoint, save_path) print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}") return current_loss return best_loss # def show_line(img, pred, epoch, writer): # fig = plt.figure(figsize=(15, 15)) # # # ... your plotting code here ... # # # Save the figure to a BytesIO buffer # buf = BytesIO() # plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) # buf.seek(0) # # # Load the image from the buffer and convert to numpy array # image = Image.open(buf) # image_from_plot = np.array(image)[..., :3] # Keep RGB channels if there's an alpha # # # Close the figure to free memory # plt.close(fig) # # # Log the image to TensorBoard or other logger # writer.add_image('validate', image_from_plot, epoch, dataformats='HWC') def show_line(img, pred, epoch, writer): im = img.permute(1, 2, 0) writer.add_image("z-ori", im, epoch, dataformats="HWC") boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"], colors="yellow", width=1) writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC") PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} # print(f'pred[1]:{pred[1]}') H = pred[-1]['wires'] lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] scores = H["score"][0].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break # postprocess lines to remove overlapped lines diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) for i, t in enumerate([0.85]): plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) for (a, b), s in zip(nlines, nscores): if s < t: continue plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) plt.scatter(a[1], a[0], **PLTOPTS) plt.scatter(b[1], b[0], **PLTOPTS) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(im) plt.tight_layout() fig = plt.gcf() fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸 tmp_img=fig.canvas.tostring_argb() tmp_img_np=np.frombuffer(tmp_img, dtype=np.uint8) tmp_img_np=tmp_img_np.reshape(int(height), int(width), 4) img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道 # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape( # fig.canvas.get_width_height()[::-1] + (3,)) plt.close() img2 = transforms.ToTensor()(img_rgb) writer.add_image("z-output", img2, epoch)