import os import torch from torch.utils.tensorboard import SummaryWriter from models.base.base_model import BaseModel from models.base.base_trainer import BaseTrainer from models.config.config_tool import read_yaml from models.line_detect.dataset_LD import WirePointDataset from models.line_detect.postprocess import box_line_, show_ from utils.log_util import show_line, save_latest_model, save_best_model from tools import utils def _loss(losses): total_loss = 0 for i in losses.keys(): if i != "loss_wirepoint": total_loss += losses[i] else: loss_labels = losses[i]["losses"] loss_labels_k = list(loss_labels[0].keys()) for j, name in enumerate(loss_labels_k): loss = loss_labels[0][name].mean() total_loss += loss return total_loss device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def move_to_device(data, device): if isinstance(data, (list, tuple)): return type(data)(move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 class Trainer(BaseTrainer): def __init__(self, model=None, dataset=None, device='cuda', **kwargs): super().__init__(model,dataset,device,**kwargs) def move_to_device(self, data, device): if isinstance(data, (list, tuple)): return type(data)(self.move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: self.move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 def load_best_model(self,model, optimizer, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model, optimizer def writer_loss(self, writer, losses, epoch): try: for key, value in losses.items(): if key == 'loss_wirepoint': for subdict in losses['loss_wirepoint']['losses']: for subkey, subvalue in subdict.items(): writer.add_scalar(f'loss/{subkey}', subvalue.item() if hasattr(subvalue, 'item') else subvalue, epoch) elif isinstance(value, torch.Tensor): writer.add_scalar(f'loss/{key}', value.item(), epoch) except Exception as e: print(f"TensorBoard logging error: {e}") def train_cfg(self, model:BaseModel, cfg): # cfg = r'./config/wireframe.yaml' cfg = read_yaml(cfg) print(f'cfg:{cfg}') # print(cfg['n_dyn_negl']) self.train(model, **cfg) def train(self, model, **kwargs): dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train') train_sampler = torch.utils.data.RandomSampler(dataset_train) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True) train_collate_fn = utils.collate_fn_wirepoint data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn ) dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val') val_sampler = torch.utils.data.RandomSampler(dataset_val) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True) val_collate_fn = utils.collate_fn_wirepoint data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn ) # model = linenet_resnet50_fpn().to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr']) writer = SummaryWriter(kwargs['io']['logdir']) model.to(device) # 加载权重 save_path = 'logs/pth/best_model.pth' model, optimizer = self.load_best_model(model, optimizer, save_path, device) logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth') os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在) latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth') best_model_path = os.path.join(logdir_with_pth, 'best_model.pth') global_step = 0 for epoch in range(kwargs['optim']['max_epoch']): print(f"epoch:{epoch}") total_train_loss = 0.0 model.train() for imgs, targets in data_loader_train: imgs = move_to_device(imgs, device) targets=move_to_device(targets,device) # print(f'imgs:{len(imgs)}') # print(f'targets:{len(targets)}') losses = model(imgs, targets) # print(losses) loss = _loss(losses) optimizer.zero_grad() loss.backward() optimizer.step() self.writer_loss(writer, losses, epoch) model.eval() with torch.no_grad(): for batch_idx, (imgs, targets) in enumerate(data_loader_val): pred = model(self.move_to_device(imgs, self.device)) pred_ = box_line_(pred) # 将box与line对应 show_(imgs, pred_, epoch, writer) if batch_idx == 0: show_line(imgs[0], pred, epoch, writer) break avg_train_loss = total_train_loss / len(data_loader_train) writer.add_scalar('loss/train', avg_train_loss, epoch) best_loss = 10000 save_latest_model( model, latest_model_path, epoch, optimizer ) best_loss = save_best_model( model, best_model_path, epoch, avg_train_loss, best_loss, optimizer )