trainer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import math
  2. import os
  3. import sys
  4. from datetime import datetime
  5. import torch
  6. import torchvision
  7. from torch.utils.tensorboard import SummaryWriter
  8. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  9. from models.wirenet.postprocess import postprocess_keypoint
  10. from torchvision.utils import draw_bounding_boxes
  11. from torchvision import transforms
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. import matplotlib as mpl
  15. from tools.coco_utils import get_coco_api_from_dataset
  16. from tools.coco_eval import CocoEvaluator
  17. import time
  18. from models.config.config_tool import read_yaml
  19. from models.ins.maskrcnn_dataset import MaskRCNNDataset
  20. from models.keypoint.keypoint_dataset import KeypointDataset
  21. from tools import utils, presets
  22. def log_losses_to_tensorboard(writer, result, step):
  23. writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
  24. writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
  25. writer.add_scalar('Loss/keypoint', result['loss_keypoint'].item(), step)
  26. writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
  27. writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
  28. def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
  29. model.train1()
  30. metric_logger = utils.MetricLogger(delimiter=" ")
  31. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
  32. header = f"Epoch: [{epoch}]"
  33. lr_scheduler = None
  34. if epoch == 0:
  35. warmup_factor = 1.0 / 1000
  36. warmup_iters = min(1000, len(data_loader) - 1)
  37. lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  38. optimizer, start_factor=warmup_factor, total_iters=warmup_iters
  39. )
  40. total_train_loss=0
  41. for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
  42. global_step = epoch * len(data_loader) + batch_idx
  43. # print(f'images:{images}')
  44. images = list(image.to(device) for image in images)
  45. targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
  46. with torch.cuda.amp.autocast(enabled=scaler is not None):
  47. loss_dict = model(images, targets)
  48. # print(f'loss_dict:{loss_dict}')
  49. losses = sum(loss for loss in loss_dict.values())
  50. total_train_loss += losses.item()
  51. log_losses_to_tensorboard(writer, loss_dict, global_step)
  52. # reduce losses over all GPUs for logging purposes
  53. loss_dict_reduced = utils.reduce_dict(loss_dict)
  54. losses_reduced = sum(loss for loss in loss_dict_reduced.values())
  55. loss_value = losses_reduced.item()
  56. if not math.isfinite(loss_value):
  57. print(f"Loss is {loss_value}, stopping training")
  58. print(loss_dict_reduced)
  59. sys.exit(1)
  60. optimizer.zero_grad()
  61. if scaler is not None:
  62. scaler.scale(losses).backward()
  63. scaler.step(optimizer)
  64. scaler.update()
  65. else:
  66. losses.backward()
  67. optimizer.step()
  68. if lr_scheduler is not None:
  69. lr_scheduler.step()
  70. metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
  71. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  72. return metric_logger, total_train_loss
  73. cmap = plt.get_cmap("jet")
  74. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  75. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  76. sm.set_array([])
  77. def c(x):
  78. return sm.to_rgba(x)
  79. def show_line(img, pred, epoch, writer):
  80. im = img.permute(1, 2, 0) # [512, 512, 3]
  81. writer.add_image("ori", im, epoch, dataformats="HWC")
  82. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
  83. colors="yellow", width=1)
  84. # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
  85. # plt.show()
  86. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  87. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  88. lines = pred["keypoints"].detach().cpu().numpy()
  89. scores = pred["keypoints_scores"].detach().cpu().numpy()
  90. # for i in range(1, len(lines)):
  91. # if (lines[i] == lines[0]).all():
  92. # lines = lines[:i]
  93. # scores = scores[:i]
  94. # break
  95. # postprocess lines to remove overlapped lines
  96. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  97. nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
  98. # print(f'nscores:{nscores}')
  99. for i, t in enumerate([0.5]):
  100. plt.gca().set_axis_off()
  101. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  102. plt.margins(0, 0)
  103. for (a, b), s in zip(nlines, nscores):
  104. if s < t:
  105. continue
  106. # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  107. # plt.scatter(a[1], a[0], **PLTOPTS)
  108. # plt.scatter(b[1], b[0], **PLTOPTS)
  109. plt.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=2, zorder=s)
  110. plt.scatter(a[0], a[1], **PLTOPTS)
  111. plt.scatter(b[0], b[1], **PLTOPTS)
  112. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  113. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  114. plt.imshow(im.cpu())
  115. plt.tight_layout()
  116. fig = plt.gcf()
  117. fig.canvas.draw()
  118. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  119. fig.canvas.get_width_height()[::-1] + (3,))
  120. plt.close()
  121. img2 = transforms.ToTensor()(image_from_plot)
  122. writer.add_image("output", img2, epoch)
  123. def _get_iou_types(model):
  124. model_without_ddp = model
  125. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  126. model_without_ddp = model.module
  127. iou_types = ["bbox"]
  128. if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
  129. iou_types.append("segm")
  130. if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
  131. iou_types.append("keypoints")
  132. return iou_types
  133. def evaluate(model, data_loader, epoch, writer, device):
  134. n_threads = torch.get_num_threads()
  135. # FIXME remove this and make paste_masks_in_image run on the GPU
  136. torch.set_num_threads(1)
  137. cpu_device = torch.device("cpu")
  138. model.eval()
  139. metric_logger = utils.MetricLogger(delimiter=" ")
  140. header = "Test:"
  141. coco = get_coco_api_from_dataset(data_loader.dataset)
  142. iou_types = _get_iou_types(model)
  143. coco_evaluator = CocoEvaluator(coco, iou_types)
  144. print(f'start to evaluate!!!')
  145. for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
  146. images = list(img.to(device) for img in images)
  147. model_time = time.time()
  148. outputs = model(images)
  149. # print(f'outputs:{outputs}')
  150. if batch_idx == 0:
  151. show_line(images[0], outputs[0], epoch, writer)
  152. # print(f'outputs:{outputs}')
  153. # print(f'outputs[0]:{outputs[0]}')
  154. # outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
  155. # model_time = time.time() - model_time
  156. #
  157. # res = {target["image_id"]: output for target, output in zip(targets, outputs)}
  158. # evaluator_time = time.time()
  159. # coco_evaluator.update(res)
  160. # evaluator_time = time.time() - evaluator_time
  161. # metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
  162. #
  163. # # gather the stats from all processes
  164. # metric_logger.synchronize_between_processes()
  165. # print("Averaged stats:", metric_logger)
  166. # coco_evaluator.synchronize_between_processes()
  167. #
  168. # # accumulate predictions from all images
  169. # coco_evaluator.accumulate()
  170. # coco_evaluator.summarize()
  171. # torch.set_num_threads(n_threads)
  172. # return coco_evaluator
  173. def train_cfg(model, cfg):
  174. parameters = read_yaml(cfg)
  175. print(f'train parameters:{parameters}')
  176. train(model, **parameters)
  177. def train(model, **kwargs):
  178. # 默认参数
  179. default_params = {
  180. 'dataset_path': '/path/to/dataset',
  181. 'num_classes': 2,
  182. 'num_keypoints': 2,
  183. 'opt': 'adamw',
  184. 'batch_size': 2,
  185. 'epochs': 10,
  186. 'lr': 0.005,
  187. 'momentum': 0.9,
  188. 'weight_decay': 1e-4,
  189. 'lr_step_size': 3,
  190. 'lr_gamma': 0.1,
  191. 'num_workers': 4,
  192. 'print_freq': 10,
  193. 'target_type': 'polygon',
  194. 'enable_logs': True,
  195. 'augmentation': False,
  196. 'checkpoint': None
  197. }
  198. # 更新默认参数
  199. for key, value in kwargs.items():
  200. if key in default_params:
  201. default_params[key] = value
  202. else:
  203. raise ValueError(f"Unknown argument: {key}")
  204. # 解析参数
  205. dataset_path = default_params['dataset_path']
  206. num_classes = default_params['num_classes']
  207. batch_size = default_params['batch_size']
  208. epochs = default_params['epochs']
  209. lr = default_params['lr']
  210. momentum = default_params['momentum']
  211. weight_decay = default_params['weight_decay']
  212. lr_step_size = default_params['lr_step_size']
  213. lr_gamma = default_params['lr_gamma']
  214. num_workers = default_params['num_workers']
  215. print_freq = default_params['print_freq']
  216. target_type = default_params['target_type']
  217. augmentation = default_params['augmentation']
  218. # 设置设备
  219. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  220. train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
  221. wts_path = os.path.join(train_result_ptath, 'weights')
  222. tb_path = os.path.join(train_result_ptath, 'logs')
  223. writer = SummaryWriter(tb_path)
  224. transforms = None
  225. # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  226. if augmentation:
  227. transforms = get_transform(is_train=True)
  228. print(f'transforms:{transforms}')
  229. if not os.path.exists('train_results'):
  230. os.mkdir('train_results')
  231. model.to(device)
  232. optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
  233. dataset = KeypointDataset(dataset_path=dataset_path,
  234. transforms=transforms, dataset_type='train', target_type=target_type)
  235. dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
  236. dataset_type='val')
  237. train_sampler = torch.utils.data.RandomSampler(dataset)
  238. test_sampler = torch.utils.data.RandomSampler(dataset_test)
  239. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
  240. train_collate_fn = utils.collate_fn
  241. data_loader = torch.utils.data.DataLoader(
  242. dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
  243. )
  244. data_loader_test = torch.utils.data.DataLoader(
  245. dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
  246. )
  247. img_results_path = os.path.join(train_result_ptath, 'img_results')
  248. if os.path.exists(train_result_ptath):
  249. pass
  250. # os.remove(train_result_ptath)
  251. else:
  252. os.mkdir(train_result_ptath)
  253. if os.path.exists(train_result_ptath):
  254. os.mkdir(wts_path)
  255. os.mkdir(img_results_path)
  256. for epoch in range(epochs):
  257. metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
  258. losses = metric_logger.meters['loss'].global_avg
  259. print(f'epoch {epoch}:loss:{losses}')
  260. if os.path.exists(f'{wts_path}/last.pt'):
  261. os.remove(f'{wts_path}/last.pt')
  262. torch.save(model.state_dict(), f'{wts_path}/last.pt')
  263. # write_metric_logs(epoch, metric_logger, writer)
  264. if epoch == 0:
  265. best_loss = losses;
  266. if best_loss >= losses:
  267. best_loss = losses
  268. if os.path.exists(f'{wts_path}/best.pt'):
  269. os.remove(f'{wts_path}/best.pt')
  270. torch.save(model.state_dict(), f'{wts_path}/best.pt')
  271. evaluate(model, data_loader_test, epoch, writer, device=device)
  272. avg_train_loss = total_train_loss / len(data_loader)
  273. writer.add_scalar('Loss/train', avg_train_loss, epoch)
  274. def get_transform(is_train, **kwargs):
  275. default_params = {
  276. 'augmentation': 'multiscale',
  277. 'backend': 'tensor',
  278. 'use_v2': False,
  279. }
  280. # 更新默认参数
  281. for key, value in kwargs.items():
  282. if key in default_params:
  283. default_params[key] = value
  284. else:
  285. raise ValueError(f"Unknown argument: {key}")
  286. # 解析参数
  287. augmentation = default_params['augmentation']
  288. backend = default_params['backend']
  289. use_v2 = default_params['use_v2']
  290. if is_train:
  291. return presets.DetectionPresetTrain(
  292. data_augmentation=augmentation, backend=backend, use_v2=use_v2
  293. )
  294. # elif weights and test_only:
  295. # weights = torchvision.models.get_weight(args.weights)
  296. # trans = weights.transforms()
  297. # return lambda img, target: (trans(img), target)
  298. else:
  299. return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
  300. def write_metric_logs(epoch, metric_logger, writer):
  301. writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
  302. writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
  303. # writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
  304. writer.add_scalar('Loss/box_reg', metric_logger.meters['loss_keypoint'].global_avg, epoch)
  305. writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
  306. writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
  307. writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
  308. # def log_losses_to_tensorboard(writer, result, step):
  309. # writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
  310. # writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
  311. # writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
  312. # writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
  313. # writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)