|
@@ -32,9 +32,8 @@ def log_losses_to_tensorboard(writer, result, step):
|
|
|
writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
|
|
|
|
|
|
|
|
|
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq,writer, scaler=None):
|
|
|
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, scaler=None):
|
|
|
model.train()
|
|
|
- total_train_loss=0.0
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
|
|
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
|
|
header = f"Epoch: [{epoch}]"
|
|
@@ -130,7 +129,7 @@ def show_line(img, pred, epoch, writer):
|
|
|
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
plt.margins(0, 0)
|
|
|
for (a, b), s in zip(nlines, nscores):
|
|
|
- if s < t :
|
|
|
+ if s < t:
|
|
|
continue
|
|
|
plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
plt.scatter(a[1], a[0], **PLTOPTS)
|
|
@@ -186,24 +185,24 @@ def evaluate(model, data_loader, epoch, writer, device):
|
|
|
show_line(images[0], outputs[0], epoch, writer)
|
|
|
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
|
|
|
def train_cfg(model, cfg):
|
|
@@ -300,8 +299,11 @@ def train(model, **kwargs):
|
|
|
os.mkdir(wts_path)
|
|
|
os.mkdir(img_results_path)
|
|
|
|
|
|
+
|
|
|
+ total_train_loss = 0.0
|
|
|
+
|
|
|
for epoch in range(epochs):
|
|
|
- metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
|
|
|
+ metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, total_train_loss, None)
|
|
|
losses = metric_logger.meters['loss'].global_avg
|
|
|
print(f'epoch {epoch}:loss:{losses}')
|
|
|
if os.path.exists(f'{wts_path}/last.pt'):
|
|
@@ -363,4 +365,4 @@ def write_metric_logs(epoch, metric_logger, writer):
|
|
|
|
|
|
|
|
|
|
|
|
-
|
|
|
+
|