train——line_rcnn.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #!/usr/bin/env python3
  2. import datetime
  3. import glob
  4. import os
  5. import os.path as osp
  6. import platform
  7. import pprint
  8. import random
  9. import shlex
  10. import shutil
  11. import subprocess
  12. import sys
  13. import numpy as np
  14. import torch
  15. import torchvision
  16. import yaml
  17. import lcnn
  18. from lcnn.config import C, M
  19. from lcnn.datasets import WireframeDataset, collate
  20. from lcnn.models.line_vectorizer import LineVectorizer
  21. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  22. from torchvision.models import resnet50
  23. from models.line_detect.line_rcnn import linercnn_resnet50_fpn
  24. def main():
  25. # 训练配置参数
  26. config = {
  27. # 数据集配置
  28. 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
  29. 'config_file': 'config/wireframe.yaml', # 配置文件路径
  30. # GPU配置
  31. 'devices': '0', # 使用的GPU设备
  32. 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
  33. # 预训练模型路径
  34. # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径
  35. }
  36. # 更新配置
  37. C.update(C.from_yaml(filename=config['config_file']))
  38. M.update(C.model)
  39. # 设置随机数种子
  40. random.seed(0)
  41. np.random.seed(0)
  42. torch.manual_seed(0)
  43. # 设备配置
  44. device_name = "cpu"
  45. os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
  46. if torch.cuda.is_available():
  47. device_name = "cuda"
  48. torch.backends.cudnn.deterministic = True
  49. torch.cuda.manual_seed(0)
  50. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  51. else:
  52. print("CUDA is not available")
  53. device = torch.device(device_name)
  54. # 数据加载
  55. kwargs = {
  56. "collate_fn": collate,
  57. "num_workers": C.io.num_workers if os.name != "nt" else 0,
  58. "pin_memory": True,
  59. }
  60. train_loader = torch.utils.data.DataLoader(
  61. WireframeDataset(config['datadir'], dataset_type="train"),
  62. shuffle=True,
  63. batch_size=M.batch_size,
  64. **kwargs,
  65. )
  66. val_loader = torch.utils.data.DataLoader(
  67. WireframeDataset(config['datadir'], dataset_type="val"),
  68. shuffle=False,
  69. batch_size=M.batch_size_eval,
  70. **kwargs,
  71. )
  72. model = linercnn_resnet50_fpn().to(device)
  73. # 加载预训练权重
  74. try:
  75. # 加载模型权重
  76. checkpoint = torch.load(config['pretrained_model'], map_location=device)
  77. # 根据实际的检查点结构选择加载方式
  78. if 'model_state_dict' in checkpoint:
  79. # 如果是完整的检查点
  80. model.load_state_dict(checkpoint['model_state_dict'])
  81. elif 'state_dict' in checkpoint:
  82. # 如果是只有状态字典的检查点
  83. model.load_state_dict(checkpoint['state_dict'])
  84. else:
  85. # 直接加载权重字典
  86. model.load_state_dict(checkpoint)
  87. print("Successfully loaded pre-trained model weights.")
  88. except Exception as e:
  89. print(f"Error loading model weights: {e}")
  90. # 优化器配置
  91. if C.optim.name == "Adam":
  92. optim = torch.optim.Adam(
  93. filter(lambda p: p.requires_grad, model.parameters()),
  94. lr=C.optim.lr,
  95. weight_decay=C.optim.weight_decay,
  96. amsgrad=C.optim.amsgrad,
  97. )
  98. elif C.optim.name == "SGD":
  99. optim = torch.optim.SGD(
  100. filter(lambda p: p.requires_grad, model.parameters()),
  101. lr=C.optim.lr,
  102. weight_decay=C.optim.weight_decay,
  103. momentum=C.optim.momentum,
  104. )
  105. else:
  106. raise NotImplementedError
  107. # 输出目录
  108. outdir = osp.join(
  109. osp.expanduser(C.io.logdir),
  110. f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
  111. )
  112. os.makedirs(outdir, exist_ok=True)
  113. try:
  114. trainer = lcnn.trainer.Trainer(
  115. device=device,
  116. model=model,
  117. optimizer=optim,
  118. train_loader=train_loader,
  119. val_loader=val_loader,
  120. out=outdir,
  121. )
  122. print("Starting training...")
  123. trainer.train()
  124. print("Training completed.")
  125. except BaseException:
  126. if len(glob.glob(f"{outdir}/viz/*")) <= 1:
  127. shutil.rmtree(outdir)
  128. raise
  129. if __name__ == "__main__":
  130. main()