log_util.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import os
  2. import numpy as np
  3. import torch
  4. from matplotlib import pyplot as plt
  5. from libs.vision_libs.utils import draw_bounding_boxes
  6. from models.wirenet.postprocess import postprocess
  7. from torchvision import transforms
  8. def save_latest_model(model, save_path, epoch, optimizer=None):
  9. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  10. checkpoint = {
  11. 'epoch': epoch,
  12. 'model_state_dict': model.state_dict(),
  13. }
  14. if optimizer is not None:
  15. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  16. torch.save(checkpoint, save_path)
  17. def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
  18. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  19. if current_loss < best_loss:
  20. checkpoint = {
  21. 'epoch': epoch,
  22. 'model_state_dict': model.state_dict(),
  23. 'loss': current_loss
  24. }
  25. if optimizer is not None:
  26. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  27. torch.save(checkpoint, save_path)
  28. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  29. return current_loss
  30. return best_loss
  31. def show_line(img, pred, epoch, writer):
  32. im = img.permute(1, 2, 0)
  33. writer.add_image("ori", im, epoch, dataformats="HWC")
  34. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  35. colors="yellow", width=1)
  36. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  37. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  38. # print(f'pred[1]:{pred[1]}')
  39. H = pred[-1]['wires']
  40. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  41. scores = H["score"][0].cpu().numpy()
  42. for i in range(1, len(lines)):
  43. if (lines[i] == lines[0]).all():
  44. lines = lines[:i]
  45. scores = scores[:i]
  46. break
  47. # postprocess lines to remove overlapped lines
  48. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  49. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  50. for i, t in enumerate([0.85]):
  51. plt.gca().set_axis_off()
  52. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  53. plt.margins(0, 0)
  54. for (a, b), s in zip(nlines, nscores):
  55. if s < t:
  56. continue
  57. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  58. plt.scatter(a[1], a[0], **PLTOPTS)
  59. plt.scatter(b[1], b[0], **PLTOPTS)
  60. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  61. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  62. plt.imshow(im)
  63. plt.tight_layout()
  64. fig = plt.gcf()
  65. fig.canvas.draw()
  66. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  67. fig.canvas.get_width_height()[::-1] + (3,))
  68. plt.close()
  69. img2 = transforms.ToTensor()(image_from_plot)
  70. writer.add_image("output", img2, epoch)