trainer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import os
  2. import time
  3. from datetime import datetime
  4. import numpy as np
  5. import torch
  6. from matplotlib import pyplot as plt
  7. from torch.optim.lr_scheduler import ReduceLROnPlateau
  8. from torch.utils.tensorboard import SummaryWriter
  9. from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
  10. from models.base.base_model import BaseModel
  11. from models.base.base_trainer import BaseTrainer
  12. from models.config.config_tool import read_yaml
  13. from models.line_detect.line_dataset import LineDataset
  14. from models.line_net.dataset_LD import WirePointDataset
  15. from models.wirenet.postprocess import postprocess
  16. from tools import utils
  17. from torchvision import transforms
  18. import matplotlib as mpl
  19. cmap = plt.get_cmap("jet")
  20. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  21. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  22. sm.set_array([])
  23. def _loss(losses):
  24. total_loss = 0
  25. for i in losses.keys():
  26. if i != "loss_wirepoint":
  27. total_loss += losses[i]
  28. else:
  29. loss_labels = losses[i]["losses"]
  30. loss_labels_k = list(loss_labels[0].keys())
  31. for j, name in enumerate(loss_labels_k):
  32. loss = loss_labels[0][name].mean()
  33. total_loss += loss
  34. return total_loss
  35. def c(x):
  36. return sm.to_rgba(x)
  37. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  38. import matplotlib.pyplot as plt
  39. from PIL import ImageDraw
  40. from torchvision.transforms import functional as F
  41. import torch
  42. # 由低到高蓝黄红
  43. def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
  44. """
  45. 根据得分对线段着色并绘制
  46. :param tensor_image: (3, H, W) uint8 图像
  47. :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
  48. :param scores: (N,) 每条线的得分,范围 [0, 1]
  49. :param width: 线宽
  50. :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
  51. :return: (3, H, W) uint8 画好线的图像
  52. """
  53. assert tensor_image.dtype == torch.uint8
  54. assert tensor_image.shape[0] == 3
  55. assert lines.shape[0] == scores.shape[0]
  56. # 准备色图
  57. colormap = plt.get_cmap(cmap)
  58. colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8') # 去掉 alpha 通道
  59. # 转为 PIL 画图
  60. image_pil = F.to_pil_image(tensor_image)
  61. draw = ImageDraw.Draw(image_pil)
  62. for line, color in zip(lines, colors):
  63. start = tuple(map(float, line[0][:2].tolist()))
  64. end = tuple(map(float, line[1][:2].tolist()))
  65. draw.line([start, end], fill=tuple(color), width=width)
  66. return (F.to_tensor(image_pil) * 255).to(torch.uint8)
  67. class Trainer(BaseTrainer):
  68. def __init__(self, model=None, **kwargs):
  69. super().__init__(model, device, **kwargs)
  70. self.model = model
  71. # print(f'kwargs:{kwargs}')
  72. self.init_params(**kwargs)
  73. def init_params(self, **kwargs):
  74. if kwargs != {}:
  75. print(f'train_params:{kwargs["train_params"]}')
  76. self.freeze_config = kwargs['train_params']['freeze_params']
  77. print(f'freeze_config:{self.freeze_config}')
  78. self.dataset_path = kwargs['io']['datadir']
  79. self.data_type = kwargs['io']['data_type']
  80. self.batch_size = kwargs['train_params']['batch_size']
  81. self.num_workers = kwargs['train_params']['num_workers']
  82. self.logdir = kwargs['io']['logdir']
  83. self.resume_from = kwargs['train_params']['resume_from']
  84. self.optim = ''
  85. self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
  86. self.wts_path = os.path.join(self.train_result_ptath, 'weights')
  87. self.tb_path = os.path.join(self.train_result_ptath, 'logs')
  88. self.writer = SummaryWriter(self.tb_path)
  89. self.last_model_path = os.path.join(self.wts_path, 'last.pth')
  90. self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
  91. self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
  92. self.max_epoch = kwargs['train_params']['max_epoch']
  93. self.augmentation= kwargs['train_params']["augmentation"]
  94. def move_to_device(self, data, device):
  95. if isinstance(data, (list, tuple)):
  96. return type(data)(self.move_to_device(item, device) for item in data)
  97. elif isinstance(data, dict):
  98. return {key: self.move_to_device(value, device) for key, value in data.items()}
  99. elif isinstance(data, torch.Tensor):
  100. return data.to(device)
  101. else:
  102. return data # 对于非张量类型的数据不做任何改变
  103. def freeze_params(self, model):
  104. """根据配置冻结模型参数"""
  105. default_config = {
  106. 'backbone': True, # 冻结 backbone
  107. 'rpn': False, # 不冻结 rpn
  108. 'roi_heads': {
  109. 'box_head': False,
  110. 'box_predictor': False,
  111. 'line_head': False,
  112. 'line_predictor': {
  113. 'fc1': False,
  114. 'fc2': {
  115. '0': False,
  116. '2': False,
  117. '4': False
  118. }
  119. }
  120. }
  121. }
  122. # 更新默认配置
  123. default_config.update(self.freeze_config)
  124. config = default_config
  125. print("\n===== Parameter Freezing Configuration =====")
  126. for name, module in model.named_children():
  127. if name in config:
  128. if isinstance(config[name], bool):
  129. for param in module.parameters():
  130. param.requires_grad = not config[name]
  131. print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
  132. elif isinstance(config[name], dict):
  133. for subname, submodule in module.named_children():
  134. if subname in config[name]:
  135. if isinstance(config[name][subname], bool):
  136. for param in submodule.parameters():
  137. param.requires_grad = not config[name][subname]
  138. print(
  139. f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
  140. elif isinstance(config[name][subname], dict):
  141. for subsubname, subsubmodule in submodule.named_children():
  142. if subsubname in config[name][subname]:
  143. for param in subsubmodule.parameters():
  144. param.requires_grad = not config[name][subname][subsubname]
  145. print(
  146. f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
  147. # 打印参数统计
  148. total_params = sum(p.numel() for p in model.parameters())
  149. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  150. print(f"\nTotal Parameters: {total_params:,}")
  151. print(f"Trainable Parameters: {trainable_params:,}")
  152. print(f"Frozen Parameters: {total_params - trainable_params:,}")
  153. def load_best_model(self, model, optimizer, save_path, device):
  154. if os.path.exists(save_path):
  155. checkpoint = torch.load(save_path, map_location=device)
  156. model.load_state_dict(checkpoint['model_state_dict'])
  157. if optimizer is not None:
  158. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  159. epoch = checkpoint['epoch']
  160. loss = checkpoint['loss']
  161. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  162. else:
  163. print(f"No saved model found at {save_path}")
  164. return model, optimizer
  165. def writer_predict_result(self, img, result, epoch,type=1):
  166. img = img.cpu().detach()
  167. im = img.permute(1, 2, 0) # [512, 512, 3]
  168. self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
  169. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result["boxes"],
  170. colors="yellow", width=1)
  171. # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
  172. # plt.show()
  173. self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  174. if type==1:
  175. keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
  176. self.writer.add_image("z-output", keypoint_img, epoch)
  177. # print("lines shape:", result['lines'].shape)
  178. if type==2:
  179. # 用自己写的函数画线段
  180. # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
  181. print(f"shape of linescore:{result['liness_scores'].shape}")
  182. scores = result['liness_scores'].mean(dim=1) # shape: [31]
  183. line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'],scores, width=3, cmap='jet')
  184. self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
  185. def writer_loss(self, losses, epoch, phase='train'):
  186. try:
  187. for key, value in losses.items():
  188. if key == 'loss_wirepoint':
  189. for subdict in losses['loss_wirepoint']['losses']:
  190. for subkey, subvalue in subdict.items():
  191. self.writer.add_scalar(f'{phase}/loss/{subkey}',
  192. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  193. epoch)
  194. elif isinstance(value, torch.Tensor):
  195. self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
  196. except Exception as e:
  197. print(f"TensorBoard logging error: {e}")
  198. def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
  199. cfg = read_yaml(cfg)
  200. # print(f'cfg:{cfg}')
  201. # self.freeze_config = freeze_config or {} # 更新冻结配置
  202. self.train(model, **cfg)
  203. def train(self, model, **kwargs):
  204. self.init_params(**kwargs)
  205. dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
  206. dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=False, data_type=self.data_type, dataset_type='val')
  207. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  208. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  209. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
  210. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
  211. train_collate_fn = utils.collate_fn
  212. val_collate_fn = utils.collate_fn
  213. data_loader_train = torch.utils.data.DataLoader(
  214. dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
  215. )
  216. data_loader_val = torch.utils.data.DataLoader(
  217. dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
  218. )
  219. model.to(device)
  220. optimizer = torch.optim.Adam(
  221. filter(lambda p: p.requires_grad, model.parameters()),
  222. lr=kwargs['train_params']['optim']['lr'],
  223. weight_decay=kwargs['train_params']['optim']['weight_decay'],
  224. )
  225. # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
  226. scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)
  227. for epoch in range(self.max_epoch):
  228. print(f"train epoch:{epoch}")
  229. model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
  230. scheduler.step(epoch_train_loss)
  231. # ========== Validation ==========
  232. with torch.no_grad():
  233. model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
  234. scheduler.step(epoch_val_loss)
  235. if epoch==0:
  236. best_train_loss = epoch_train_loss
  237. best_val_loss = epoch_val_loss
  238. self.save_last_model(model,self.last_model_path, epoch, optimizer)
  239. best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
  240. best_train_loss,
  241. optimizer)
  242. best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
  243. optimizer)
  244. def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
  245. if phase == 'train':
  246. model.train()
  247. if phase == 'val':
  248. model.eval()
  249. total_loss = 0
  250. epoch_step = 0
  251. global_step = epoch * len(data_loader)
  252. for imgs, targets in data_loader:
  253. imgs = self.move_to_device(imgs, device)
  254. targets = self.move_to_device(targets, device)
  255. if phase== 'val':
  256. result,loss_dict = model(imgs, targets)
  257. losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
  258. print(f'val losses:{losses}')
  259. print(f'val result:{result}')
  260. else:
  261. loss_dict = model(imgs, targets)
  262. losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
  263. print(f'train losses:{losses}')
  264. # loss = _loss(losses)
  265. loss=losses
  266. total_loss += loss.item()
  267. if phase == 'train':
  268. optimizer.zero_grad()
  269. loss.backward()
  270. optimizer.step()
  271. self.writer_loss(loss_dict, global_step, phase=phase)
  272. global_step += 1
  273. if epoch_step == 0 and phase == 'val':
  274. t_start = time.time()
  275. print(f'start to predict:{t_start}')
  276. result = model(self.move_to_device(imgs, self.device))
  277. print(f'result:{result}')
  278. t_end = time.time()
  279. print(f'predict used:{t_end - t_start}')
  280. self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
  281. epoch_step+=1
  282. avg_loss = total_loss / len(data_loader)
  283. print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
  284. self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
  285. return model, avg_loss
  286. def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
  287. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  288. if current_loss <= best_loss:
  289. checkpoint = {
  290. 'epoch': epoch,
  291. 'model_state_dict': model.state_dict(),
  292. 'loss': current_loss
  293. }
  294. if optimizer is not None:
  295. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  296. torch.save(checkpoint, save_path)
  297. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  298. return current_loss
  299. return best_loss
  300. def save_last_model(self, model, save_path, epoch, optimizer=None):
  301. if os.path.exists(f'{self.wts_path}/last.pt'):
  302. os.remove(f'{self.wts_path}/last.pt')
  303. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  304. checkpoint = {
  305. 'epoch': epoch,
  306. 'model_state_dict': model.state_dict(),
  307. }
  308. if optimizer is not None:
  309. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  310. torch.save(checkpoint, save_path)
  311. if __name__ == '__main__':
  312. print('')