trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import math
  2. import os
  3. import sys
  4. from datetime import datetime
  5. import cv2
  6. import numpy as np
  7. import torch
  8. import torchvision
  9. from torch.utils.tensorboard import SummaryWriter
  10. from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  11. from torchvision.utils import draw_bounding_boxes
  12. from models.config.config_tool import read_yaml
  13. from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
  14. from tools import utils, presets
  15. def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
  16. model.train()
  17. metric_logger = utils.MetricLogger(delimiter=" ")
  18. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
  19. header = f"Epoch: [{epoch}]"
  20. lr_scheduler = None
  21. if epoch == 0:
  22. warmup_factor = 1.0 / 1000
  23. warmup_iters = min(1000, len(data_loader) - 1)
  24. lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  25. optimizer, start_factor=warmup_factor, total_iters=warmup_iters
  26. )
  27. for images, targets in metric_logger.log_every(data_loader, print_freq, header):
  28. # print(f'images:{images}')
  29. images = list(image.to(device) for image in images)
  30. targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
  31. with torch.cuda.amp.autocast(enabled=scaler is not None):
  32. loss_dict = model(images, targets)
  33. losses = sum(loss for loss in loss_dict.values())
  34. # reduce losses over all GPUs for logging purposes
  35. loss_dict_reduced = utils.reduce_dict(loss_dict)
  36. losses_reduced = sum(loss for loss in loss_dict_reduced.values())
  37. loss_value = losses_reduced.item()
  38. if not math.isfinite(loss_value):
  39. print(f"Loss is {loss_value}, stopping training")
  40. print(loss_dict_reduced)
  41. sys.exit(1)
  42. optimizer.zero_grad()
  43. if scaler is not None:
  44. scaler.scale(losses).backward()
  45. scaler.step(optimizer)
  46. scaler.update()
  47. else:
  48. losses.backward()
  49. optimizer.step()
  50. if lr_scheduler is not None:
  51. lr_scheduler.step()
  52. metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
  53. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  54. return metric_logger
  55. def load_train_parameter(cfg):
  56. parameters = read_yaml(cfg)
  57. return parameters
  58. def train_cfg(model, cfg):
  59. parameters = read_yaml(cfg)
  60. print(f'train parameters:{parameters}')
  61. train(model, **parameters)
  62. def move_to_device(data, device):
  63. if isinstance(data, (list, tuple)):
  64. return type(data)(move_to_device(item, device) for item in data)
  65. elif isinstance(data, dict):
  66. return {key: move_to_device(value, device) for key, value in data.items()}
  67. elif isinstance(data, torch.Tensor):
  68. return data.to(device)
  69. else:
  70. return data # 对于非张量类型的数据不做任何改变
  71. def train(model, **kwargs):
  72. # 默认参数
  73. default_params = {
  74. 'dataset_path': '/path/to/dataset',
  75. 'num_classes': 2,
  76. 'num_keypoints':2,
  77. 'opt': 'adamw',
  78. 'batch_size': 2,
  79. 'epochs': 10,
  80. 'lr': 0.005,
  81. 'momentum': 0.9,
  82. 'weight_decay': 1e-4,
  83. 'lr_step_size': 3,
  84. 'lr_gamma': 0.1,
  85. 'num_workers': 4,
  86. 'print_freq': 10,
  87. 'target_type': 'polygon',
  88. 'enable_logs': True,
  89. 'augmentation': False,
  90. 'checkpoint':None
  91. }
  92. # 更新默认参数
  93. for key, value in kwargs.items():
  94. if key in default_params:
  95. default_params[key] = value
  96. else:
  97. raise ValueError(f"Unknown argument: {key}")
  98. # 解析参数
  99. dataset_path = default_params['dataset_path']
  100. num_classes = default_params['num_classes']
  101. batch_size = default_params['batch_size']
  102. epochs = default_params['epochs']
  103. lr = default_params['lr']
  104. momentum = default_params['momentum']
  105. weight_decay = default_params['weight_decay']
  106. lr_step_size = default_params['lr_step_size']
  107. lr_gamma = default_params['lr_gamma']
  108. num_workers = default_params['num_workers']
  109. print_freq = default_params['print_freq']
  110. target_type = default_params['target_type']
  111. augmentation = default_params['augmentation']
  112. # 设置设备
  113. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  114. train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
  115. wts_path = os.path.join(train_result_ptath, 'weights')
  116. tb_path = os.path.join(train_result_ptath, 'logs')
  117. writer = SummaryWriter(tb_path)
  118. transforms = None
  119. # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  120. if augmentation:
  121. transforms = get_transform(is_train=True)
  122. print(f'transforms:{transforms}')
  123. if not os.path.exists('train_results'):
  124. os.mkdir('train_results')
  125. model.to(device)
  126. optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
  127. dataset = MaskRCNNDataset(dataset_path=dataset_path,
  128. transforms=transforms, dataset_type='train', target_type=target_type)
  129. val_dataset = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
  130. dataset_type='val')
  131. train_sampler = torch.utils.data.RandomSampler(dataset)
  132. val_sampler = torch.utils.data.RandomSampler(val_dataset)
  133. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
  134. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size, drop_last=True)
  135. train_collate_fn = utils.collate_fn
  136. val_collate_fn = utils.collate_fn
  137. data_loader = torch.utils.data.DataLoader(
  138. dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
  139. )
  140. val_data_loader = torch.utils.data.DataLoader(
  141. val_dataset, batch_sampler=val_batch_sampler, num_workers=num_workers, collate_fn=val_collate_fn
  142. )
  143. # data_loader_test = torch.utils.data.DataLoader(
  144. # dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
  145. # )
  146. img_results_path = os.path.join(train_result_ptath, 'img_results')
  147. if os.path.exists(train_result_ptath):
  148. pass
  149. # os.remove(train_result_ptath)
  150. else:
  151. os.mkdir(train_result_ptath)
  152. if os.path.exists(train_result_ptath):
  153. os.mkdir(wts_path)
  154. os.mkdir(img_results_path)
  155. for epoch in range(epochs):
  156. metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
  157. losses = metric_logger.meters['loss'].global_avg
  158. print(f'epoch {epoch}:loss:{losses}')
  159. if os.path.exists(f'{wts_path}/last.pt'):
  160. os.remove(f'{wts_path}/last.pt')
  161. torch.save(model.state_dict(), f'{wts_path}/last.pt')
  162. write_metric_logs(epoch, metric_logger, writer)
  163. if epoch == 0:
  164. best_loss = losses;
  165. if best_loss >= losses:
  166. best_loss = losses
  167. if os.path.exists(f'{wts_path}/best.pt'):
  168. os.remove(f'{wts_path}/best.pt')
  169. torch.save(model.state_dict(), f'{wts_path}/best.pt')
  170. with torch.no_grad():
  171. model.eval()
  172. for step, (imgs,targets) in enumerate(val_data_loader):
  173. imgs,targets=(imgs,targets)
  174. imgs=move_to_device(imgs,device)
  175. if step ==0:
  176. result=model([imgs[0]])
  177. print(f'model eval result :{result}')
  178. write_val_imgs(epoch,imgs[0],result,writer)
  179. def get_transform(is_train, **kwargs):
  180. default_params = {
  181. 'augmentation': 'multiscale',
  182. 'backend': 'tensor',
  183. 'use_v2': False,
  184. }
  185. # 更新默认参数
  186. for key, value in kwargs.items():
  187. if key in default_params:
  188. default_params[key] = value
  189. else:
  190. raise ValueError(f"Unknown argument: {key}")
  191. # 解析参数
  192. augmentation = default_params['augmentation']
  193. backend = default_params['backend']
  194. use_v2 = default_params['use_v2']
  195. if is_train:
  196. return presets.DetectionPresetTrain(
  197. data_augmentation=augmentation, backend=backend, use_v2=use_v2
  198. )
  199. # elif weights and test_only:
  200. # weights = torchvision.models.get_weight(args.weights)
  201. # trans = weights.transforms()
  202. # return lambda img, target: (trans(img), target)
  203. else:
  204. return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
  205. def write_metric_logs(epoch, metric_logger, writer):
  206. writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
  207. writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
  208. writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
  209. writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
  210. writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
  211. writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
  212. def generate_colors(n):
  213. """
  214. 生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
  215. :param n: 需要的颜色数量
  216. :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
  217. """
  218. hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
  219. bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
  220. return bgr_colors
  221. def overlay_masks_on_image(image, masks, alpha=0.6):
  222. """
  223. 在原图上叠加多个掩码,每个掩码使用不同的颜色。
  224. :param image: 原图 (NumPy 数组)
  225. :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
  226. :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
  227. :param alpha: 掩码的透明度 (0.0 到 1.0)
  228. :return: 叠加了多个掩码的图像
  229. """
  230. colors = generate_colors(len(masks))
  231. if len(masks) != len(colors):
  232. raise ValueError("The number of masks and colors must be the same.")
  233. # 复制原图,避免修改原始图像
  234. overlay = image.copy()
  235. for mask, color in zip(masks, colors):
  236. # 确保掩码是二值图像
  237. mask = mask.cpu().detach().permute(1, 2, 0).numpy()
  238. binary_mask = (mask > 0).astype(np.uint8) * 255 # 你可以根据实际情况调整阈值
  239. # 创建彩色掩码
  240. colored_mask = np.zeros_like(image)
  241. colored_mask[:] = color
  242. colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
  243. # 将彩色掩码与当前的叠加图像混合
  244. overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
  245. return overlay
  246. def write_val_imgs(epoch, img, results, writer):
  247. masks = results[0]['masks']
  248. boxes = results[0]['boxes']
  249. print(f'writer img shape:{img.shape}')
  250. # cv2.imshow('mask',masks[0].cpu().detach().numpy())
  251. boxes = boxes.cpu().detach()
  252. drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
  253. print(f'drawn_boxes:{drawn_boxes.shape}')
  254. drawn_boxes=drawn_boxes.cpu()
  255. # boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
  256. writer.add_image("z-boxes", drawn_boxes, epoch)
  257. # boxed_img=cv2.resize(boxed_img,(800,800))
  258. # cv2.imshow('boxes',boxed_img)
  259. if masks.shape[0]>0:
  260. mask = masks[0].cpu().detach().permute(1, 2, 0).numpy()
  261. mask = cv2.resize(mask, (800, 800))
  262. # cv2.imshow('mask',mask)
  263. img = img.cpu().detach().permute(1, 2, 0).numpy()
  264. masked_img = overlay_masks_on_image(boxed_img, masks)
  265. masked_img = cv2.resize(masked_img, (800, 800))
  266. writer.add_image("z-masks", masked_img, epoch)