|
@@ -1,5 +1,7 @@
|
|
|
|
+
|
|
import os
|
|
import os
|
|
import time
|
|
import time
|
|
|
|
+from datetime import datetime
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
@@ -9,7 +11,7 @@ from models.base.base_trainer import BaseTrainer
|
|
from models.config.config_tool import read_yaml
|
|
from models.config.config_tool import read_yaml
|
|
from models.line_detect.dataset_LD import WirePointDataset
|
|
from models.line_detect.dataset_LD import WirePointDataset
|
|
from models.line_detect.postprocess import box_line_, show_
|
|
from models.line_detect.postprocess import box_line_, show_
|
|
-from utils.log_util import show_line, save_latest_model, save_best_model
|
|
|
|
|
|
+from utils.log_util import show_line, save_last_model, save_best_model
|
|
from tools import utils
|
|
from tools import utils
|
|
|
|
|
|
|
|
|
|
@@ -108,25 +110,31 @@ class Trainer(BaseTrainer):
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|
|
)
|
|
)
|
|
|
|
|
|
- # model = linenet_resnet50_fpn().to(self.device)
|
|
|
|
|
|
+ train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
|
|
|
|
+ wts_path = os.path.join(train_result_ptath, 'weights')
|
|
|
|
+ tb_path = os.path.join(train_result_ptath, 'logs')
|
|
|
|
+ writer = SummaryWriter(tb_path)
|
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
|
|
- writer = SummaryWriter(kwargs['io']['logdir'])
|
|
|
|
|
|
+ # writer = SummaryWriter(kwargs['io']['logdir'])
|
|
model.to(device)
|
|
model.to(device)
|
|
|
|
|
|
- # 加载权重
|
|
|
|
- save_path = 'logs/pth/best_model.pth'
|
|
|
|
- model, optimizer = self.load_best_model(model, optimizer, save_path, device)
|
|
|
|
|
|
|
|
- logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
|
|
|
|
- os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
|
|
|
|
- latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
|
|
|
|
- best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
|
|
|
|
|
|
+
|
|
|
|
+ # # 加载权重
|
|
|
|
+ # save_path = 'logs/pth/best_model.pth'
|
|
|
|
+ # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
|
|
|
|
+
|
|
|
|
+ # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
|
|
|
|
+ # os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
|
|
|
|
+ last_model_path = os.path.join(wts_path, 'last.pth')
|
|
|
|
+ best_model_path = os.path.join(wts_path, 'best.pth')
|
|
global_step = 0
|
|
global_step = 0
|
|
|
|
|
|
for epoch in range(kwargs['optim']['max_epoch']):
|
|
for epoch in range(kwargs['optim']['max_epoch']):
|
|
print(f"epoch:{epoch}")
|
|
print(f"epoch:{epoch}")
|
|
total_train_loss = 0.0
|
|
total_train_loss = 0.0
|
|
|
|
+
|
|
model.train()
|
|
model.train()
|
|
|
|
|
|
for imgs, targets in data_loader_train:
|
|
for imgs, targets in data_loader_train:
|
|
@@ -135,44 +143,37 @@ class Trainer(BaseTrainer):
|
|
# print(f'imgs:{len(imgs)}')
|
|
# print(f'imgs:{len(imgs)}')
|
|
# print(f'targets:{len(targets)}')
|
|
# print(f'targets:{len(targets)}')
|
|
losses = model(imgs, targets)
|
|
losses = model(imgs, targets)
|
|
- # print(losses)
|
|
|
|
loss = _loss(losses)
|
|
loss = _loss(losses)
|
|
|
|
+ total_train_loss += loss.item()
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.step()
|
|
self.writer_loss(writer, losses, epoch)
|
|
self.writer_loss(writer, losses, epoch)
|
|
|
|
|
|
|
|
+
|
|
|
|
+ avg_train_loss = total_train_loss / len(data_loader_train)
|
|
|
|
+ if epoch == 0:
|
|
|
|
+ best_loss = avg_train_loss;
|
|
|
|
+
|
|
|
|
+ writer.add_scalar('loss/train', avg_train_loss, epoch)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ if os.path.exists(f'{wts_path}/last.pt'):
|
|
|
|
+ os.remove(f'{wts_path}/last.pt')
|
|
|
|
+ # torch.save(model.state_dict(), f'{wts_path}/last.pt')
|
|
|
|
+ save_last_model(model,last_model_path,epoch,optimizer)
|
|
|
|
+ best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
|
|
|
|
+
|
|
model.eval()
|
|
model.eval()
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
t_start = time.time()
|
|
t_start = time.time()
|
|
print(f'start to predict:{t_start}')
|
|
print(f'start to predict:{t_start}')
|
|
pred = model(self.move_to_device(imgs, self.device))
|
|
pred = model(self.move_to_device(imgs, self.device))
|
|
- # t_end = time.time()
|
|
|
|
- # print(f'predict used:{t_end - t_start}')
|
|
|
|
- # t_start=time.time()
|
|
|
|
- # print(f'start to box_line:{t_start}')
|
|
|
|
- # pred_ = box_line_(pred) # 将box与line对应
|
|
|
|
- # t_end=time.time()
|
|
|
|
- # print(f'box_line_ used:{t_end-t_start}')
|
|
|
|
- # show_(imgs, pred_, epoch, writer)
|
|
|
|
|
|
+ t_end = time.time()
|
|
|
|
+ print(f'predict used:{t_end - t_start}')
|
|
if batch_idx == 0:
|
|
if batch_idx == 0:
|
|
show_line(imgs[0], pred, epoch, writer)
|
|
show_line(imgs[0], pred, epoch, writer)
|
|
break
|
|
break
|
|
- avg_train_loss = total_train_loss / len(data_loader_train)
|
|
|
|
- writer.add_scalar('loss/train', avg_train_loss, epoch)
|
|
|
|
- best_loss = 10000
|
|
|
|
- save_latest_model(
|
|
|
|
- model,
|
|
|
|
- latest_model_path,
|
|
|
|
- epoch,
|
|
|
|
- optimizer
|
|
|
|
- )
|
|
|
|
- best_loss = save_best_model(
|
|
|
|
- model,
|
|
|
|
- best_model_path,
|
|
|
|
- epoch,
|
|
|
|
- avg_train_loss,
|
|
|
|
- best_loss,
|
|
|
|
- optimizer
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+
|