|
@@ -128,17 +128,17 @@ class Trainer(BaseTrainer):
|
|
|
print(f"No saved model found at {save_path}")
|
|
print(f"No saved model found at {save_path}")
|
|
|
return model, optimizer
|
|
return model, optimizer
|
|
|
|
|
|
|
|
- def writer_loss(self, writer, losses, epoch):
|
|
|
|
|
|
|
+ def writer_loss(self, writer, losses, epoch,mode='train'):
|
|
|
try:
|
|
try:
|
|
|
for key, value in losses.items():
|
|
for key, value in losses.items():
|
|
|
if key == 'loss_wirepoint':
|
|
if key == 'loss_wirepoint':
|
|
|
for subdict in losses['loss_wirepoint']['losses']:
|
|
for subdict in losses['loss_wirepoint']['losses']:
|
|
|
for subkey, subvalue in subdict.items():
|
|
for subkey, subvalue in subdict.items():
|
|
|
- writer.add_scalar(f'loss/{subkey}',
|
|
|
|
|
|
|
+ writer.add_scalar(f'{mode}/loss/{subkey}',
|
|
|
subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
epoch)
|
|
epoch)
|
|
|
elif isinstance(value, torch.Tensor):
|
|
elif isinstance(value, torch.Tensor):
|
|
|
- writer.add_scalar(f'loss/{key}', value.item(), epoch)
|
|
|
|
|
|
|
+ writer.add_scalar(f'{mode}/loss/{key}', value.item(), epoch)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
print(f"TensorBoard logging error: {e}")
|
|
print(f"TensorBoard logging error: {e}")
|
|
|
|
|
|
|
@@ -184,7 +184,8 @@ class Trainer(BaseTrainer):
|
|
|
last_model_path = os.path.join(wts_path, 'last.pth')
|
|
last_model_path = os.path.join(wts_path, 'last.pth')
|
|
|
best_train_model_path = os.path.join(wts_path, 'best_train.pth')
|
|
best_train_model_path = os.path.join(wts_path, 'best_train.pth')
|
|
|
best_val_model_path = os.path.join(wts_path, 'best_val.pth')
|
|
best_val_model_path = os.path.join(wts_path, 'best_val.pth')
|
|
|
- global_step = 0
|
|
|
|
|
|
|
+ global_train_step = 0
|
|
|
|
|
+ global_val_step = 0
|
|
|
|
|
|
|
|
for epoch in range(kwargs['optim']['max_epoch']):
|
|
for epoch in range(kwargs['optim']['max_epoch']):
|
|
|
print(f"epoch:{epoch}")
|
|
print(f"epoch:{epoch}")
|
|
@@ -199,8 +200,8 @@ class Trainer(BaseTrainer):
|
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
optimizer.step()
|
|
|
- self.writer_loss(writer, losses, global_step)
|
|
|
|
|
- global_step += 1
|
|
|
|
|
|
|
+ self.writer_loss(writer, losses, global_train_step)
|
|
|
|
|
+ global_train_step += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -208,8 +209,9 @@ class Trainer(BaseTrainer):
|
|
|
print(f'model.eval!!')
|
|
print(f'model.eval!!')
|
|
|
# ========== Validation ==========
|
|
# ========== Validation ==========
|
|
|
total_val_loss = 0.0
|
|
total_val_loss = 0.0
|
|
|
|
|
+ batch_idx=0
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
|
|
|
|
+ for imgs, targets in data_loader_val:
|
|
|
t_start = time.time()
|
|
t_start = time.time()
|
|
|
print(f'start to predict:{t_start}')
|
|
print(f'start to predict:{t_start}')
|
|
|
|
|
|
|
@@ -217,7 +219,9 @@ class Trainer(BaseTrainer):
|
|
|
targets = move_to_device(targets, device)
|
|
targets = move_to_device(targets, device)
|
|
|
print(f'targets:{targets}')
|
|
print(f'targets:{targets}')
|
|
|
|
|
|
|
|
- losses = model(imgs, targets)
|
|
|
|
|
|
|
+ _,losses = model(imgs, targets)
|
|
|
|
|
+ self.writer_loss(writer, losses, global_val_step,mode='val')
|
|
|
|
|
+ global_val_step+=1
|
|
|
print(f'val losses:{losses}')
|
|
print(f'val losses:{losses}')
|
|
|
loss = _loss(losses)
|
|
loss = _loss(losses)
|
|
|
total_val_loss += loss.item()
|
|
total_val_loss += loss.item()
|
|
@@ -229,7 +233,8 @@ class Trainer(BaseTrainer):
|
|
|
print(f'predict used:{t_end - t_start}')
|
|
print(f'predict used:{t_end - t_start}')
|
|
|
if batch_idx == 0:
|
|
if batch_idx == 0:
|
|
|
show_line(imgs[0], pred, epoch, writer)
|
|
show_line(imgs[0], pred, epoch, writer)
|
|
|
- break
|
|
|
|
|
|
|
+ batch_idx+=1
|
|
|
|
|
+
|
|
|
|
|
|
|
|
avg_val_loss = total_val_loss / len(data_loader_val)
|
|
avg_val_loss = total_val_loss / len(data_loader_val)
|
|
|
# print(f'avg_val_loss:{avg_val_loss}')
|
|
# print(f'avg_val_loss:{avg_val_loss}')
|