trainer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import os
  2. import torch
  3. from torch.utils.tensorboard import SummaryWriter
  4. from models.base.base_model import BaseModel
  5. from models.base.base_trainer import BaseTrainer
  6. from models.config.config_tool import read_yaml
  7. from models.line_detect.dataset_LD import WirePointDataset
  8. from models.line_detect.postprocess import box_line_, show_
  9. from utils.log_util import show_line, save_latest_model, save_best_model
  10. from tools import utils
  11. def _loss(losses):
  12. total_loss = 0
  13. for i in losses.keys():
  14. if i != "loss_wirepoint":
  15. total_loss += losses[i]
  16. else:
  17. loss_labels = losses[i]["losses"]
  18. loss_labels_k = list(loss_labels[0].keys())
  19. for j, name in enumerate(loss_labels_k):
  20. loss = loss_labels[0][name].mean()
  21. total_loss += loss
  22. return total_loss
  23. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  24. def move_to_device(data, device):
  25. if isinstance(data, (list, tuple)):
  26. return type(data)(move_to_device(item, device) for item in data)
  27. elif isinstance(data, dict):
  28. return {key: move_to_device(value, device) for key, value in data.items()}
  29. elif isinstance(data, torch.Tensor):
  30. return data.to(device)
  31. else:
  32. return data # 对于非张量类型的数据不做任何改变
  33. class Trainer(BaseTrainer):
  34. def __init__(self, model=None,
  35. dataset=None,
  36. device='cuda',
  37. **kwargs):
  38. super().__init__(model,dataset,device,**kwargs)
  39. def move_to_device(self, data, device):
  40. if isinstance(data, (list, tuple)):
  41. return type(data)(self.move_to_device(item, device) for item in data)
  42. elif isinstance(data, dict):
  43. return {key: self.move_to_device(value, device) for key, value in data.items()}
  44. elif isinstance(data, torch.Tensor):
  45. return data.to(device)
  46. else:
  47. return data # 对于非张量类型的数据不做任何改变
  48. def load_best_model(self,model, optimizer, save_path, device):
  49. if os.path.exists(save_path):
  50. checkpoint = torch.load(save_path, map_location=device)
  51. model.load_state_dict(checkpoint['model_state_dict'])
  52. if optimizer is not None:
  53. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  54. epoch = checkpoint['epoch']
  55. loss = checkpoint['loss']
  56. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  57. else:
  58. print(f"No saved model found at {save_path}")
  59. return model, optimizer
  60. def writer_loss(self, writer, losses, epoch):
  61. try:
  62. for key, value in losses.items():
  63. if key == 'loss_wirepoint':
  64. for subdict in losses['loss_wirepoint']['losses']:
  65. for subkey, subvalue in subdict.items():
  66. writer.add_scalar(f'loss/{subkey}',
  67. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  68. epoch)
  69. elif isinstance(value, torch.Tensor):
  70. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  71. except Exception as e:
  72. print(f"TensorBoard logging error: {e}")
  73. def train_cfg(self, model:BaseModel, cfg):
  74. # cfg = r'./config/wireframe.yaml'
  75. cfg = read_yaml(cfg)
  76. print(f'cfg:{cfg}')
  77. # print(cfg['n_dyn_negl'])
  78. self.train(model, **cfg)
  79. def train(self, model, **kwargs):
  80. dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
  81. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  82. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  83. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
  84. train_collate_fn = utils.collate_fn_wirepoint
  85. data_loader_train = torch.utils.data.DataLoader(
  86. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  87. )
  88. dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
  89. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  90. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  91. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
  92. val_collate_fn = utils.collate_fn_wirepoint
  93. data_loader_val = torch.utils.data.DataLoader(
  94. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  95. )
  96. # model = linenet_resnet50_fpn().to(self.device)
  97. optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
  98. writer = SummaryWriter(kwargs['io']['logdir'])
  99. model.to(device)
  100. # 加载权重
  101. save_path = 'logs/pth/best_model.pth'
  102. model, optimizer = self.load_best_model(model, optimizer, save_path, device)
  103. logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
  104. os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
  105. latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
  106. best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
  107. global_step = 0
  108. for epoch in range(kwargs['optim']['max_epoch']):
  109. print(f"epoch:{epoch}")
  110. total_train_loss = 0.0
  111. model.train()
  112. for imgs, targets in data_loader_train:
  113. imgs = move_to_device(imgs, device)
  114. targets=move_to_device(targets,device)
  115. # print(f'imgs:{len(imgs)}')
  116. # print(f'targets:{len(targets)}')
  117. losses = model(imgs, targets)
  118. # print(losses)
  119. loss = _loss(losses)
  120. optimizer.zero_grad()
  121. loss.backward()
  122. optimizer.step()
  123. self.writer_loss(writer, losses, epoch)
  124. model.eval()
  125. with torch.no_grad():
  126. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  127. pred = model(self.move_to_device(imgs, self.device))
  128. pred_ = box_line_(pred) # 将box与line对应
  129. show_(imgs, pred_, epoch, writer)
  130. if batch_idx == 0:
  131. show_line(imgs[0], pred, epoch, writer)
  132. break
  133. avg_train_loss = total_train_loss / len(data_loader_train)
  134. writer.add_scalar('loss/train', avg_train_loss, epoch)
  135. best_loss = 10000
  136. save_latest_model(
  137. model,
  138. latest_model_path,
  139. epoch,
  140. optimizer
  141. )
  142. best_loss = save_best_model(
  143. model,
  144. best_model_path,
  145. epoch,
  146. avg_train_loss,
  147. best_loss,
  148. optimizer
  149. )