Procházet zdrojové kódy

完善保存模型权重和日志功能

RenLiqiang před 3 měsíci
rodič
revize
8d26ae5bad
2 změnil soubory, kde provedl 22 přidání a 7 odebrání
  1. 0 4
      models/ins_detect/trainer.py
  2. 22 3
      utils/log_util.py

+ 0 - 4
models/ins_detect/trainer.py

@@ -14,11 +14,7 @@ from tools import utils, presets
 
 
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
-<<<<<<< HEAD
     model.train()
-=======
-    model.train1()
->>>>>>> dev
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"

+ 22 - 3
utils/log_util.py

@@ -7,9 +7,28 @@ from matplotlib import pyplot as plt
 from libs.vision_libs.utils import draw_bounding_boxes
 from models.wirenet.postprocess import postprocess
 from torchvision import transforms
+import matplotlib as mpl
 
 
-def save_latest_model(model, save_path, epoch, optimizer=None):
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+def save_last_model(model, save_path, epoch, optimizer=None):
     os.makedirs(os.path.dirname(save_path), exist_ok=True)
 
     checkpoint = {
@@ -25,7 +44,7 @@ def save_latest_model(model, save_path, epoch, optimizer=None):
 def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
     os.makedirs(os.path.dirname(save_path), exist_ok=True)
 
-    if current_loss < best_loss:
+    if current_loss <= best_loss:
         checkpoint = {
             'epoch': epoch,
             'model_state_dict': model.state_dict(),
@@ -80,7 +99,7 @@ def show_line(img, pred, epoch, writer):
         plt.tight_layout()
         fig = plt.gcf()
         fig.canvas.draw()
-        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+        image_from_plot = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).reshape(
             fig.canvas.get_width_height()[::-1] + (3,))
         plt.close()
         img2 = transforms.ToTensor()(image_from_plot)