trainer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import os
  2. import time
  3. from datetime import datetime
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from matplotlib import pyplot as plt
  8. from torch.optim.lr_scheduler import ReduceLROnPlateau
  9. from torch.utils.tensorboard import SummaryWriter
  10. from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
  11. from models.base.base_model import BaseModel
  12. from models.base.base_trainer import BaseTrainer
  13. from models.config.config_tool import read_yaml
  14. from models.line_detect.line_dataset import LineDataset
  15. from models.line_net.dataset_LD import WirePointDataset
  16. from models.wirenet.postprocess import postprocess
  17. from tools import utils
  18. from torchvision import transforms
  19. import matplotlib as mpl
  20. cmap = plt.get_cmap("jet")
  21. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  22. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  23. sm.set_array([])
  24. def _loss(losses):
  25. total_loss = 0
  26. for i in losses.keys():
  27. if i != "loss_wirepoint":
  28. total_loss += losses[i]
  29. else:
  30. loss_labels = losses[i]["losses"]
  31. loss_labels_k = list(loss_labels[0].keys())
  32. for j, name in enumerate(loss_labels_k):
  33. loss = loss_labels[0][name].mean()
  34. total_loss += loss
  35. return total_loss
  36. def c(x):
  37. return sm.to_rgba(x)
  38. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  39. import matplotlib.pyplot as plt
  40. from PIL import ImageDraw
  41. from torchvision.transforms import functional as F
  42. import torch
  43. def fit_circle(points):
  44. """
  45. Fit a circle to a set of points (at least 3).
  46. Args:
  47. points: torch.Tensor 或 numpy array, shape (N, 2)
  48. Returns:
  49. center (cx, cy), radius r
  50. """
  51. # 如果是 torch.Tensor,先转为 numpy
  52. if isinstance(points, torch.Tensor):
  53. if points.dim() == 3:
  54. points = points[0] # 去掉 batch 维度
  55. points = points.detach().cpu().numpy()
  56. if not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 2):
  57. raise ValueError(f"Expected points shape (N, 2), got {points.shape}")
  58. x = points[:, 0].astype(float)
  59. y = points[:, 1].astype(float)
  60. # 确保 A 是二维数组
  61. A = np.column_stack((x, y, np.ones_like(x))) # 使用 column_stack 代替 stack 可能更清晰
  62. b = -(x ** 2 + y ** 2)
  63. try:
  64. sol, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)
  65. except np.linalg.LinAlgError as e:
  66. print(f"Linear algebra error occurred: {e}")
  67. raise ValueError("Could not fit circle to points.")
  68. D, E, F = sol
  69. cx = -D / 2.0
  70. cy = -E / 2.0
  71. r = np.sqrt(cx ** 2 + cy ** 2 - F)
  72. return (cx, cy), r
  73. # 由低到高蓝黄红
  74. def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
  75. """
  76. 根据得分对线段着色并绘制
  77. :param tensor_image: (3, H, W) uint8 图像
  78. :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
  79. :param scores: (N,) 每条线的得分,范围 [0, 1]
  80. :param width: 线宽
  81. :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
  82. :return: (3, H, W) uint8 画好线的图像
  83. """
  84. assert tensor_image.dtype == torch.uint8
  85. assert tensor_image.shape[0] == 3
  86. assert lines.shape[0] == scores.shape[0]
  87. # 准备色图
  88. colormap = plt.get_cmap(cmap)
  89. colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8') # 去掉 alpha 通道
  90. # 转为 PIL 画图
  91. image_pil = F.to_pil_image(tensor_image)
  92. draw = ImageDraw.Draw(image_pil)
  93. for line, color in zip(lines, colors):
  94. start = tuple(map(float, line[0][:2].tolist()))
  95. end = tuple(map(float, line[1][:2].tolist()))
  96. draw.line([start, end], fill=tuple(color), width=width)
  97. return (F.to_tensor(image_pil) * 255).to(torch.uint8)
  98. class Trainer(BaseTrainer):
  99. def __init__(self, model=None, **kwargs):
  100. super().__init__(model, device, **kwargs)
  101. self.model = model
  102. # print(f'kwargs:{kwargs}')
  103. self.init_params(**kwargs)
  104. def init_params(self, **kwargs):
  105. if kwargs != {}:
  106. print(f'train_params:{kwargs["train_params"]}')
  107. self.freeze_config = kwargs['train_params']['freeze_params']
  108. print(f'freeze_config:{self.freeze_config}')
  109. self.dataset_path = kwargs['io']['datadir']
  110. self.data_type = kwargs['io']['data_type']
  111. self.batch_size = kwargs['train_params']['batch_size']
  112. self.num_workers = kwargs['train_params']['num_workers']
  113. self.logdir = kwargs['io']['logdir']
  114. self.resume_from = kwargs['train_params']['resume_from']
  115. self.optim = ''
  116. self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
  117. self.wts_path = os.path.join(self.train_result_ptath, 'weights')
  118. self.tb_path = os.path.join(self.train_result_ptath, 'logs')
  119. self.writer = SummaryWriter(self.tb_path)
  120. self.last_model_path = os.path.join(self.wts_path, 'last.pth')
  121. self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
  122. self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
  123. self.max_epoch = kwargs['train_params']['max_epoch']
  124. self.augmentation= kwargs['train_params']["augmentation"]
  125. def move_to_device(self, data, device):
  126. if isinstance(data, (list, tuple)):
  127. return type(data)(self.move_to_device(item, device) for item in data)
  128. elif isinstance(data, dict):
  129. return {key: self.move_to_device(value, device) for key, value in data.items()}
  130. elif isinstance(data, torch.Tensor):
  131. return data.to(device)
  132. else:
  133. return data # 对于非张量类型的数据不做任何改变
  134. def freeze_params(self, model):
  135. """根据配置冻结模型参数"""
  136. default_config = {
  137. 'backbone': True, # 冻结 backbone
  138. 'rpn': False, # 不冻结 rpn
  139. 'roi_heads': {
  140. 'box_head': False,
  141. 'box_predictor': False,
  142. 'line_head': False,
  143. 'line_predictor': {
  144. 'fc1': False,
  145. 'fc2': {
  146. '0': False,
  147. '2': False,
  148. '4': False
  149. }
  150. }
  151. }
  152. }
  153. # 更新默认配置
  154. default_config.update(self.freeze_config)
  155. config = default_config
  156. print("\n===== Parameter Freezing Configuration =====")
  157. for name, module in model.named_children():
  158. if name in config:
  159. if isinstance(config[name], bool):
  160. for param in module.parameters():
  161. param.requires_grad = not config[name]
  162. print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
  163. elif isinstance(config[name], dict):
  164. for subname, submodule in module.named_children():
  165. if subname in config[name]:
  166. if isinstance(config[name][subname], bool):
  167. for param in submodule.parameters():
  168. param.requires_grad = not config[name][subname]
  169. print(
  170. f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
  171. elif isinstance(config[name][subname], dict):
  172. for subsubname, subsubmodule in submodule.named_children():
  173. if subsubname in config[name][subname]:
  174. for param in subsubmodule.parameters():
  175. param.requires_grad = not config[name][subname][subsubname]
  176. print(
  177. f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
  178. # 打印参数统计
  179. total_params = sum(p.numel() for p in model.parameters())
  180. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  181. print(f"\nTotal Parameters: {total_params:,}")
  182. print(f"Trainable Parameters: {trainable_params:,}")
  183. print(f"Frozen Parameters: {total_params - trainable_params:,}")
  184. def load_best_model(self, model, optimizer, save_path, device):
  185. if os.path.exists(save_path):
  186. checkpoint = torch.load(save_path, map_location=device)
  187. model.load_state_dict(checkpoint['model_state_dict'])
  188. if optimizer is not None:
  189. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  190. epoch = checkpoint['epoch']
  191. loss = checkpoint['loss']
  192. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  193. else:
  194. print(f"No saved model found at {save_path}")
  195. return model, optimizer
  196. def writer_predict_result(self, img, result, epoch,):
  197. img = img.cpu().detach()
  198. im = img.permute(1, 2, 0) # [512, 512, 3]
  199. self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
  200. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result["boxes"],
  201. colors="yellow", width=1)
  202. # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
  203. # plt.show()
  204. self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  205. if 'points' in result:
  206. keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
  207. self.writer.add_image("z-output", keypoint_img, epoch)
  208. # print("lines shape:", result['lines'].shape)
  209. if 'lines' in result:
  210. # 用自己写的函数画线段
  211. # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
  212. print(f"shape of linescore:{result['lines_scores'].shape}")
  213. scores = result['lines_scores'].mean(dim=1) # shape: [31]
  214. line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'],scores, width=3, cmap='jet')
  215. self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
  216. if 'arcs' in result:
  217. arcs = result['arcs'][0]
  218. # img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
  219. # img_tensor =torch.tensor(img_rgb)
  220. # img_tensor = np.transpose(img_tensor)
  221. self.writer.add_image('z-out-arc', arcs, global_step=epoch)
  222. if 'circles' in result:
  223. # points=result['circles']
  224. # points=points.squeeze(1)
  225. ppp=result['circles']
  226. bbb=result['boxes']
  227. print(f'boxes shape:{bbb.shape}')
  228. print(f'ppp:{ppp.shape}')
  229. points = result['circles']
  230. points = points.squeeze(1)
  231. print(f'points shape:{points.shape}')
  232. features = result['features']
  233. circle_image = img.cpu().numpy().transpose((1, 2, 0)) # CHW -> HWC
  234. circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
  235. if isinstance(points, torch.Tensor):
  236. points = points.cpu().numpy()
  237. for point_set in points:
  238. center, radius = fit_circle(point_set)
  239. cx, cy = map(int, center)
  240. circle_image = cv2.circle(circle_image, (cx, cy), int(radius), (0, 0, 255), 2)
  241. for point in point_set:
  242. px, py = map(int, point)
  243. circle_image = cv2.circle(circle_image, (px, py), 3, (0, 255, 255), -1)
  244. img_rgb = cv2.cvtColor(circle_image, cv2.COLOR_BGR2RGB)
  245. img_tensor = img_rgb.transpose((2, 0, 1)) # HWC -> CHW
  246. img_tensor = img_tensor / 255.0 # 归一化到 [0, 1]
  247. # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
  248. self.writer.add_image('z-out-circle', img_tensor, global_step=epoch)
  249. self.writer.add_image('z-feature', features, global_step=epoch)
  250. # cv2.imshow('arc', img_rgb)
  251. # cv2.waitKey(1000000)
  252. def writer_loss(self, losses, epoch, phase='train'):
  253. try:
  254. for key, value in losses.items():
  255. if key == 'loss_wirepoint':
  256. for subdict in losses['loss_wirepoint']['losses']:
  257. for subkey, subvalue in subdict.items():
  258. self.writer.add_scalar(f'{phase}/loss/{subkey}',
  259. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  260. epoch)
  261. elif isinstance(value, torch.Tensor):
  262. self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
  263. except Exception as e:
  264. print(f"TensorBoard logging error: {e}")
  265. def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
  266. cfg = read_yaml(cfg)
  267. # print(f'cfg:{cfg}')
  268. # self.freeze_config = freeze_config or {} # 更新冻结配置
  269. self.train(model, **cfg)
  270. def train(self, model, **kwargs):
  271. self.init_params(**kwargs)
  272. dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
  273. dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='val')
  274. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  275. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  276. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
  277. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
  278. train_collate_fn = utils.collate_fn
  279. val_collate_fn = utils.collate_fn
  280. data_loader_train = torch.utils.data.DataLoader(
  281. dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
  282. )
  283. data_loader_val = torch.utils.data.DataLoader(
  284. dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
  285. )
  286. model.to(device)
  287. optimizer = torch.optim.Adam(
  288. filter(lambda p: p.requires_grad, model.parameters()),
  289. lr=kwargs['train_params']['optim']['lr'],
  290. weight_decay=kwargs['train_params']['optim']['weight_decay'],
  291. )
  292. model, optimizer = self.load_best_model(model, optimizer,
  293. r"\\192.168.50.222\share\rlq\weights\250725_arc_res152_best_val.pth",
  294. device)
  295. # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
  296. scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)
  297. for epoch in range(self.max_epoch):
  298. print(f"train epoch:{epoch}")
  299. model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
  300. scheduler.step(epoch_train_loss)
  301. # ========== Validation ==========
  302. with torch.no_grad():
  303. model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
  304. scheduler.step(epoch_val_loss)
  305. if epoch==0:
  306. best_train_loss = epoch_train_loss
  307. best_val_loss = epoch_val_loss
  308. self.save_last_model(model,self.last_model_path, epoch, optimizer)
  309. best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
  310. best_train_loss,
  311. optimizer)
  312. best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
  313. optimizer)
  314. def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
  315. if phase == 'train':
  316. model.train()
  317. if phase == 'val':
  318. model.eval()
  319. total_loss = 0
  320. epoch_step = 0
  321. global_step = epoch * len(data_loader)
  322. for imgs, targets in data_loader:
  323. imgs = self.move_to_device(imgs, device)
  324. targets = self.move_to_device(targets, device)
  325. if phase== 'val':
  326. result,loss_dict = model(imgs, targets)
  327. losses = sum(loss_dict.values())
  328. print(f'val losses:{losses}')
  329. # print(f'val result:{result}')
  330. else:
  331. loss_dict = model(imgs, targets)
  332. losses = sum(loss_dict.values())
  333. print(f'train losses:{losses}')
  334. # loss = _loss(losses)
  335. loss=losses
  336. total_loss += loss.item()
  337. if phase == 'train':
  338. optimizer.zero_grad()
  339. loss.backward()
  340. optimizer.step()
  341. self.writer_loss(loss_dict, global_step, phase=phase)
  342. global_step += 1
  343. if epoch_step == 0 and phase == 'val':
  344. t_start = time.time()
  345. print(f'start to predict:{t_start}')
  346. result = model(self.move_to_device(imgs, self.device))
  347. # print(f'result:{result}')
  348. t_end = time.time()
  349. print(f'predict used:{t_end - t_start}')
  350. self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
  351. epoch_step+=1
  352. avg_loss = total_loss / len(data_loader)
  353. print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
  354. self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
  355. return model, avg_loss
  356. def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
  357. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  358. if current_loss <= best_loss:
  359. checkpoint = {
  360. 'epoch': epoch,
  361. 'model_state_dict': model.state_dict(),
  362. 'loss': current_loss
  363. }
  364. if optimizer is not None:
  365. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  366. torch.save(checkpoint, save_path)
  367. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  368. return current_loss
  369. return best_loss
  370. def save_last_model(self, model, save_path, epoch, optimizer=None):
  371. if os.path.exists(f'{self.wts_path}/last.pt'):
  372. os.remove(f'{self.wts_path}/last.pt')
  373. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  374. checkpoint = {
  375. 'epoch': epoch,
  376. 'model_state_dict': model.state_dict(),
  377. }
  378. if optimizer is not None:
  379. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  380. torch.save(checkpoint, save_path)
  381. if __name__ == '__main__':
  382. print('')