trainer.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import os
  2. import time
  3. import torch
  4. from torch.utils.tensorboard import SummaryWriter
  5. from models.base.base_model import BaseModel
  6. from models.base.base_trainer import BaseTrainer
  7. from models.config.config_tool import read_yaml
  8. from models.line_detect.dataset_LD import WirePointDataset
  9. from models.line_detect.postprocess import box_line_, show_
  10. from utils.log_util import show_line, save_latest_model, save_best_model
  11. from tools import utils
  12. def _loss(losses):
  13. total_loss = 0
  14. for i in losses.keys():
  15. if i != "loss_wirepoint":
  16. total_loss += losses[i]
  17. else:
  18. loss_labels = losses[i]["losses"]
  19. loss_labels_k = list(loss_labels[0].keys())
  20. for j, name in enumerate(loss_labels_k):
  21. loss = loss_labels[0][name].mean()
  22. total_loss += loss
  23. return total_loss
  24. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  25. def move_to_device(data, device):
  26. if isinstance(data, (list, tuple)):
  27. return type(data)(move_to_device(item, device) for item in data)
  28. elif isinstance(data, dict):
  29. return {key: move_to_device(value, device) for key, value in data.items()}
  30. elif isinstance(data, torch.Tensor):
  31. return data.to(device)
  32. else:
  33. return data # 对于非张量类型的数据不做任何改变
  34. class Trainer(BaseTrainer):
  35. def __init__(self, model=None,
  36. dataset=None,
  37. device='cuda',
  38. **kwargs):
  39. super().__init__(model,dataset,device,**kwargs)
  40. def move_to_device(self, data, device):
  41. if isinstance(data, (list, tuple)):
  42. return type(data)(self.move_to_device(item, device) for item in data)
  43. elif isinstance(data, dict):
  44. return {key: self.move_to_device(value, device) for key, value in data.items()}
  45. elif isinstance(data, torch.Tensor):
  46. return data.to(device)
  47. else:
  48. return data # 对于非张量类型的数据不做任何改变
  49. def load_best_model(self,model, optimizer, save_path, device):
  50. if os.path.exists(save_path):
  51. checkpoint = torch.load(save_path, map_location=device)
  52. model.load_state_dict(checkpoint['model_state_dict'])
  53. if optimizer is not None:
  54. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  55. epoch = checkpoint['epoch']
  56. loss = checkpoint['loss']
  57. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  58. else:
  59. print(f"No saved model found at {save_path}")
  60. return model, optimizer
  61. def writer_loss(self, writer, losses, epoch):
  62. try:
  63. for key, value in losses.items():
  64. if key == 'loss_wirepoint':
  65. for subdict in losses['loss_wirepoint']['losses']:
  66. for subkey, subvalue in subdict.items():
  67. writer.add_scalar(f'loss/{subkey}',
  68. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  69. epoch)
  70. elif isinstance(value, torch.Tensor):
  71. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  72. except Exception as e:
  73. print(f"TensorBoard logging error: {e}")
  74. def train_cfg(self, model:BaseModel, cfg):
  75. # cfg = r'./config/wireframe.yaml'
  76. cfg = read_yaml(cfg)
  77. print(f'cfg:{cfg}')
  78. # print(cfg['n_dyn_negl'])
  79. self.train(model, **cfg)
  80. def train(self, model, **kwargs):
  81. dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
  82. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  83. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  84. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
  85. train_collate_fn = utils.collate_fn_wirepoint
  86. data_loader_train = torch.utils.data.DataLoader(
  87. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  88. )
  89. dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
  90. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  91. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  92. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
  93. val_collate_fn = utils.collate_fn_wirepoint
  94. data_loader_val = torch.utils.data.DataLoader(
  95. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  96. )
  97. # model = linenet_resnet50_fpn().to(self.device)
  98. optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
  99. writer = SummaryWriter(kwargs['io']['logdir'])
  100. model.to(device)
  101. # 加载权重
  102. save_path = 'logs/pth/best_model.pth'
  103. model, optimizer = self.load_best_model(model, optimizer, save_path, device)
  104. logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
  105. os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
  106. latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
  107. best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
  108. global_step = 0
  109. for epoch in range(kwargs['optim']['max_epoch']):
  110. print(f"epoch:{epoch}")
  111. total_train_loss = 0.0
  112. model.train()
  113. for imgs, targets in data_loader_train:
  114. imgs = move_to_device(imgs, device)
  115. targets=move_to_device(targets,device)
  116. # print(f'imgs:{len(imgs)}')
  117. # print(f'targets:{len(targets)}')
  118. losses = model(imgs, targets)
  119. # print(losses)
  120. loss = _loss(losses)
  121. optimizer.zero_grad()
  122. loss.backward()
  123. optimizer.step()
  124. self.writer_loss(writer, losses, epoch)
  125. model.eval()
  126. with torch.no_grad():
  127. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  128. t_start = time.time()
  129. print(f'start to predict:{t_start}')
  130. pred = model(self.move_to_device(imgs, self.device))
  131. # t_end = time.time()
  132. # print(f'predict used:{t_end - t_start}')
  133. # t_start=time.time()
  134. # print(f'start to box_line:{t_start}')
  135. # pred_ = box_line_(pred) # 将box与line对应
  136. # t_end=time.time()
  137. # print(f'box_line_ used:{t_end-t_start}')
  138. # show_(imgs, pred_, epoch, writer)
  139. if batch_idx == 0:
  140. show_line(imgs[0], pred, epoch, writer)
  141. break
  142. avg_train_loss = total_train_loss / len(data_loader_train)
  143. writer.add_scalar('loss/train', avg_train_loss, epoch)
  144. best_loss = 10000
  145. save_latest_model(
  146. model,
  147. latest_model_path,
  148. epoch,
  149. optimizer
  150. )
  151. best_loss = save_best_model(
  152. model,
  153. best_model_path,
  154. epoch,
  155. avg_train_loss,
  156. best_loss,
  157. optimizer
  158. )