|
@@ -7,9 +7,28 @@ from matplotlib import pyplot as plt
|
|
|
from libs.vision_libs.utils import draw_bounding_boxes
|
|
|
from models.wirenet.postprocess import postprocess
|
|
|
from torchvision import transforms
|
|
|
+import matplotlib as mpl
|
|
|
|
|
|
|
|
|
-def save_latest_model(model, save_path, epoch, optimizer=None):
|
|
|
+cmap = plt.get_cmap("jet")
|
|
|
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
+sm.set_array([])
|
|
|
+
|
|
|
+
|
|
|
+def c(x):
|
|
|
+ return sm.to_rgba(x)
|
|
|
+
|
|
|
+
|
|
|
+def imshow(im):
|
|
|
+ plt.close()
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.colorbar(sm, fraction=0.046)
|
|
|
+ plt.xlim([0, im.shape[0]])
|
|
|
+ plt.ylim([im.shape[0], 0])
|
|
|
+
|
|
|
+def save_last_model(model, save_path, epoch, optimizer=None):
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
|
|
checkpoint = {
|
|
@@ -25,7 +44,7 @@ def save_latest_model(model, save_path, epoch, optimizer=None):
|
|
|
def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
|
|
- if current_loss < best_loss:
|
|
|
+ if current_loss <= best_loss:
|
|
|
checkpoint = {
|
|
|
'epoch': epoch,
|
|
|
'model_state_dict': model.state_dict(),
|
|
@@ -80,7 +99,7 @@ def show_line(img, pred, epoch, writer):
|
|
|
plt.tight_layout()
|
|
|
fig = plt.gcf()
|
|
|
fig.canvas.draw()
|
|
|
- image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
|
|
|
+ image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(
|
|
|
fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
plt.close()
|
|
|
img2 = transforms.ToTensor()(image_from_plot)
|