trainer.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import os
  2. import time
  3. from datetime import datetime
  4. import torch
  5. from torch.utils.tensorboard import SummaryWriter
  6. from models.base.base_model import BaseModel
  7. from models.base.base_trainer import BaseTrainer
  8. from models.config.config_tool import read_yaml
  9. from models.line_detect.dataset_LD import WirePointDataset
  10. from models.line_detect.postprocess import box_line_, show_
  11. from utils.log_util import show_line, save_last_model, save_best_model
  12. from tools import utils
  13. def _loss(losses):
  14. total_loss = 0
  15. for i in losses.keys():
  16. if i != "loss_wirepoint":
  17. total_loss += losses[i]
  18. else:
  19. loss_labels = losses[i]["losses"]
  20. loss_labels_k = list(loss_labels[0].keys())
  21. for j, name in enumerate(loss_labels_k):
  22. loss = loss_labels[0][name].mean()
  23. total_loss += loss
  24. return total_loss
  25. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  26. def move_to_device(data, device):
  27. if isinstance(data, (list, tuple)):
  28. return type(data)(move_to_device(item, device) for item in data)
  29. elif isinstance(data, dict):
  30. return {key: move_to_device(value, device) for key, value in data.items()}
  31. elif isinstance(data, torch.Tensor):
  32. return data.to(device)
  33. else:
  34. return data # 对于非张量类型的数据不做任何改变
  35. class Trainer(BaseTrainer):
  36. def __init__(self, model=None,
  37. dataset=None,
  38. device='cuda',
  39. **kwargs):
  40. super().__init__(model,dataset,device,**kwargs)
  41. def move_to_device(self, data, device):
  42. if isinstance(data, (list, tuple)):
  43. return type(data)(self.move_to_device(item, device) for item in data)
  44. elif isinstance(data, dict):
  45. return {key: self.move_to_device(value, device) for key, value in data.items()}
  46. elif isinstance(data, torch.Tensor):
  47. return data.to(device)
  48. else:
  49. return data # 对于非张量类型的数据不做任何改变
  50. def load_best_model(self,model, optimizer, save_path, device):
  51. if os.path.exists(save_path):
  52. checkpoint = torch.load(save_path, map_location=device)
  53. model.load_state_dict(checkpoint['model_state_dict'])
  54. if optimizer is not None:
  55. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  56. epoch = checkpoint['epoch']
  57. loss = checkpoint['loss']
  58. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  59. else:
  60. print(f"No saved model found at {save_path}")
  61. return model, optimizer
  62. def writer_loss(self, writer, losses, epoch):
  63. try:
  64. for key, value in losses.items():
  65. if key == 'loss_wirepoint':
  66. for subdict in losses['loss_wirepoint']['losses']:
  67. for subkey, subvalue in subdict.items():
  68. writer.add_scalar(f'loss/{subkey}',
  69. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  70. epoch)
  71. elif isinstance(value, torch.Tensor):
  72. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  73. except Exception as e:
  74. print(f"TensorBoard logging error: {e}")
  75. def train_cfg(self, model:BaseModel, cfg):
  76. # cfg = r'./config/wireframe.yaml'
  77. cfg = read_yaml(cfg)
  78. print(f'cfg:{cfg}')
  79. # print(cfg['n_dyn_negl'])
  80. self.train(model, **cfg)
  81. def train(self, model, **kwargs):
  82. dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
  83. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  84. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  85. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
  86. train_collate_fn = utils.collate_fn_wirepoint
  87. data_loader_train = torch.utils.data.DataLoader(
  88. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  89. )
  90. dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
  91. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  92. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  93. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
  94. val_collate_fn = utils.collate_fn_wirepoint
  95. data_loader_val = torch.utils.data.DataLoader(
  96. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  97. )
  98. train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
  99. wts_path = os.path.join(train_result_ptath, 'weights')
  100. tb_path = os.path.join(train_result_ptath, 'logs')
  101. writer = SummaryWriter(tb_path)
  102. optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
  103. # writer = SummaryWriter(kwargs['io']['logdir'])
  104. model.to(device)
  105. # # 加载权重
  106. # save_path = 'logs/pth/best_model.pth'
  107. # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
  108. # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
  109. # os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
  110. last_model_path = os.path.join(wts_path, 'last.pth')
  111. best_model_path = os.path.join(wts_path, 'best.pth')
  112. global_step = 0
  113. for epoch in range(kwargs['optim']['max_epoch']):
  114. print(f"epoch:{epoch}")
  115. total_train_loss = 0.0
  116. model.train()
  117. for imgs, targets in data_loader_train:
  118. imgs = move_to_device(imgs, device)
  119. targets=move_to_device(targets,device)
  120. # print(f'imgs:{len(imgs)}')
  121. # print(f'targets:{len(targets)}')
  122. losses = model(imgs, targets)
  123. loss = _loss(losses)
  124. total_train_loss += loss.item()
  125. optimizer.zero_grad()
  126. loss.backward()
  127. optimizer.step()
  128. self.writer_loss(writer, losses, global_step)
  129. global_step+=1
  130. avg_train_loss = total_train_loss / len(data_loader_train)
  131. if epoch == 0:
  132. best_loss = avg_train_loss;
  133. writer.add_scalar('loss/train', avg_train_loss, epoch)
  134. if os.path.exists(f'{wts_path}/last.pt'):
  135. os.remove(f'{wts_path}/last.pt')
  136. # torch.save(model.state_dict(), f'{wts_path}/last.pt')
  137. save_last_model(model,last_model_path,epoch,optimizer)
  138. best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
  139. model.eval()
  140. with torch.no_grad():
  141. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  142. t_start = time.time()
  143. print(f'start to predict:{t_start}')
  144. pred = model(self.move_to_device(imgs, self.device))
  145. t_end = time.time()
  146. print(f'predict used:{t_end - t_start}')
  147. if batch_idx == 0:
  148. show_line(imgs[0], pred, epoch, writer)
  149. break