log_util.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. import numpy as np
  3. import torch
  4. from matplotlib import pyplot as plt
  5. from libs.vision_libs.utils import draw_bounding_boxes
  6. from models.wirenet.postprocess import postprocess
  7. from torchvision import transforms
  8. import matplotlib as mpl
  9. cmap = plt.get_cmap("jet")
  10. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  11. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  12. sm.set_array([])
  13. def c(x):
  14. return sm.to_rgba(x)
  15. def imshow(im):
  16. plt.close()
  17. plt.tight_layout()
  18. plt.imshow(im)
  19. plt.colorbar(sm, fraction=0.046)
  20. plt.xlim([0, im.shape[0]])
  21. plt.ylim([im.shape[0], 0])
  22. def save_last_model(model, save_path, epoch, optimizer=None):
  23. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  24. checkpoint = {
  25. 'epoch': epoch,
  26. 'model_state_dict': model.state_dict(),
  27. }
  28. if optimizer is not None:
  29. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  30. torch.save(checkpoint, save_path)
  31. def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
  32. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  33. if current_loss <= best_loss:
  34. checkpoint = {
  35. 'epoch': epoch,
  36. 'model_state_dict': model.state_dict(),
  37. 'loss': current_loss
  38. }
  39. if optimizer is not None:
  40. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  41. torch.save(checkpoint, save_path)
  42. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  43. return current_loss
  44. return best_loss
  45. def show_line(img, pred, epoch, writer):
  46. im = img.permute(1, 2, 0)
  47. writer.add_image("ori", im, epoch, dataformats="HWC")
  48. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  49. colors="yellow", width=1)
  50. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  51. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  52. # print(f'pred[1]:{pred[1]}')
  53. H = pred[-1]['wires']
  54. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  55. scores = H["score"][0].cpu().numpy()
  56. for i in range(1, len(lines)):
  57. if (lines[i] == lines[0]).all():
  58. lines = lines[:i]
  59. scores = scores[:i]
  60. break
  61. # postprocess lines to remove overlapped lines
  62. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  63. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  64. for i, t in enumerate([0.85]):
  65. plt.gca().set_axis_off()
  66. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  67. plt.margins(0, 0)
  68. for (a, b), s in zip(nlines, nscores):
  69. if s < t:
  70. continue
  71. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  72. plt.scatter(a[1], a[0], **PLTOPTS)
  73. plt.scatter(b[1], b[0], **PLTOPTS)
  74. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  75. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  76. plt.imshow(im)
  77. plt.tight_layout()
  78. fig = plt.gcf()
  79. fig.canvas.draw()
  80. image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(
  81. fig.canvas.get_width_height()[::-1] + (3,))
  82. plt.close()
  83. img2 = transforms.ToTensor()(image_from_plot)
  84. writer.add_image("output", img2, epoch)