|
@@ -0,0 +1,169 @@
|
|
|
+import os
|
|
|
+
|
|
|
+import torch
|
|
|
+from torch.utils.tensorboard import SummaryWriter
|
|
|
+
|
|
|
+from models.base.base_model import BaseModel
|
|
|
+from models.base.base_trainer import BaseTrainer
|
|
|
+from models.config.config_tool import read_yaml
|
|
|
+from models.line_detect.dataset_LD import WirePointDataset
|
|
|
+from models.line_detect.postprocess import box_line_, show_
|
|
|
+from utils.log_util import show_line, save_latest_model, save_best_model
|
|
|
+from tools import utils
|
|
|
+
|
|
|
+
|
|
|
+def _loss(losses):
|
|
|
+ total_loss = 0
|
|
|
+ for i in losses.keys():
|
|
|
+ if i != "loss_wirepoint":
|
|
|
+ total_loss += losses[i]
|
|
|
+ else:
|
|
|
+ loss_labels = losses[i]["losses"]
|
|
|
+ loss_labels_k = list(loss_labels[0].keys())
|
|
|
+ for j, name in enumerate(loss_labels_k):
|
|
|
+ loss = loss_labels[0][name].mean()
|
|
|
+ total_loss += loss
|
|
|
+
|
|
|
+ return total_loss
|
|
|
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+def move_to_device(data, device):
|
|
|
+ if isinstance(data, (list, tuple)):
|
|
|
+ return type(data)(move_to_device(item, device) for item in data)
|
|
|
+ elif isinstance(data, dict):
|
|
|
+ return {key: move_to_device(value, device) for key, value in data.items()}
|
|
|
+ elif isinstance(data, torch.Tensor):
|
|
|
+ return data.to(device)
|
|
|
+ else:
|
|
|
+ return data # 对于非张量类型的数据不做任何改变
|
|
|
+
|
|
|
+class Trainer(BaseTrainer):
|
|
|
+ def __init__(self, model=None,
|
|
|
+ dataset=None,
|
|
|
+ device='cuda',
|
|
|
+ **kwargs):
|
|
|
+
|
|
|
+ super().__init__(model,dataset,device,**kwargs)
|
|
|
+
|
|
|
+ def move_to_device(self, data, device):
|
|
|
+ if isinstance(data, (list, tuple)):
|
|
|
+ return type(data)(self.move_to_device(item, device) for item in data)
|
|
|
+ elif isinstance(data, dict):
|
|
|
+ return {key: self.move_to_device(value, device) for key, value in data.items()}
|
|
|
+ elif isinstance(data, torch.Tensor):
|
|
|
+ return data.to(device)
|
|
|
+ else:
|
|
|
+ return data # 对于非张量类型的数据不做任何改变
|
|
|
+
|
|
|
+ def load_best_model(self,model, optimizer, save_path, device):
|
|
|
+ if os.path.exists(save_path):
|
|
|
+ checkpoint = torch.load(save_path, map_location=device)
|
|
|
+ model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
+ if optimizer is not None:
|
|
|
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
+ epoch = checkpoint['epoch']
|
|
|
+ loss = checkpoint['loss']
|
|
|
+ print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
|
|
|
+ else:
|
|
|
+ print(f"No saved model found at {save_path}")
|
|
|
+ return model, optimizer
|
|
|
+
|
|
|
+ def writer_loss(self, writer, losses, epoch):
|
|
|
+ try:
|
|
|
+ for key, value in losses.items():
|
|
|
+ if key == 'loss_wirepoint':
|
|
|
+ for subdict in losses['loss_wirepoint']['losses']:
|
|
|
+ for subkey, subvalue in subdict.items():
|
|
|
+ writer.add_scalar(f'loss/{subkey}',
|
|
|
+ subvalue.item() if hasattr(subvalue, 'item') else subvalue,
|
|
|
+ epoch)
|
|
|
+ elif isinstance(value, torch.Tensor):
|
|
|
+ writer.add_scalar(f'loss/{key}', value.item(), epoch)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"TensorBoard logging error: {e}")
|
|
|
+
|
|
|
+ def train_cfg(self, model:BaseModel, cfg):
|
|
|
+ # cfg = r'./config/wireframe.yaml'
|
|
|
+ cfg = read_yaml(cfg)
|
|
|
+ print(f'cfg:{cfg}')
|
|
|
+ # print(cfg['n_dyn_negl'])
|
|
|
+ self.train(model, **cfg)
|
|
|
+
|
|
|
+ def train(self, model, **kwargs):
|
|
|
+ dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
|
+ train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
|
+ # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
+ train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
|
|
|
+ train_collate_fn = utils.collate_fn_wirepoint
|
|
|
+ data_loader_train = torch.utils.data.DataLoader(
|
|
|
+ dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
|
+ )
|
|
|
+
|
|
|
+ dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
|
|
|
+ val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
|
+ # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
|
+ val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
|
|
|
+ val_collate_fn = utils.collate_fn_wirepoint
|
|
|
+ data_loader_val = torch.utils.data.DataLoader(
|
|
|
+ dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|
|
|
+ )
|
|
|
+
|
|
|
+ # model = linenet_resnet50_fpn().to(self.device)
|
|
|
+
|
|
|
+ optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
|
|
|
+ writer = SummaryWriter(kwargs['io']['logdir'])
|
|
|
+ model.to(device)
|
|
|
+
|
|
|
+ # 加载权重
|
|
|
+ save_path = 'logs/pth/best_model.pth'
|
|
|
+ model, optimizer = self.load_best_model(model, optimizer, save_path, device)
|
|
|
+
|
|
|
+ logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
|
|
|
+ os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
|
|
|
+ latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
|
|
|
+ best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
|
|
|
+ global_step = 0
|
|
|
+
|
|
|
+ for epoch in range(kwargs['optim']['max_epoch']):
|
|
|
+ print(f"epoch:{epoch}")
|
|
|
+ total_train_loss = 0.0
|
|
|
+ model.train()
|
|
|
+
|
|
|
+ for imgs, targets in data_loader_train:
|
|
|
+ imgs = move_to_device(imgs, device)
|
|
|
+ targets=move_to_device(targets,device)
|
|
|
+ # print(f'imgs:{len(imgs)}')
|
|
|
+ # print(f'targets:{len(targets)}')
|
|
|
+ losses = model(imgs, targets)
|
|
|
+ # print(losses)
|
|
|
+ loss = _loss(losses)
|
|
|
+ optimizer.zero_grad()
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ self.writer_loss(writer, losses, epoch)
|
|
|
+
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ for batch_idx, (imgs, targets) in enumerate(data_loader_val):
|
|
|
+ pred = model(self.move_to_device(imgs, self.device))
|
|
|
+ pred_ = box_line_(pred) # 将box与line对应
|
|
|
+ show_(imgs, pred_, epoch, writer)
|
|
|
+ if batch_idx == 0:
|
|
|
+ show_line(imgs[0], pred, epoch, writer)
|
|
|
+ break
|
|
|
+ avg_train_loss = total_train_loss / len(data_loader_train)
|
|
|
+ writer.add_scalar('loss/train', avg_train_loss, epoch)
|
|
|
+ best_loss = 10000
|
|
|
+ save_latest_model(
|
|
|
+ model,
|
|
|
+ latest_model_path,
|
|
|
+ epoch,
|
|
|
+ optimizer
|
|
|
+ )
|
|
|
+ best_loss = save_best_model(
|
|
|
+ model,
|
|
|
+ best_model_path,
|
|
|
+ epoch,
|
|
|
+ avg_train_loss,
|
|
|
+ best_loss,
|
|
|
+ optimizer
|
|
|
+ )
|