冻结参数训练.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. #!/usr/bin/env python3
  2. import datetime
  3. import glob
  4. import os
  5. import os.path as osp
  6. import platform
  7. import pprint
  8. import random
  9. import shlex
  10. import shutil
  11. import subprocess
  12. import sys
  13. import numpy as np
  14. import torch
  15. import torchvision
  16. import yaml
  17. import lcnn
  18. from lcnn.config import C, M
  19. from lcnn.datasets import WireframeDataset, collate
  20. from lcnn.models.line_vectorizer import LineVectorizer
  21. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  22. from torchvision.models import resnet50
  23. def print_model_structure(model):
  24. """
  25. 详细打印模型结构和参数
  26. """
  27. print("\n========= Model Structure =========")
  28. # 打印模型总体信息
  29. print("Model Type:", type(model))
  30. # 打印模型总参数量
  31. total_params = sum(p.numel() for p in model.parameters())
  32. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  33. print(f"\nTotal Parameters: {total_params:,}")
  34. print(f"Trainable Parameters: {trainable_params:,}")
  35. print(f"Non-trainable Parameters: {total_params - trainable_params:,}")
  36. # 打印每个模块的参数量和可训练状态
  37. print("\n===== Detailed Model Components =====")
  38. for name, module in model.named_children():
  39. module_params = sum(p.numel() for p in module.parameters())
  40. module_trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
  41. print(f"\nmodel.named:{name}:")
  42. print(f" Total Parameters: {module_params:,}")
  43. print(f" Trainable Parameters: {module_trainable_params:,}")
  44. # 打印子模块
  45. for subname, submodule in module.named_children():
  46. sub_params = sum(p.numel() for p in submodule.parameters())
  47. sub_trainable_params = sum(p.numel() for p in submodule.parameters() if p.requires_grad)
  48. print(f" {subname}:")
  49. print(f" Total Parameters: {sub_params:,}")
  50. print(f" Trainable Parameters: {sub_trainable_params:,}")
  51. def verify_freeze_params(model, freeze_config):
  52. """
  53. 验证参数冻结是否生效
  54. """
  55. print("\n===== Verifying Parameter Freezing =====")
  56. for name, module in model.named_children():
  57. if name in freeze_config:
  58. if freeze_config[name]:
  59. print(f"\nChecking module: {name}")
  60. for param_name, param in module.named_parameters():
  61. print(f" {param_name}: requires_grad = {param.requires_grad}")
  62. # 特别处理fc2子模块
  63. if name == 'fc2' and 'fc2_submodules' in freeze_config:
  64. for subname, submodule in module.named_children():
  65. if subname in freeze_config['fc2_submodules']:
  66. if freeze_config['fc2_submodules'][subname]:
  67. print(f"\nChecking fc2 submodule: {subname}")
  68. for param_name, param in submodule.named_parameters():
  69. print(f" {param_name}: requires_grad = {param.requires_grad}")
  70. def freeze_params(model, freeze_config=None):
  71. """
  72. 更精细的参数冻结方法
  73. Args:
  74. model: 要冻结参数的模型
  75. freeze_config: 冻结配置字典
  76. """
  77. # 默认冻结配置
  78. default_config = {
  79. 'backbone': False,
  80. 'fc1': False,
  81. 'fc2': False,
  82. 'fc2_submodules': {
  83. '0': False, # fc2的第一个子模块
  84. '2': False, # fc2的第三个子模块
  85. '4': False # fc2的第五个子模块
  86. },
  87. 'pooling': False,
  88. 'loss': False
  89. }
  90. # 更新默认配置
  91. if freeze_config is not None:
  92. for key, value in freeze_config.items():
  93. if isinstance(value, dict):
  94. default_config[key].update(value)
  95. else:
  96. default_config[key] = value
  97. print("\n===== Parameter Freezing Configuration =====")
  98. for name, module in model.named_children():
  99. # 处理主模块冻结
  100. if name in default_config:
  101. for param in module.parameters():
  102. param.requires_grad = not default_config[name]
  103. if not default_config[name]:
  104. print(f"Module {name} is trainable")
  105. else:
  106. print(f"Freezing module: {name}")
  107. # 处理fc2的子模块
  108. if name == 'fc2' and 'fc2_submodules' in default_config:
  109. for subname, submodule in module.named_children():
  110. if subname in default_config['fc2_submodules']:
  111. for param in submodule.parameters():
  112. param.requires_grad = not default_config['fc2_submodules'][subname]
  113. if not default_config['fc2_submodules'][subname]:
  114. print(f"Submodule fc2.{subname} is trainable")
  115. else:
  116. print(f"Freezing submodule: fc2.{subname}")
  117. # 打印参数冻结后的详细信息
  118. total_params = sum(p.numel() for p in model.parameters())
  119. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  120. print(f"\nTotal Parameters: {total_params:,}")
  121. print(f"Trainable Parameters: {trainable_params:,}")
  122. print(f"Frozen Parameters: {total_params - trainable_params:,}")
  123. def get_model(num_classes):
  124. # 加载预训练的ResNet-50 FPN backbone
  125. model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
  126. # 获取分类器的输入特征数
  127. in_features = model.roi_heads.box_predictor.cls_score.in_features
  128. # 替换分类器以适应新的类别数量
  129. model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
  130. return model
  131. def main():
  132. # 训练配置参数
  133. config = {
  134. # 数据集配置
  135. 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
  136. 'config_file': 'config/wireframe.yaml', # 配置文件路径
  137. # GPU配置
  138. 'devices': '0', # 使用的GPU设备
  139. 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
  140. # 预训练模型路径
  141. # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径
  142. # 详细的参数冻结配置 冻结是True
  143. 'freeze_config': {
  144. 'backbone': False, # 冻结backbone
  145. 'fc1': False, # 不冻结fc1
  146. 'fc2': False, # 不冻结fc2
  147. 'fc2_submodules': {
  148. '0': False, # fc2的第一个子模块保持可训练
  149. '2': False, # 冻结fc2的第三个子模块
  150. '4': False # fc2的第五个子模块保持可训练
  151. },
  152. 'pooling': False, # 不冻结pooling
  153. 'loss': False # 不冻结loss
  154. }
  155. }
  156. # 更新配置
  157. C.update(C.from_yaml(filename=config['config_file']))
  158. M.update(C.model)
  159. # 设置随机数种子
  160. random.seed(0)
  161. np.random.seed(0)
  162. torch.manual_seed(0)
  163. # 设备配置
  164. device_name = "cpu"
  165. os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
  166. if torch.cuda.is_available():
  167. device_name = "cuda"
  168. torch.backends.cudnn.deterministic = True
  169. torch.cuda.manual_seed(0)
  170. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  171. else:
  172. print("CUDA is not available")
  173. device = torch.device(device_name)
  174. # 数据加载
  175. kwargs = {
  176. "collate_fn": collate,
  177. "num_workers": C.io.num_workers if os.name != "nt" else 0,
  178. "pin_memory": True,
  179. }
  180. train_loader = torch.utils.data.DataLoader(
  181. WireframeDataset(config['datadir'], dataset_type="train"),
  182. shuffle=True,
  183. batch_size=M.batch_size,
  184. **kwargs,
  185. )
  186. val_loader = torch.utils.data.DataLoader(
  187. WireframeDataset(config['datadir'], dataset_type="val"),
  188. shuffle=False,
  189. batch_size=M.batch_size_eval,
  190. **kwargs,
  191. )
  192. # 构建模型
  193. if M.backbone == "stacked_hourglass":
  194. print(f"backbone == stacked_hourglass")
  195. model = lcnn.models.hg(
  196. depth=M.depth,
  197. head=MultitaskHead,
  198. num_stacks=M.num_stacks,
  199. num_blocks=M.num_blocks,
  200. num_classes=sum(sum(M.head_size, [])),
  201. )
  202. print(f"model.shape:{model}")
  203. model = MultitaskLearner(model)
  204. model = LineVectorizer(model)
  205. elif M.backbone == "unet":
  206. print(f"backbone == unet")
  207. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  208. model = lcnn.models.unet(
  209. num_classes=sum(sum(M.head_size, [])),
  210. num_stacks=M.num_stacks,
  211. base_channels=kwargs.get("base_channels", 64)
  212. )
  213. model = MultitaskLearner(model)
  214. model = LineVectorizer(model)
  215. elif M.backbone == "resnet50":
  216. print(f"backbone == resnet50")
  217. model = lcnn.models.resnet50(
  218. # num_stacks=M.num_stacks,
  219. num_classes=sum(sum(M.head_size, [])),
  220. )
  221. model = MultitaskLearner(model)
  222. model = LineVectorizer(model)
  223. elif M.backbone == "resnet501":
  224. print(f"backbone == resnet501")
  225. model = lcnn.models.resnet501(
  226. # num_stacks=M.num_stacks,
  227. num_classes=sum(sum(M.head_size, [])),
  228. )
  229. model = MultitaskLearner(model)
  230. model = LineVectorizer(model)
  231. elif M.backbone == "fasterrcnn_resnet50":
  232. print(f"backbone == fasterrcnn_resnet50")
  233. model = lcnn.models.fasterrcnn_resnet50(
  234. # num_stacks=M.num_stacks,
  235. num_classes=sum(sum(M.head_size, [])),
  236. )
  237. model = MultitaskLearner(model)
  238. model = LineVectorizer(model)
  239. else:
  240. raise NotImplementedError
  241. # 加载预训练权重
  242. try:
  243. # 加载模型权重
  244. checkpoint = torch.load(config['pretrained_model'], map_location=device)
  245. # 根据实际的检查点结构选择加载方式
  246. if 'model_state_dict' in checkpoint:
  247. # 如果是完整的检查点
  248. model.load_state_dict(checkpoint['model_state_dict'])
  249. elif 'state_dict' in checkpoint:
  250. # 如果是只有状态字典的检查点
  251. model.load_state_dict(checkpoint['state_dict'])
  252. else:
  253. # 直接加载权重字典
  254. model.load_state_dict(checkpoint)
  255. print("Successfully loaded pre-trained model weights.")
  256. except Exception as e:
  257. print(f"Error loading model weights: {e}")
  258. # 打印模型结构
  259. # print_model_structure(model)
  260. # # 冻结参数
  261. # freeze_params(
  262. # model,
  263. # freeze_config=config['freeze_config']
  264. # )
  265. # # 验证冻结参数
  266. # verify_freeze_params(model, config['freeze_config'])
  267. #
  268. # # 打印模型结构
  269. # print("\n========= After Freezing Backbone =========")
  270. # print_model_structure(model)
  271. # 移动到设备
  272. model = model.to(device)
  273. # 优化器配置
  274. if C.optim.name == "Adam":
  275. optim = torch.optim.Adam(
  276. filter(lambda p: p.requires_grad, model.parameters()),
  277. lr=C.optim.lr,
  278. weight_decay=C.optim.weight_decay,
  279. amsgrad=C.optim.amsgrad,
  280. )
  281. elif C.optim.name == "SGD":
  282. optim = torch.optim.SGD(
  283. filter(lambda p: p.requires_grad, model.parameters()),
  284. lr=C.optim.lr,
  285. weight_decay=C.optim.weight_decay,
  286. momentum=C.optim.momentum,
  287. )
  288. else:
  289. raise NotImplementedError
  290. # 输出目录
  291. outdir = osp.join(
  292. osp.expanduser(C.io.logdir),
  293. f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
  294. )
  295. os.makedirs(outdir, exist_ok=True)
  296. try:
  297. trainer = lcnn.trainer.Trainer(
  298. device=device,
  299. model=model,
  300. optimizer=optim,
  301. train_loader=train_loader,
  302. val_loader=val_loader,
  303. out=outdir,
  304. )
  305. print("Starting training...")
  306. trainer.train()
  307. print("Training completed.")
  308. except BaseException:
  309. if len(glob.glob(f"{outdir}/viz/*")) <= 1:
  310. shutil.rmtree(outdir)
  311. raise
  312. if __name__ == "__main__":
  313. main()