111.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import os
  2. import time
  3. from datetime import datetime
  4. import torch
  5. from torch.utils.tensorboard import SummaryWriter
  6. from models.base.base_model import BaseModel
  7. from models.base.base_trainer import BaseTrainer
  8. from models.config.config_tool import read_yaml
  9. from models.line_detect.dataset_LD import WirePointDataset
  10. from models.line_detect.postprocess import box_line_, show_
  11. from utils.log_util import show_line, save_last_model, save_best_model
  12. from tools import utils
  13. def _loss(losses):
  14. total_loss = 0
  15. for i in losses.keys():
  16. if i != "loss_wirepoint":
  17. total_loss += losses[i]
  18. else:
  19. loss_labels = losses[i]["losses"]
  20. loss_labels_k = list(loss_labels[0].keys())
  21. for j, name in enumerate(loss_labels_k):
  22. loss = loss_labels[0][name].mean()
  23. total_loss += loss
  24. return total_loss
  25. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  26. def move_to_device(data, device):
  27. if isinstance(data, (list, tuple)):
  28. return type(data)(move_to_device(item, device) for item in data)
  29. elif isinstance(data, dict):
  30. return {key: move_to_device(value, device) for key, value in data.items()}
  31. elif isinstance(data, torch.Tensor):
  32. return data.to(device)
  33. else:
  34. return data # 对于非张量类型的数据不做任何改变
  35. class Trainer(BaseTrainer):
  36. def __init__(self, model=None,
  37. dataset=None,
  38. device='cuda',
  39. freeze_config=None, # 新增:冻结参数配置
  40. **kwargs):
  41. super().__init__(model, dataset, device, **kwargs)
  42. self.freeze_config = freeze_config or {} # 默认冻结配置为空
  43. def move_to_device(self, data, device):
  44. if isinstance(data, (list, tuple)):
  45. return type(data)(self.move_to_device(item, device) for item in data)
  46. elif isinstance(data, dict):
  47. return {key: self.move_to_device(value, device) for key, value in data.items()}
  48. elif isinstance(data, torch.Tensor):
  49. return data.to(device)
  50. else:
  51. return data # 对于非张量类型的数据不做任何改变
  52. def freeze_params(self, model):
  53. """根据配置冻结模型参数"""
  54. default_config = {
  55. 'backbone': True, # 冻结 backbone
  56. 'rpn': False, # 不冻结 rpn
  57. 'roi_heads': {
  58. 'box_head': False,
  59. 'box_predictor': False,
  60. 'line_head': False,
  61. 'line_predictor': {
  62. 'fc1': False,
  63. 'fc2': {
  64. '0': False,
  65. '2': False,
  66. '4': False
  67. }
  68. }
  69. }
  70. }
  71. # 更新默认配置
  72. default_config.update(self.freeze_config)
  73. config = default_config
  74. print("\n===== Parameter Freezing Configuration =====")
  75. for name, module in model.named_children():
  76. if name in config:
  77. if isinstance(config[name], bool):
  78. for param in module.parameters():
  79. param.requires_grad = not config[name]
  80. print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
  81. elif isinstance(config[name], dict):
  82. for subname, submodule in module.named_children():
  83. if subname in config[name]:
  84. if isinstance(config[name][subname], bool):
  85. for param in submodule.parameters():
  86. param.requires_grad = not config[name][subname]
  87. print(
  88. f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
  89. elif isinstance(config[name][subname], dict):
  90. for subsubname, subsubmodule in submodule.named_children():
  91. if subsubname in config[name][subname]:
  92. for param in subsubmodule.parameters():
  93. param.requires_grad = not config[name][subname][subsubname]
  94. print(
  95. f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
  96. # 打印参数统计
  97. total_params = sum(p.numel() for p in model.parameters())
  98. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  99. print(f"\nTotal Parameters: {total_params:,}")
  100. print(f"Trainable Parameters: {trainable_params:,}")
  101. print(f"Frozen Parameters: {total_params - trainable_params:,}")
  102. def load_best_model(self, model, optimizer, save_path, device):
  103. if os.path.exists(save_path):
  104. checkpoint = torch.load(save_path, map_location=device)
  105. model.load_state_dict(checkpoint['model_state_dict'])
  106. if optimizer is not None:
  107. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  108. epoch = checkpoint['epoch']
  109. loss = checkpoint['loss']
  110. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  111. else:
  112. print(f"No saved model found at {save_path}")
  113. return model, optimizer
  114. def writer_loss(self, writer, losses, epoch):
  115. try:
  116. for key, value in losses.items():
  117. if key == 'loss_wirepoint':
  118. for subdict in losses['loss_wirepoint']['losses']:
  119. for subkey, subvalue in subdict.items():
  120. writer.add_scalar(f'loss/{subkey}',
  121. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  122. epoch)
  123. elif isinstance(value, torch.Tensor):
  124. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  125. except Exception as e:
  126. print(f"TensorBoard logging error: {e}")
  127. def train_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
  128. cfg = read_yaml(cfg)
  129. self.freeze_config = freeze_config or {} # 更新冻结配置
  130. self.train(model, **cfg)
  131. def train(self, model, **kwargs):
  132. dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
  133. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  134. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
  135. train_collate_fn = utils.collate_fn_wirepoint
  136. data_loader_train = torch.utils.data.DataLoader(
  137. dataset_train, batch_sampler=train_batch_sampler, num_workers=1, collate_fn=train_collate_fn
  138. )
  139. dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
  140. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  141. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=4, drop_last=True)
  142. val_collate_fn = utils.collate_fn_wirepoint
  143. data_loader_val = torch.utils.data.DataLoader(
  144. dataset_val, batch_sampler=val_batch_sampler, num_workers=1, collate_fn=val_collate_fn
  145. )
  146. train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
  147. wts_path = os.path.join(train_result_ptath, 'weights')
  148. tb_path = os.path.join(train_result_ptath, 'logs')
  149. writer = SummaryWriter(tb_path)
  150. model.to(device)
  151. # # 加载权重
  152. # save_path =r"F:\BaiduNetdiskDownload\r50fpn_wts_e350\best.pth"
  153. # model, _ = self.load_best_model(model, None, save_path, device)
  154. # 冻结参数
  155. # self.freeze_params(model)
  156. # 初始化优化器(仅训练未冻结参数)
  157. optimizer = torch.optim.Adam(
  158. filter(lambda p: p.requires_grad, model.parameters()),
  159. lr=kwargs['optim']['lr']
  160. )
  161. last_model_path = os.path.join(wts_path, 'last.pth')
  162. best_model_path = os.path.join(wts_path, 'best.pth')
  163. global_step = 0
  164. for epoch in range(kwargs['optim']['max_epoch']):
  165. print(f"epoch:{epoch}")
  166. total_train_loss = 0.0
  167. model.train()
  168. for imgs, targets in data_loader_train:
  169. imgs = move_to_device(imgs, device)
  170. targets = move_to_device(targets, device)
  171. losses = model(imgs, targets)
  172. loss = _loss(losses)
  173. total_train_loss += loss.item()
  174. optimizer.zero_grad()
  175. loss.backward()
  176. optimizer.step()
  177. self.writer_loss(writer, losses, global_step)
  178. global_step += 1
  179. avg_train_loss = total_train_loss / len(data_loader_train)
  180. if epoch == 0:
  181. best_loss = avg_train_loss
  182. writer.add_scalar('loss/train', avg_train_loss, epoch)
  183. if os.path.exists(f'{wts_path}/last.pt'):
  184. os.remove(f'{wts_path}/last.pt')
  185. save_last_model(model, last_model_path, epoch, optimizer)
  186. best_loss = save_best_model(model, best_model_path, epoch, avg_train_loss, best_loss, optimizer)
  187. model.eval()
  188. with torch.no_grad():
  189. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  190. t_start = time.time()
  191. print(f'start to predict:{t_start}')
  192. pred = model(self.move_to_device(imgs, self.device))
  193. # print(f'pred:{pred}')
  194. t_end = time.time()
  195. print(f'predict used:{t_end - t_start}')
  196. if batch_idx == 0:
  197. show_line(imgs[0], pred, epoch, writer)
  198. break
  199. import torch
  200. from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
  201. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  202. if __name__ == '__main__':
  203. # model = LineNet('line_net.yaml')
  204. model=linenet_resnet50_fpn()
  205. #model=linenet_resnet18_fpn()
  206. # trainer = Trainer()
  207. # trainer.train_cfg(model,cfg='./train.yaml')
  208. # model.train_by_cfg(cfg='train.yaml')
  209. trainer = Trainer()
  210. trainer.train_cfg(model=model, cfg='train.yaml')