trainer.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. import os
  2. import time
  3. from datetime import datetime
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import torchvision
  8. from PIL.ImageDraw import ImageDraw
  9. from matplotlib import pyplot as plt
  10. from scipy.ndimage import gaussian_filter
  11. from torch.optim.lr_scheduler import ReduceLROnPlateau
  12. from torch.utils.tensorboard import SummaryWriter
  13. from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
  14. from models.base.base_model import BaseModel
  15. from models.base.base_trainer import BaseTrainer
  16. from models.config.config_tool import read_yaml
  17. from models.line_detect.line_dataset import LineDataset
  18. import torch.nn.functional as F
  19. from torchvision.transforms import functional as TF
  20. from tools import utils
  21. import matplotlib as mpl
  22. from utils.data_process.show_prams import print_params
  23. cmap = plt.get_cmap("jet")
  24. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  25. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  26. sm.set_array([])
  27. def _loss(losses):
  28. total_loss = 0
  29. for i in losses.keys():
  30. if i != "loss_wirepoint":
  31. total_loss += losses[i]
  32. else:
  33. loss_labels = losses[i]["losses"]
  34. loss_labels_k = list(loss_labels[0].keys())
  35. for j, name in enumerate(loss_labels_k):
  36. loss = loss_labels[0][name].mean()
  37. total_loss += loss
  38. return total_loss
  39. def c(x):
  40. return sm.to_rgba(x)
  41. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  42. def draw_ellipses_on_image(image, masks_pred, threshold=0.5, color=(0, 255, 0), thickness=2):
  43. """
  44. 在单张原始图像上绘制从 masks 拟合出的椭圆。
  45. 自动将 masks resize 到 image 的空间尺寸。
  46. Args:
  47. image: Tensor [3, H_img, W_img] —— 原始图像(如 [3, 2000, 2000])
  48. masks_pred: Tensor [N, 1, H_mask, W_mask] or [N, H_mask, W_mask] —— 模型输出 mask(如 [2, 1, 672, 672])
  49. threshold: 二值化阈值
  50. color: BGR color for OpenCV
  51. thickness: ellipse line thickness
  52. Returns:
  53. drawn_image: numpy array [H_img, W_img, 3] in RGB
  54. """
  55. # Step 1: 标准化 masks_pred to [N, H, W]
  56. if masks_pred.ndim == 4:
  57. if masks_pred.shape[1] == 1:
  58. masks_pred = masks_pred.squeeze(1) # [N, 1, H, W] -> [N, H, W]
  59. else:
  60. raise ValueError(f"Expected channel=1 in masks_pred, got shape {masks_pred.shape}")
  61. elif masks_pred.ndim != 3:
  62. raise ValueError(f"masks_pred must be 3D (N, H, W) or 4D (N, 1, H, W), got {masks_pred.shape}")
  63. N, H_mask, W_mask = masks_pred.shape
  64. C, H_img, W_img = image.shape
  65. # Step 2: Resize masks to original image size using bilinear interpolation
  66. masks_resized = F.interpolate(
  67. masks_pred.unsqueeze(1).float(), # [N, 1, H_mask, W_mask]
  68. size=(H_img, W_img),
  69. mode='bilinear',
  70. align_corners=False
  71. ).squeeze(1) # [N, H_img, W_img]
  72. # Step 3: Convert image to numpy RGB
  73. img_tensor = image.detach().cpu()
  74. if img_tensor.max() <= 1.0:
  75. img_np = (img_tensor * 255).byte().numpy() # [3, H, W]
  76. else:
  77. img_np = img_tensor.byte().numpy()
  78. img_rgb = np.transpose(img_np, (1, 2, 0)) # [H, W, 3]
  79. img_out = img_rgb.copy()
  80. # Step 4: Process each mask
  81. for mask in masks_resized:
  82. mask_cpu = mask.detach().cpu()
  83. mask_prob = torch.sigmoid(mask_cpu) if mask_cpu.min() < 0 else mask_cpu
  84. binary = (mask_prob > threshold).numpy().astype(np.uint8) * 255 # [H_img, W_img]
  85. contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  86. if contours:
  87. largest_contour = max(contours, key=cv2.contourArea)
  88. if len(largest_contour) >= 5:
  89. try:
  90. ellipse = cv2.fitEllipse(largest_contour)
  91. img_bgr = cv2.cvtColor(img_out, cv2.COLOR_RGB2BGR)
  92. cv2.ellipse(img_bgr, ellipse, color=color, thickness=thickness)
  93. img_out = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  94. except cv2.error as e:
  95. print(f"Warning: Failed to fit ellipse: {e}")
  96. return img_out
  97. def fit_circle(points):
  98. """
  99. Fit a circle to a set of points (at least 3).
  100. Args:
  101. points: torch.Tensor 或 numpy array, shape (N, 2)
  102. Returns:
  103. center (cx, cy), radius r
  104. """
  105. # 如果是 torch.Tensor,先转为 numpy
  106. if isinstance(points, torch.Tensor):
  107. if points.dim() == 3:
  108. points = points[0] # 去掉 batch 维度
  109. points = points.detach().cpu().numpy()
  110. if not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 2):
  111. raise ValueError(f"Expected points shape (N, 2), got {points.shape}")
  112. x = points[:, 0].astype(float)
  113. y = points[:, 1].astype(float)
  114. # 确保 A 是二维数组
  115. A = np.column_stack((x, y, np.ones_like(x))) # 使用 column_stack 代替 stack 可能更清晰
  116. b = -(x ** 2 + y ** 2)
  117. try:
  118. sol, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)
  119. except np.linalg.LinAlgError as e:
  120. print(f"Linear algebra error occurred: {e}")
  121. raise ValueError("Could not fit circle to points.")
  122. D, E, F = sol
  123. cx = -D / 2.0
  124. cy = -E / 2.0
  125. r = np.sqrt(cx ** 2 + cy ** 2 - F)
  126. return (cx, cy), r
  127. from PIL import ImageDraw, Image
  128. import io
  129. def draw_el(all, background_img):
  130. """
  131. all = [x_center, y_center, a, b, theta, x1, y1, x2, y2]
  132. theta: ellipse rotation (degrees)
  133. (x1, y1): start point
  134. (x2, y2): end point
  135. """
  136. if isinstance(all, torch.Tensor):
  137. all = all.cpu().numpy()
  138. # Unpack parameters
  139. cx, cy, a, b, theta_deg, x1, y1, x2, y2 = all
  140. # cx = cx / 672 * 2000
  141. # cy = cy / 672 * 2000
  142. # # a = a / 672 * 2000
  143. # # b = b / 672 * 2000
  144. # x1 = x1 / 672 * 2000
  145. # y1 = y1 / 672 * 2000
  146. # x2 = x2 / 672 * 2000
  147. # y2 = y2 / 672 * 2000
  148. theta = np.radians(theta_deg)
  149. # ====== Draw ellipse ======
  150. phi = np.linspace(0, np.pi * 2, 500)
  151. x_ellipse = cx + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
  152. y_ellipse = cy + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
  153. # ====== Draw image ======
  154. plt.figure(figsize=(10, 10))
  155. plt.imshow(background_img)
  156. # Ellipse
  157. plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
  158. # Center
  159. plt.plot(cx, cy, 'ko', markersize=8)
  160. # Start & End points (now real coordinates)
  161. plt.plot(x1, y1, 'ro', markersize=10)
  162. plt.plot(x2, y2, 'go', markersize=10)
  163. # ====== Convert to tensor ======
  164. buf = io.BytesIO()
  165. plt.savefig(buf, format='png', bbox_inches='tight')
  166. buf.seek(0)
  167. result_img = Image.open(buf).convert('RGB')
  168. img_tensor = torch.from_numpy(np.array(result_img)).permute(2, 0, 1)
  169. plt.close()
  170. return img_tensor
  171. # from PIL import ImageDraw, Image
  172. # import io
  173. # # 绘制椭圆
  174. # def draw_el(all, background_img):
  175. # # 解析椭圆参数
  176. # if isinstance(all, torch.Tensor):
  177. # all = all.cpu().numpy()
  178. # print_params(all)
  179. # x, y, a, b, q, q1, q2 = all
  180. # theta = np.radians(q)
  181. # phi1 = np.radians(q1) # 第一个点的参数角
  182. # phi2 = np.radians(q2) # 第二个点的参数角
  183. #
  184. # # 生成椭圆上的点
  185. # phi = np.linspace(0, 2 * np.pi, 500)
  186. # x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
  187. # y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
  188. #
  189. # # 计算两个指定点的坐标
  190. # def param_to_point(phi, xc, yc, a, b, theta):
  191. # x = xc + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
  192. # y = yc + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
  193. # return x, y
  194. #
  195. # P1 = param_to_point(phi1, x, y, a, b, theta)
  196. # P2 = param_to_point(phi2, x, y, a, b, theta)
  197. #
  198. # # 创建画布并显示背景图片(使用传入的background_img,shape为[H, W, C])
  199. # plt.figure(figsize=(10, 10))
  200. # plt.imshow(background_img) # 直接显示背景图
  201. #
  202. # # 绘制椭圆及相关元素
  203. # plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
  204. # plt.plot(x, y, 'ko', markersize=8)
  205. # plt.plot(P1[0], P1[1], 'ro', markersize=10)
  206. # plt.plot(P2[0], P2[1], 'go', markersize=10)
  207. # 转换为TensorBoard所需的张量格式 [C, H, W]
  208. # buf = io.BytesIO()
  209. # plt.savefig(buf, format='png', bbox_inches='tight')
  210. # buf.seek(0)
  211. # result_img = Image.open(buf).convert('RGB')
  212. # img_tensor = torch.from_numpy(np.array(result_img)).permute(2, 0, 1)
  213. # plt.close()
  214. #
  215. # return img_tensor
  216. # 由低到高蓝黄红
  217. def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
  218. """
  219. 根据得分对线段着色并绘制
  220. :param tensor_image: (3, H, W) uint8 图像
  221. :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
  222. :param scores: (N,) 每条线的得分,范围 [0, 1]
  223. :param width: 线宽
  224. :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
  225. :return: (3, H, W) uint8 画好线的图像
  226. """
  227. assert tensor_image.dtype == torch.uint8
  228. assert tensor_image.shape[0] == 3
  229. assert lines.shape[0] == scores.shape[0]
  230. # 准备色图
  231. colormap = plt.get_cmap(cmap)
  232. colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8') # 去掉 alpha 通道
  233. # 转为 PIL 画图
  234. image_pil = TF.to_pil_image(tensor_image)
  235. draw = ImageDraw.Draw(image_pil)
  236. for line, color in zip(lines, colors):
  237. start = tuple(map(float, line[0][:2].tolist()))
  238. end = tuple(map(float, line[1][:2].tolist()))
  239. draw.line([start, end], fill=tuple(color), width=width)
  240. return (torchvision.transforms.functional.to_tensor(image_pil) * 255).to(torch.uint8)
  241. class Trainer(BaseTrainer):
  242. def __init__(self, model=None, **kwargs):
  243. super().__init__(model, device, **kwargs)
  244. self.model = model
  245. # print(f'kwargs:{kwargs}')
  246. self.init_params(**kwargs)
  247. def init_params(self, **kwargs):
  248. if kwargs != {}:
  249. print(f'train_params:{kwargs["train_params"]}')
  250. self.freeze_config = kwargs['train_params']['freeze_params']
  251. print(f'freeze_config:{self.freeze_config}')
  252. self.dataset_path = kwargs['io']['datadir']
  253. self.data_type = kwargs['io']['data_type']
  254. self.batch_size = kwargs['train_params']['batch_size']
  255. self.num_workers = kwargs['train_params']['num_workers']
  256. self.logdir = kwargs['io']['logdir']
  257. self.resume_from = kwargs['train_params']['resume_from']
  258. self.optim = ''
  259. self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
  260. self.wts_path = os.path.join(self.train_result_ptath, 'weights')
  261. self.tb_path = os.path.join(self.train_result_ptath, 'logs')
  262. self.writer = SummaryWriter(self.tb_path)
  263. self.last_model_path = os.path.join(self.wts_path, 'last.pth')
  264. self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
  265. self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
  266. self.max_epoch = kwargs['train_params']['max_epoch']
  267. self.augmentation = kwargs['train_params']["augmentation"]
  268. def move_to_device(self, data, device):
  269. if isinstance(data, (list, tuple)):
  270. return type(data)(self.move_to_device(item, device) for item in data)
  271. elif isinstance(data, dict):
  272. return {key: self.move_to_device(value, device) for key, value in data.items()}
  273. elif isinstance(data, torch.Tensor):
  274. return data.to(device)
  275. else:
  276. return data # 对于非张量类型的数据不做任何改变
  277. def freeze_params(self, model):
  278. """根据配置冻结模型参数"""
  279. default_config = {
  280. 'backbone': True, # 冻结 backbone
  281. 'rpn': False, # 不冻结 rpn
  282. 'roi_heads': {
  283. 'box_head': False,
  284. 'box_predictor': False,
  285. 'line_head': False,
  286. 'line_predictor': {
  287. 'fc1': False,
  288. 'fc2': {
  289. '0': False,
  290. '2': False,
  291. '4': False
  292. }
  293. }
  294. }
  295. }
  296. # 更新默认配置
  297. default_config.update(self.freeze_config)
  298. config = default_config
  299. print("\n===== Parameter Freezing Configuration =====")
  300. for name, module in model.named_children():
  301. if name in config:
  302. if isinstance(config[name], bool):
  303. for param in module.parameters():
  304. param.requires_grad = not config[name]
  305. print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
  306. elif isinstance(config[name], dict):
  307. for subname, submodule in module.named_children():
  308. if subname in config[name]:
  309. if isinstance(config[name][subname], bool):
  310. for param in submodule.parameters():
  311. param.requires_grad = not config[name][subname]
  312. print(
  313. f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
  314. elif isinstance(config[name][subname], dict):
  315. for subsubname, subsubmodule in submodule.named_children():
  316. if subsubname in config[name][subname]:
  317. for param in subsubmodule.parameters():
  318. param.requires_grad = not config[name][subname][subsubname]
  319. print(
  320. f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
  321. # 打印参数统计
  322. total_params = sum(p.numel() for p in model.parameters())
  323. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  324. print(f"\nTotal Parameters: {total_params:,}")
  325. print(f"Trainable Parameters: {trainable_params:,}")
  326. print(f"Frozen Parameters: {total_params - trainable_params:,}")
  327. def load_best_model(self, model, optimizer, save_path, device):
  328. if os.path.exists(save_path):
  329. checkpoint = torch.load(save_path, map_location=device)
  330. model.load_state_dict(checkpoint['model_state_dict'])
  331. if optimizer is not None:
  332. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  333. epoch = checkpoint['epoch']
  334. loss = checkpoint['loss']
  335. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  336. else:
  337. print(f"No saved model found at {save_path}")
  338. return model, optimizer
  339. def writer_predict_result(self, img, result, epoch, ):
  340. img = img.cpu().detach()
  341. im = img.permute(1, 2, 0) # [512, 512, 3]
  342. self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
  343. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result["boxes"],
  344. colors="yellow", width=1)
  345. # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
  346. # plt.show()
  347. self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  348. if 'points' in result:
  349. keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
  350. self.writer.add_image("z-output", keypoint_img, epoch)
  351. # print("lines shape:", result['lines'].shape)
  352. if 'lines' in result:
  353. # 用自己写的函数画线段
  354. # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
  355. print(f"shape of linescore:{result['lines_scores'].shape}")
  356. scores = result['lines_scores'].mean(dim=1) # shape: [31]
  357. line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'], scores, width=3,
  358. cmap='jet')
  359. self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
  360. if 'arcs' in result:
  361. arcs = result['arcs'][0]
  362. print(f'arcs in draw:{arcs}')
  363. ellipse_img = draw_el(arcs, background_img=im)
  364. # img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
  365. #
  366. # img_tensor =torch.tensor(img_rgb)
  367. # img_tensor = np.transpose(img_tensor)
  368. self.writer.add_image('z-out-arc', ellipse_img, global_step=epoch)
  369. if 'ins_masks' in result:
  370. # points=result['circles']
  371. # points=points.squeeze(1)
  372. ppp = result['ins_masks']
  373. bbb = result['boxes']
  374. print(f'boxes shape:{bbb.shape}')
  375. print(f'ppp:{ppp.shape}')
  376. ins_masks = result['ins_masks']
  377. ins_masks = ins_masks.squeeze(1)
  378. print(f'ins_masks shape:{ins_masks.shape}')
  379. features = result['features']
  380. circle_image = img.cpu().numpy().transpose((1, 2, 0)) # CHW -> HWC
  381. circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
  382. sum_mask = ins_masks.sum(dim=0, keepdim=True)
  383. sum_mask = sum_mask / (sum_mask.max() + 1e-8)
  384. # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
  385. self.writer.add_image('z-ins-masks', sum_mask.squeeze(0), global_step=epoch)
  386. result_imgs = draw_ellipses_on_image(img, ins_masks, threshold=0.5)
  387. self.writer.add_image('z-out-ellipses', result_imgs, dataformats='HWC', global_step=epoch)
  388. features = self.apply_gaussian_blur_to_tensor(features, sigma=3)
  389. self.writer.add_image('z-feature', features, global_step=epoch)
  390. # cv2.imshow('arc', img_rgb)
  391. # cv2.waitKey(1000000)
  392. def normalize_tensor(self, tensor):
  393. """Normalize tensor to [0, 1]"""
  394. min_val = tensor.min()
  395. max_val = tensor.max()
  396. return (tensor - min_val) / (max_val - min_val)
  397. def apply_gaussian_blur_to_tensor(self, feature_map, sigma=3):
  398. """
  399. Apply Gaussian blur to a feature map and convert it into an RGB heatmap.
  400. :param feature_map: Tensor of shape (H, W) or (1, H, W)
  401. :param sigma: Standard deviation for Gaussian kernel
  402. :return: Tensor of shape (3, H, W) representing the RGB heatmap
  403. """
  404. if feature_map.dim() == 3:
  405. if feature_map.shape[0] != 1:
  406. raise ValueError("Only single-channel feature map supported.")
  407. feature_map = feature_map.squeeze(0)
  408. # Normalize to [0, 1]
  409. normalized_feat = self.normalize_tensor(feature_map).cpu().numpy()
  410. # Apply Gaussian blur
  411. blurred_feat = gaussian_filter(normalized_feat, sigma=sigma)
  412. # Convert to colormap (e.g., 'jet')
  413. colormap = plt.get_cmap('jet')
  414. colored = colormap(blurred_feat) # shape: (H, W, 4) RGBA
  415. # Convert to (3, H, W), drop alpha channel
  416. colored_rgb = colored[:, :, :3] # (H, W, 3)
  417. colored_tensor = torch.from_numpy(colored_rgb).permute(2, 0, 1) # (3, H, W)
  418. return colored_tensor.float()
  419. def writer_loss(self, losses, epoch, phase='train'):
  420. try:
  421. for key, value in losses.items():
  422. if key == 'loss_wirepoint':
  423. for subdict in losses['loss_wirepoint']['losses']:
  424. for subkey, subvalue in subdict.items():
  425. self.writer.add_scalar(f'{phase}/loss/{subkey}',
  426. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  427. epoch)
  428. elif isinstance(value, torch.Tensor):
  429. self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
  430. except Exception as e:
  431. print(f"TensorBoard logging error: {e}")
  432. def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
  433. cfg = read_yaml(cfg)
  434. # print(f'cfg:{cfg}')
  435. # self.freeze_config = freeze_config or {} # 更新冻结配置
  436. self.train(model, **cfg)
  437. def train(self, model, **kwargs):
  438. self.init_params(**kwargs)
  439. dataset_train = LineDataset(dataset_path=self.dataset_path, augmentation=self.augmentation,
  440. data_type=self.data_type, dataset_type='train')
  441. dataset_val = LineDataset(dataset_path=self.dataset_path, augmentation=self.augmentation,
  442. data_type=self.data_type, dataset_type='val')
  443. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  444. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  445. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
  446. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
  447. train_collate_fn = utils.collate_fn
  448. val_collate_fn = utils.collate_fn
  449. data_loader_train = torch.utils.data.DataLoader(
  450. dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
  451. )
  452. data_loader_val = torch.utils.data.DataLoader(
  453. dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
  454. )
  455. model.to(device)
  456. optimizer = torch.optim.Adam(
  457. filter(lambda p: p.requires_grad, model.parameters()),
  458. lr=kwargs['train_params']['optim']['lr'],
  459. weight_decay=kwargs['train_params']['optim']['weight_decay'],
  460. )
  461. model, optimizer = self.load_best_model(model, optimizer,
  462. r"\\192.168.50.222\share\rlq\weights\250725_arc_res152_best_val.pth",
  463. device)
  464. # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
  465. scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)
  466. for epoch in range(self.max_epoch):
  467. print(f"train epoch:{epoch}")
  468. model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
  469. scheduler.step(epoch_train_loss)
  470. # ========== Validation ==========
  471. with torch.no_grad():
  472. model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
  473. scheduler.step(epoch_val_loss)
  474. if epoch == 0:
  475. best_train_loss = epoch_train_loss
  476. best_val_loss = epoch_val_loss
  477. self.save_last_model(model, self.last_model_path, epoch, optimizer)
  478. best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
  479. best_train_loss,
  480. optimizer)
  481. best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
  482. optimizer)
  483. def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
  484. if phase == 'train':
  485. model.train()
  486. if phase == 'val':
  487. model.eval()
  488. total_loss = 0
  489. epoch_step = 0
  490. global_step = epoch * len(data_loader)
  491. for imgs, targets in data_loader:
  492. imgs = self.move_to_device(imgs, device)
  493. targets = self.move_to_device(targets, device)
  494. if phase == 'val':
  495. result, loss_dict = model(imgs, targets)
  496. losses = sum(loss_dict.values())
  497. print(f'val losses:{losses}')
  498. # print(f'val result:{result}')
  499. else:
  500. loss_dict = model(imgs, targets)
  501. losses = sum(loss_dict.values())
  502. print(f'train losses:{losses}')
  503. # loss = _loss(losses)
  504. loss = losses
  505. total_loss += loss.item()
  506. if phase == 'train':
  507. optimizer.zero_grad()
  508. loss.backward()
  509. torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
  510. optimizer.step()
  511. self.writer_loss(loss_dict, global_step, phase=phase)
  512. global_step += 1
  513. if epoch_step == 0 and phase == 'val':
  514. t_start = time.time()
  515. print(f'start to predict:{t_start}')
  516. result = model(self.move_to_device(imgs, self.device))
  517. # print(f'result:{result}')
  518. t_end = time.time()
  519. print(f'predict used:{t_end - t_start}')
  520. from utils.data_process.show_prams import print_params
  521. print_params(imgs[0], result[0], epoch)
  522. self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
  523. epoch_step += 1
  524. avg_loss = total_loss / len(data_loader)
  525. print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
  526. self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
  527. return model, avg_loss
  528. def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
  529. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  530. if current_loss <= best_loss:
  531. checkpoint = {
  532. 'epoch': epoch,
  533. 'model_state_dict': model.state_dict(),
  534. 'loss': current_loss
  535. }
  536. if optimizer is not None:
  537. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  538. torch.save(checkpoint, save_path)
  539. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  540. return current_loss
  541. return best_loss
  542. def save_last_model(self, model, save_path, epoch, optimizer=None):
  543. if os.path.exists(f'{self.wts_path}/last.pt'):
  544. os.remove(f'{self.wts_path}/last.pt')
  545. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  546. checkpoint = {
  547. 'epoch': epoch,
  548. 'model_state_dict': model.state_dict(),
  549. }
  550. if optimizer is not None:
  551. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  552. torch.save(checkpoint, save_path)
  553. if __name__ == '__main__':
  554. print('')