trainer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import os
  2. import time
  3. from datetime import datetime
  4. import numpy as np
  5. import torch
  6. from matplotlib import pyplot as plt
  7. from torch.utils.tensorboard import SummaryWriter
  8. from libs.vision_libs.utils import draw_bounding_boxes
  9. from models.base.base_model import BaseModel
  10. from models.base.base_trainer import BaseTrainer
  11. from models.config.config_tool import read_yaml
  12. from models.line_detect.dataset_LD import WirePointDataset
  13. from models.wirenet.postprocess import postprocess
  14. from tools import utils
  15. from torchvision import transforms
  16. import matplotlib as mpl
  17. cmap = plt.get_cmap("jet")
  18. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  19. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  20. sm.set_array([])
  21. def _loss(losses):
  22. total_loss = 0
  23. for i in losses.keys():
  24. if i != "loss_wirepoint":
  25. total_loss += losses[i]
  26. else:
  27. loss_labels = losses[i]["losses"]
  28. loss_labels_k = list(loss_labels[0].keys())
  29. for j, name in enumerate(loss_labels_k):
  30. loss = loss_labels[0][name].mean()
  31. total_loss += loss
  32. return total_loss
  33. def c(x):
  34. return sm.to_rgba(x)
  35. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  36. class Trainer(BaseTrainer):
  37. def __init__(self, model=None, **kwargs):
  38. super().__init__(model, device, **kwargs)
  39. self.model = model
  40. # print(f'kwargs:{kwargs}')
  41. self.init_params(**kwargs)
  42. def init_params(self, **kwargs):
  43. if kwargs != {}:
  44. print(f'train_params:{kwargs["train_params"]}')
  45. self.freeze_config = kwargs['train_params']['freeze_params']
  46. print(f'freeze_config:{self.freeze_config}')
  47. self.dataset_path = kwargs['io']['datadir']
  48. self.batch_size = kwargs['train_params']['batch_size']
  49. self.num_workers = kwargs['train_params']['num_workers']
  50. self.logdir = kwargs['io']['logdir']
  51. self.resume_from = kwargs['train_params']['resume_from']
  52. self.optim = ''
  53. self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
  54. self.wts_path = os.path.join(self.train_result_ptath, 'weights')
  55. self.tb_path = os.path.join(self.train_result_ptath, 'logs')
  56. self.writer = SummaryWriter(self.tb_path)
  57. self.last_model_path = os.path.join(self.wts_path, 'last.pth')
  58. self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
  59. self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
  60. self.max_epoch = kwargs['train_params']['max_epoch']
  61. def move_to_device(self, data, device):
  62. if isinstance(data, (list, tuple)):
  63. return type(data)(self.move_to_device(item, device) for item in data)
  64. elif isinstance(data, dict):
  65. return {key: self.move_to_device(value, device) for key, value in data.items()}
  66. elif isinstance(data, torch.Tensor):
  67. return data.to(device)
  68. else:
  69. return data # 对于非张量类型的数据不做任何改变
  70. def freeze_params(self, model):
  71. """根据配置冻结模型参数"""
  72. default_config = {
  73. 'backbone': True, # 冻结 backbone
  74. 'rpn': False, # 不冻结 rpn
  75. 'roi_heads': {
  76. 'box_head': False,
  77. 'box_predictor': False,
  78. 'line_head': False,
  79. 'line_predictor': {
  80. 'fc1': False,
  81. 'fc2': {
  82. '0': False,
  83. '2': False,
  84. '4': False
  85. }
  86. }
  87. }
  88. }
  89. # 更新默认配置
  90. default_config.update(self.freeze_config)
  91. config = default_config
  92. print("\n===== Parameter Freezing Configuration =====")
  93. for name, module in model.named_children():
  94. if name in config:
  95. if isinstance(config[name], bool):
  96. for param in module.parameters():
  97. param.requires_grad = not config[name]
  98. print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
  99. elif isinstance(config[name], dict):
  100. for subname, submodule in module.named_children():
  101. if subname in config[name]:
  102. if isinstance(config[name][subname], bool):
  103. for param in submodule.parameters():
  104. param.requires_grad = not config[name][subname]
  105. print(
  106. f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
  107. elif isinstance(config[name][subname], dict):
  108. for subsubname, subsubmodule in submodule.named_children():
  109. if subsubname in config[name][subname]:
  110. for param in subsubmodule.parameters():
  111. param.requires_grad = not config[name][subname][subsubname]
  112. print(
  113. f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
  114. # 打印参数统计
  115. total_params = sum(p.numel() for p in model.parameters())
  116. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  117. print(f"\nTotal Parameters: {total_params:,}")
  118. print(f"Trainable Parameters: {trainable_params:,}")
  119. print(f"Frozen Parameters: {total_params - trainable_params:,}")
  120. def load_best_model(self, model, optimizer, save_path, device):
  121. if os.path.exists(save_path):
  122. checkpoint = torch.load(save_path, map_location=device)
  123. model.load_state_dict(checkpoint['model_state_dict'])
  124. if optimizer is not None:
  125. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  126. epoch = checkpoint['epoch']
  127. loss = checkpoint['loss']
  128. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  129. else:
  130. print(f"No saved model found at {save_path}")
  131. return model, optimizer
  132. def writer_predict_result(self, img, result, epoch):
  133. img = img.cpu().detach()
  134. im = img.permute(1, 2, 0)
  135. self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
  136. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
  137. colors="yellow", width=1)
  138. self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  139. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  140. # print(f'pred[1]:{pred[1]}')
  141. heatmaps = result[-2][0]
  142. print(f'heatmaps:{heatmaps.shape}')
  143. jmap = heatmaps[1: 2].cpu().detach()
  144. lmap = heatmaps[2: 3].cpu().detach()
  145. self.writer.add_image("z-jmap", jmap, epoch)
  146. self.writer.add_image("z-lmap", lmap, epoch)
  147. # plt.imshow(lmap)
  148. # plt.show()
  149. H = result[-1]['wires']
  150. lines = H["lines"][0].cpu().numpy()
  151. scores = H["score"][0].cpu().numpy()
  152. for i in range(1, len(lines)):
  153. if (lines[i] == lines[0]).all():
  154. lines = lines[:i]
  155. scores = scores[:i]
  156. break
  157. # postprocess lines to remove overlapped lines
  158. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  159. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  160. for i, t in enumerate([0]):
  161. plt.gca().set_axis_off()
  162. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  163. plt.margins(0, 0)
  164. for (a, b), s in zip(nlines, nscores):
  165. if s < t:
  166. continue
  167. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  168. plt.scatter(a[1], a[0], **PLTOPTS)
  169. plt.scatter(b[1], b[0], **PLTOPTS)
  170. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  171. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  172. plt.imshow(im)
  173. plt.tight_layout()
  174. fig = plt.gcf()
  175. fig.canvas.draw()
  176. width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
  177. tmp_img = fig.canvas.tostring_argb()
  178. tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
  179. tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
  180. img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
  181. # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
  182. # fig.canvas.get_width_height()[::-1] + (3,))
  183. plt.close()
  184. img2 = transforms.ToTensor()(img_rgb)
  185. self.writer.add_image("z-output", img2, epoch)
  186. def writer_loss(self, losses, epoch, phase='train'):
  187. try:
  188. for key, value in losses.items():
  189. if key == 'loss_wirepoint':
  190. for subdict in losses['loss_wirepoint']['losses']:
  191. for subkey, subvalue in subdict.items():
  192. self.writer.add_scalar(f'{phase}/loss/{subkey}',
  193. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  194. epoch)
  195. elif isinstance(value, torch.Tensor):
  196. self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
  197. except Exception as e:
  198. print(f"TensorBoard logging error: {e}")
  199. def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
  200. cfg = read_yaml(cfg)
  201. # print(f'cfg:{cfg}')
  202. # self.freeze_config = freeze_config or {} # 更新冻结配置
  203. self.train(model, **cfg)
  204. def train(self, model, **kwargs):
  205. self.init_params(**kwargs)
  206. dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
  207. dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
  208. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  209. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  210. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
  211. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
  212. train_collate_fn = utils.collate_fn
  213. val_collate_fn = utils.collate_fn
  214. data_loader_train = torch.utils.data.DataLoader(
  215. dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
  216. )
  217. data_loader_val = torch.utils.data.DataLoader(
  218. dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
  219. )
  220. model.to(device)
  221. optimizer = torch.optim.Adam(
  222. filter(lambda p: p.requires_grad, model.parameters()),
  223. lr=kwargs['train_params']['optim']['lr']
  224. )
  225. for epoch in range(self.max_epoch):
  226. print(f"train epoch:{epoch}")
  227. model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
  228. # ========== Validation ==========
  229. with torch.no_grad():
  230. model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
  231. if epoch==0:
  232. best_train_loss = epoch_train_loss
  233. best_val_loss = epoch_val_loss
  234. self.save_last_model(model,self.last_model_path, epoch, optimizer)
  235. best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
  236. best_train_loss,
  237. optimizer)
  238. best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
  239. optimizer)
  240. def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
  241. if phase == 'train':
  242. model.train()
  243. if phase == 'val':
  244. model.eval()
  245. total_loss = 0
  246. epoch_step = 0
  247. global_step = epoch_step * len(data_loader)
  248. for imgs, targets in data_loader:
  249. imgs = self.move_to_device(imgs, device)
  250. targets = self.move_to_device(targets, device)
  251. if phase== 'val':
  252. result,losses = model(imgs, targets)
  253. else:
  254. losses = model(imgs, targets)
  255. loss = _loss(losses)
  256. total_loss += loss.item()
  257. if phase == 'train':
  258. optimizer.zero_grad()
  259. loss.backward()
  260. optimizer.step()
  261. self.writer_loss(losses, global_step, phase=phase)
  262. global_step += 1
  263. if epoch_step == 0 and phase == 'val':
  264. t_start = time.time()
  265. print(f'start to predict:{t_start}')
  266. result = model(self.move_to_device(imgs, self.device))
  267. t_end = time.time()
  268. print(f'predict used:{t_end - t_start}')
  269. self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
  270. epoch_step+=1
  271. avg_loss = total_loss / len(data_loader)
  272. print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
  273. self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
  274. return model, avg_loss
  275. def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
  276. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  277. if current_loss <= best_loss:
  278. checkpoint = {
  279. 'epoch': epoch,
  280. 'model_state_dict': model.state_dict(),
  281. 'loss': current_loss
  282. }
  283. if optimizer is not None:
  284. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  285. torch.save(checkpoint, save_path)
  286. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  287. return current_loss
  288. return best_loss
  289. def save_last_model(self, model, save_path, epoch, optimizer=None):
  290. if os.path.exists(f'{self.wts_path}/last.pt'):
  291. os.remove(f'{self.wts_path}/last.pt')
  292. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  293. checkpoint = {
  294. 'epoch': epoch,
  295. 'model_state_dict': model.state_dict(),
  296. }
  297. if optimizer is not None:
  298. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  299. torch.save(checkpoint, save_path)
  300. if __name__ == '__main__':
  301. print('')