trainer.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import math
  2. import os
  3. import sys
  4. from datetime import datetime
  5. import torch
  6. import torchvision
  7. from torch.utils.tensorboard import SummaryWriter
  8. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  9. from models.config.config_tool import read_yaml
  10. from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
  11. from tools import utils, presets
  12. def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
  13. <<<<<<< HEAD
  14. model.train()
  15. =======
  16. model.train1()
  17. >>>>>>> dev
  18. metric_logger = utils.MetricLogger(delimiter=" ")
  19. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
  20. header = f"Epoch: [{epoch}]"
  21. lr_scheduler = None
  22. if epoch == 0:
  23. warmup_factor = 1.0 / 1000
  24. warmup_iters = min(1000, len(data_loader) - 1)
  25. lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  26. optimizer, start_factor=warmup_factor, total_iters=warmup_iters
  27. )
  28. for images, targets in metric_logger.log_every(data_loader, print_freq, header):
  29. print(f'images:{images}')
  30. images = list(image.to(device) for image in images)
  31. targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
  32. with torch.cuda.amp.autocast(enabled=scaler is not None):
  33. loss_dict = model(images, targets)
  34. losses = sum(loss for loss in loss_dict.values())
  35. # reduce losses over all GPUs for logging purposes
  36. loss_dict_reduced = utils.reduce_dict(loss_dict)
  37. losses_reduced = sum(loss for loss in loss_dict_reduced.values())
  38. loss_value = losses_reduced.item()
  39. if not math.isfinite(loss_value):
  40. print(f"Loss is {loss_value}, stopping training")
  41. print(loss_dict_reduced)
  42. sys.exit(1)
  43. optimizer.zero_grad()
  44. if scaler is not None:
  45. scaler.scale(losses).backward()
  46. scaler.step(optimizer)
  47. scaler.update()
  48. else:
  49. losses.backward()
  50. optimizer.step()
  51. if lr_scheduler is not None:
  52. lr_scheduler.step()
  53. metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
  54. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  55. return metric_logger
  56. def load_train_parameter(cfg):
  57. parameters = read_yaml(cfg)
  58. return parameters
  59. def train_cfg(model, cfg):
  60. parameters = read_yaml(cfg)
  61. print(f'train parameters:{parameters}')
  62. train(model, **parameters)
  63. def train(model, **kwargs):
  64. # 默认参数
  65. default_params = {
  66. 'dataset_path': '/path/to/dataset',
  67. 'num_classes': 2,
  68. 'num_keypoints':2,
  69. 'opt': 'adamw',
  70. 'batch_size': 2,
  71. 'epochs': 10,
  72. 'lr': 0.005,
  73. 'momentum': 0.9,
  74. 'weight_decay': 1e-4,
  75. 'lr_step_size': 3,
  76. 'lr_gamma': 0.1,
  77. 'num_workers': 4,
  78. 'print_freq': 10,
  79. 'target_type': 'polygon',
  80. 'enable_logs': True,
  81. 'augmentation': False,
  82. 'checkpoint':None
  83. }
  84. # 更新默认参数
  85. for key, value in kwargs.items():
  86. if key in default_params:
  87. default_params[key] = value
  88. else:
  89. raise ValueError(f"Unknown argument: {key}")
  90. # 解析参数
  91. dataset_path = default_params['dataset_path']
  92. num_classes = default_params['num_classes']
  93. batch_size = default_params['batch_size']
  94. epochs = default_params['epochs']
  95. lr = default_params['lr']
  96. momentum = default_params['momentum']
  97. weight_decay = default_params['weight_decay']
  98. lr_step_size = default_params['lr_step_size']
  99. lr_gamma = default_params['lr_gamma']
  100. num_workers = default_params['num_workers']
  101. print_freq = default_params['print_freq']
  102. target_type = default_params['target_type']
  103. augmentation = default_params['augmentation']
  104. # 设置设备
  105. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  106. train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
  107. wts_path = os.path.join(train_result_ptath, 'weights')
  108. tb_path = os.path.join(train_result_ptath, 'logs')
  109. writer = SummaryWriter(tb_path)
  110. transforms = None
  111. # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  112. if augmentation:
  113. transforms = get_transform(is_train=True)
  114. print(f'transforms:{transforms}')
  115. if not os.path.exists('train_results'):
  116. os.mkdir('train_results')
  117. model.to(device)
  118. optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
  119. dataset = MaskRCNNDataset(dataset_path=dataset_path,
  120. transforms=transforms, dataset_type='train', target_type=target_type)
  121. dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
  122. dataset_type='val')
  123. train_sampler = torch.utils.data.RandomSampler(dataset)
  124. test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  125. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
  126. train_collate_fn = utils.collate_fn
  127. data_loader = torch.utils.data.DataLoader(
  128. dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
  129. )
  130. # data_loader_test = torch.utils.data.DataLoader(
  131. # dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
  132. # )
  133. img_results_path = os.path.join(train_result_ptath, 'img_results')
  134. if os.path.exists(train_result_ptath):
  135. pass
  136. # os.remove(train_result_ptath)
  137. else:
  138. os.mkdir(train_result_ptath)
  139. if os.path.exists(train_result_ptath):
  140. os.mkdir(wts_path)
  141. os.mkdir(img_results_path)
  142. for epoch in range(epochs):
  143. metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
  144. losses = metric_logger.meters['loss'].global_avg
  145. print(f'epoch {epoch}:loss:{losses}')
  146. if os.path.exists(f'{wts_path}/last.pt'):
  147. os.remove(f'{wts_path}/last.pt')
  148. torch.save(model.state_dict(), f'{wts_path}/last.pt')
  149. write_metric_logs(epoch, metric_logger, writer)
  150. if epoch == 0:
  151. best_loss = losses;
  152. if best_loss >= losses:
  153. best_loss = losses
  154. if os.path.exists(f'{wts_path}/best.pt'):
  155. os.remove(f'{wts_path}/best.pt')
  156. torch.save(model.state_dict(), f'{wts_path}/best.pt')
  157. def get_transform(is_train, **kwargs):
  158. default_params = {
  159. 'augmentation': 'multiscale',
  160. 'backend': 'tensor',
  161. 'use_v2': False,
  162. }
  163. # 更新默认参数
  164. for key, value in kwargs.items():
  165. if key in default_params:
  166. default_params[key] = value
  167. else:
  168. raise ValueError(f"Unknown argument: {key}")
  169. # 解析参数
  170. augmentation = default_params['augmentation']
  171. backend = default_params['backend']
  172. use_v2 = default_params['use_v2']
  173. if is_train:
  174. return presets.DetectionPresetTrain(
  175. data_augmentation=augmentation, backend=backend, use_v2=use_v2
  176. )
  177. # elif weights and test_only:
  178. # weights = torchvision.models.get_weight(args.weights)
  179. # trans = weights.transforms()
  180. # return lambda img, target: (trans(img), target)
  181. else:
  182. return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
  183. def write_metric_logs(epoch, metric_logger, writer):
  184. writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
  185. writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
  186. writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
  187. writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
  188. writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
  189. writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)