trainer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import os
  2. import time
  3. from datetime import datetime
  4. import torch
  5. from torch.utils.tensorboard import SummaryWriter
  6. from models.base.base_model import BaseModel
  7. from models.base.base_trainer import BaseTrainer
  8. from models.config.config_tool import read_yaml
  9. from models.line_detect.dataset_LD import WirePointDataset
  10. from models.line_detect.postprocess import box_line_, show_
  11. from utils.log_util import show_line, save_last_model, save_best_model
  12. from tools import utils
  13. def _loss(losses):
  14. total_loss = 0
  15. for i in losses.keys():
  16. if i != "loss_wirepoint":
  17. total_loss += losses[i]
  18. else:
  19. loss_labels = losses[i]["losses"]
  20. loss_labels_k = list(loss_labels[0].keys())
  21. for j, name in enumerate(loss_labels_k):
  22. loss = loss_labels[0][name].mean()
  23. total_loss += loss
  24. return total_loss
  25. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  26. def move_to_device(data, device):
  27. if isinstance(data, (list, tuple)):
  28. return type(data)(move_to_device(item, device) for item in data)
  29. elif isinstance(data, dict):
  30. return {key: move_to_device(value, device) for key, value in data.items()}
  31. elif isinstance(data, torch.Tensor):
  32. return data.to(device)
  33. else:
  34. return data # 对于非张量类型的数据不做任何改变
  35. class Trainer(BaseTrainer):
  36. def __init__(self, model=None,
  37. dataset=None,
  38. device='cuda',
  39. **kwargs):
  40. super().__init__(model, dataset, device, **kwargs)
  41. def move_to_device(self, data, device):
  42. if isinstance(data, (list, tuple)):
  43. return type(data)(self.move_to_device(item, device) for item in data)
  44. elif isinstance(data, dict):
  45. return {key: self.move_to_device(value, device) for key, value in data.items()}
  46. elif isinstance(data, torch.Tensor):
  47. return data.to(device)
  48. else:
  49. return data # 对于非张量类型的数据不做任何改变
  50. def load_best_model(self, model, optimizer, save_path, device):
  51. if os.path.exists(save_path):
  52. checkpoint = torch.load(save_path, map_location=device)
  53. model.load_state_dict(checkpoint['model_state_dict'])
  54. if optimizer is not None:
  55. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  56. epoch = checkpoint['epoch']
  57. loss = checkpoint['loss']
  58. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  59. else:
  60. print(f"No saved model found at {save_path}")
  61. return model, optimizer
  62. def writer_loss(self, writer, losses, epoch):
  63. try:
  64. for key, value in losses.items():
  65. if key == 'loss_wirepoint':
  66. for subdict in losses['loss_wirepoint']['losses']:
  67. for subkey, subvalue in subdict.items():
  68. writer.add_scalar(f'loss/{subkey}',
  69. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  70. epoch)
  71. elif isinstance(value, torch.Tensor):
  72. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  73. except Exception as e:
  74. print(f"TensorBoard logging error: {e}")
  75. def train_cfg(self, model: BaseModel, cfg):
  76. # cfg = r'./config/wireframe.yaml'
  77. cfg = read_yaml(cfg)
  78. print(f'cfg:{cfg}')
  79. # print(cfg['n_dyn_negl'])
  80. self.train(model, **cfg)
  81. # def train(self, model, **kwargs):
  82. # dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
  83. # train_sampler = torch.utils.data.RandomSampler(dataset_train)
  84. # # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  85. # train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
  86. # train_collate_fn = utils.collate_fn_wirepoint
  87. # data_loader_train = torch.utils.data.DataLoader(
  88. # dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  89. # )
  90. #
  91. # dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
  92. # val_sampler = torch.utils.data.RandomSampler(dataset_val)
  93. # # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  94. # val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
  95. # val_collate_fn = utils.collate_fn_wirepoint
  96. # data_loader_val = torch.utils.data.DataLoader(
  97. # dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  98. # )
  99. #
  100. # train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
  101. # wts_path = os.path.join(train_result_ptath, 'weights')
  102. # tb_path = os.path.join(train_result_ptath, 'logs')
  103. # writer = SummaryWriter(tb_path)
  104. #
  105. # optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
  106. # # writer = SummaryWriter(kwargs['io']['logdir'])
  107. # model.to(device)
  108. #
  109. #
  110. #
  111. # # # 加载权重
  112. # # save_path = 'logs/pth/best_model.pth'
  113. # # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
  114. #
  115. # # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
  116. # # os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
  117. # last_model_path = os.path.join(wts_path, 'last.pth')
  118. # best_model_path = os.path.join(wts_path, 'best.pth')
  119. # global_step = 0
  120. #
  121. # for epoch in range(kwargs['optim']['max_epoch']):
  122. # print(f"epoch:{epoch}")
  123. # total_train_loss = 0.0
  124. #
  125. # model.train()
  126. #
  127. # for imgs, targets in data_loader_train:
  128. # imgs = move_to_device(imgs, device)
  129. # targets=move_to_device(targets,device)
  130. # # print(f'imgs:{len(imgs)}')
  131. # # print(f'targets:{len(targets)}')
  132. # losses = model(imgs, targets)
  133. # loss = _loss(losses)
  134. # total_train_loss += loss.item()
  135. # optimizer.zero_grad()
  136. # loss.backward()
  137. # optimizer.step()
  138. # self.writer_loss(writer, losses, global_step)
  139. # global_step+=1
  140. #
  141. #
  142. # avg_train_loss = total_train_loss / len(data_loader_train)
  143. # if epoch == 0:
  144. # best_loss = avg_train_loss;
  145. #
  146. # writer.add_scalar('loss/train', avg_train_loss, epoch)
  147. #
  148. #
  149. # if os.path.exists(f'{wts_path}/last.pt'):
  150. # os.remove(f'{wts_path}/last.pt')
  151. # # torch.save(model.state_dict(), f'{wts_path}/last.pt')
  152. # save_last_model(model,last_model_path,epoch,optimizer)
  153. # best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
  154. #
  155. # model.eval()
  156. # with torch.no_grad():
  157. # for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  158. # t_start = time.time()
  159. # print(f'start to predict:{t_start}')
  160. # pred = model(self.move_to_device(imgs, self.device))
  161. # t_end = time.time()
  162. # print(f'predict used:{t_end - t_start}')
  163. # if batch_idx == 0:
  164. # show_line(imgs[0], pred, epoch, writer)
  165. # break
  166. def train(self, model, **kwargs):
  167. default_params = {
  168. 'io': {
  169. 'logdir': 'logs /',
  170. 'datadir': '/ root / autodl - tmp / wirenet_rgb_gray',
  171. 'num_workers': 8,
  172. 'tensorboard_port': 6000,
  173. 'validation_interval': 300,
  174. 'batch_size': 4,
  175. 'batch_size_eval': 2,
  176. },
  177. 'optim':{
  178. 'name': 'Adam',
  179. 'lr': 4.0e-4,
  180. 'amsgrad': True,
  181. 'weight_decay': 1.0e-4,
  182. 'max_epoch': 90000000,
  183. 'lr_decay_epoch': 10,
  184. },
  185. }
  186. # 更新默认参数
  187. for key, value in kwargs.items():
  188. if key in default_params:
  189. default_params[key] = value
  190. else:
  191. raise ValueError(f"Unknown argument: {key}")
  192. # 解析参数
  193. dataset_path = default_params['io']['datadir']
  194. num_workers = default_params['io']['num_workers']
  195. batch_size_train = default_params['io']['batch_size']
  196. batch_size_eval = default_params['io']['batch_size_eval']
  197. epochs = default_params['optim']['max_epoch']
  198. lr = default_params['optim']['lr']
  199. dataset_train = WirePointDataset(dataset_path=dataset_path, dataset_type='train')
  200. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  201. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  202. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=batch_size_train, drop_last=True)
  203. train_collate_fn = utils.collate_fn_wirepoint
  204. data_loader_train = torch.utils.data.DataLoader(
  205. dataset_train, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
  206. )
  207. dataset_val = WirePointDataset(dataset_path=dataset_path, dataset_type='val')
  208. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  209. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  210. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=batch_size_eval, drop_last=True)
  211. val_collate_fn = utils.collate_fn_wirepoint
  212. data_loader_val = torch.utils.data.DataLoader(
  213. dataset_val, batch_sampler=val_batch_sampler, num_workers=num_workers, collate_fn=val_collate_fn
  214. )
  215. train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
  216. wts_path = os.path.join(train_result_ptath, 'weights')
  217. tb_path = os.path.join(train_result_ptath, 'logs')
  218. writer = SummaryWriter(tb_path)
  219. optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  220. # writer = SummaryWriter(kwargs['io']['logdir'])
  221. model.to(device)
  222. # # 加载权重
  223. # save_path = 'logs/pth/best_model.pth'
  224. # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
  225. # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
  226. # os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
  227. last_model_path = os.path.join(wts_path, 'last.pth')
  228. best_model_path = os.path.join(wts_path, 'best.pth')
  229. global_step = 0
  230. for epoch in range(epochs):
  231. print(f"epoch:{epoch}")
  232. total_train_loss = 0.0
  233. model.train()
  234. for imgs, targets in data_loader_train:
  235. imgs = move_to_device(imgs, device)
  236. targets = move_to_device(targets, device)
  237. # print(f'imgs:{len(imgs)}')
  238. # print(f'targets:{len(targets)}')
  239. losses = model(imgs, targets)
  240. loss = _loss(losses)
  241. total_train_loss += loss.item()
  242. optimizer.zero_grad()
  243. loss.backward()
  244. optimizer.step()
  245. self.writer_loss(writer, losses, global_step)
  246. global_step += 1
  247. avg_train_loss = total_train_loss / len(data_loader_train)
  248. if epoch == 0:
  249. best_loss = avg_train_loss;
  250. writer.add_scalar('loss/train', avg_train_loss, epoch)
  251. if os.path.exists(f'{wts_path}/last.pt'):
  252. os.remove(f'{wts_path}/last.pt')
  253. # torch.save(model.state_dict(), f'{wts_path}/last.pt')
  254. save_last_model(model, last_model_path, epoch, optimizer)
  255. best_loss = save_best_model(model, best_model_path, epoch, avg_train_loss, best_loss, optimizer)
  256. model.eval()
  257. with torch.no_grad():
  258. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  259. t_start = time.time()
  260. print(f'start to predict:{t_start}')
  261. pred = model(self.move_to_device(imgs, self.device))
  262. t_end = time.time()
  263. print(f'predict used:{t_end - t_start}')
  264. if batch_idx == 0:
  265. show_line(imgs[0], pred, epoch, writer)
  266. break