log_util.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import io
  2. import os
  3. import numpy as np
  4. import torch
  5. from PIL import Image
  6. from matplotlib import pyplot as plt
  7. from libs.vision_libs.utils import draw_bounding_boxes
  8. from models.wirenet.postprocess import postprocess
  9. from torchvision import transforms
  10. import matplotlib as mpl
  11. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  12. from io import BytesIO
  13. from PIL import Image
  14. cmap = plt.get_cmap("jet")
  15. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  16. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  17. sm.set_array([])
  18. def c(x):
  19. return sm.to_rgba(x)
  20. def imshow(im):
  21. plt.close()
  22. plt.tight_layout()
  23. plt.imshow(im)
  24. plt.colorbar(sm, fraction=0.046)
  25. plt.xlim([0, im.shape[0]])
  26. plt.ylim([im.shape[0], 0])
  27. def save_last_model(model, save_path, epoch, optimizer=None):
  28. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  29. checkpoint = {
  30. 'epoch': epoch,
  31. 'model_state_dict': model.state_dict(),
  32. }
  33. if optimizer is not None:
  34. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  35. torch.save(checkpoint, save_path)
  36. def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
  37. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  38. if current_loss <= best_loss:
  39. checkpoint = {
  40. 'epoch': epoch,
  41. 'model_state_dict': model.state_dict(),
  42. 'loss': current_loss
  43. }
  44. if optimizer is not None:
  45. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  46. torch.save(checkpoint, save_path)
  47. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  48. return current_loss
  49. return best_loss
  50. def show_line(img, pred, epoch, writer):
  51. img=img.cpu().detach()
  52. im = img.permute(1, 2, 0)
  53. writer.add_image("z-ori", im, epoch, dataformats="HWC")
  54. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  55. colors="yellow", width=1)
  56. writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  57. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  58. # print(f'pred[1]:{pred[1]}')
  59. heatmaps=pred[-2][0]
  60. print(f'heatmaps:{heatmaps.shape}')
  61. jmap = heatmaps[1: 2].cpu().detach()
  62. lmap = heatmaps[2: 3].cpu().detach()
  63. writer.add_image("z-jmap", jmap, epoch)
  64. writer.add_image("z-lmap", lmap, epoch)
  65. # plt.imshow(lmap)
  66. # plt.show()
  67. H = pred[-1]['wires']
  68. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  69. scores = H["score"][0].cpu().numpy()
  70. for i in range(1, len(lines)):
  71. if (lines[i] == lines[0]).all():
  72. lines = lines[:i]
  73. scores = scores[:i]
  74. break
  75. # postprocess lines to remove overlapped lines
  76. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  77. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  78. for i, t in enumerate([0]):
  79. plt.gca().set_axis_off()
  80. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  81. plt.margins(0, 0)
  82. for (a, b), s in zip(nlines, nscores):
  83. if s < t:
  84. continue
  85. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  86. plt.scatter(a[1], a[0], **PLTOPTS)
  87. plt.scatter(b[1], b[0], **PLTOPTS)
  88. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  89. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  90. plt.imshow(im)
  91. plt.tight_layout()
  92. fig = plt.gcf()
  93. fig.canvas.draw()
  94. width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
  95. tmp_img=fig.canvas.tostring_argb()
  96. tmp_img_np=np.frombuffer(tmp_img, dtype=np.uint8)
  97. tmp_img_np=tmp_img_np.reshape(int(height), int(width), 4)
  98. img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
  99. # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
  100. # fig.canvas.get_width_height()[::-1] + (3,))
  101. plt.close()
  102. img2 = transforms.ToTensor()(img_rgb)
  103. writer.add_image("z-output", img2, epoch)