| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672 |
- import os
- import time
- from datetime import datetime
- import cv2
- import numpy as np
- import torch
- from PIL.ImageDraw import ImageDraw
- from matplotlib import pyplot as plt
- from scipy.ndimage import gaussian_filter
- from torch.optim.lr_scheduler import ReduceLROnPlateau
- from torch.utils.tensorboard import SummaryWriter
- from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
- from models.base.base_model import BaseModel
- from models.base.base_trainer import BaseTrainer
- from models.config.config_tool import read_yaml
- from models.line_detect.line_dataset import LineDataset
- import torch.nn.functional as F
- from tools import utils
- import matplotlib as mpl
- from utils.data_process.show_prams import print_params
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def _loss(losses):
- total_loss = 0
- for i in losses.keys():
- if i != "loss_wirepoint":
- total_loss += losses[i]
- else:
- loss_labels = losses[i]["losses"]
- loss_labels_k = list(loss_labels[0].keys())
- for j, name in enumerate(loss_labels_k):
- loss = loss_labels[0][name].mean()
- total_loss += loss
- return total_loss
- def c(x):
- return sm.to_rgba(x)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def draw_ellipses_on_image(image, masks_pred, threshold=0.5, color=(0, 255, 0), thickness=2):
- """
- å¨åå¼ åå§å¾åä¸ç»å¶ä» masks æååºçæ¤åã
- èªå¨å° masks resize å° image ç空é´å°ºå¯¸ã
- Args:
- image: Tensor [3, H_img, W_img] ââ åå§å¾åï¼å¦ [3, 2000, 2000]ï¼
- masks_pred: Tensor [N, 1, H_mask, W_mask] or [N, H_mask, W_mask] ââ æ¨¡åè¾åº maskï¼å¦ [2, 1, 672, 672]ï¼
- threshold: äºå¼åéå¼
- color: BGR color for OpenCV
- thickness: ellipse line thickness
- Returns:
- drawn_image: numpy array [H_img, W_img, 3] in RGB
- """
- # Step 1: æ åå masks_pred to [N, H, W]
- if masks_pred.ndim == 4:
- if masks_pred.shape[1] == 1:
- masks_pred = masks_pred.squeeze(1) # [N, 1, H, W] -> [N, H, W]
- else:
- raise ValueError(f"Expected channel=1 in masks_pred, got shape {masks_pred.shape}")
- elif masks_pred.ndim != 3:
- raise ValueError(f"masks_pred must be 3D (N, H, W) or 4D (N, 1, H, W), got {masks_pred.shape}")
- N, H_mask, W_mask = masks_pred.shape
- C, H_img, W_img = image.shape
- # Step 2: Resize masks to original image size using bilinear interpolation
- masks_resized = F.interpolate(
- masks_pred.unsqueeze(1).float(), # [N, 1, H_mask, W_mask]
- size=(H_img, W_img),
- mode='bilinear',
- align_corners=False
- ).squeeze(1) # [N, H_img, W_img]
- # Step 3: Convert image to numpy RGB
- img_tensor = image.detach().cpu()
- if img_tensor.max() <= 1.0:
- img_np = (img_tensor * 255).byte().numpy() # [3, H, W]
- else:
- img_np = img_tensor.byte().numpy()
- img_rgb = np.transpose(img_np, (1, 2, 0)) # [H, W, 3]
- img_out = img_rgb.copy()
- # Step 4: Process each mask
- for mask in masks_resized:
- mask_cpu = mask.detach().cpu()
- mask_prob = torch.sigmoid(mask_cpu) if mask_cpu.min() < 0 else mask_cpu
- binary = (mask_prob > threshold).numpy().astype(np.uint8) * 255 # [H_img, W_img]
- contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- if contours:
- largest_contour = max(contours, key=cv2.contourArea)
- if len(largest_contour) >= 5:
- try:
- ellipse = cv2.fitEllipse(largest_contour)
- img_bgr = cv2.cvtColor(img_out, cv2.COLOR_RGB2BGR)
- cv2.ellipse(img_bgr, ellipse, color=color, thickness=thickness)
- img_out = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
- except cv2.error as e:
- print(f"Warning: Failed to fit ellipse: {e}")
- return img_out
- def fit_circle(points):
- """
- Fit a circle to a set of points (at least 3).
- Args:
- points: torch.Tensor æ numpy array, shape (N, 2)
- Returns:
- center (cx, cy), radius r
- """
- # å¦ææ¯ torch.Tensorï¼å
转为 numpy
- if isinstance(points, torch.Tensor):
- if points.dim() == 3:
- points = points[0] # 廿 batch 维度
- points = points.detach().cpu().numpy()
- if not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 2):
- raise ValueError(f"Expected points shape (N, 2), got {points.shape}")
- x = points[:, 0].astype(float)
- y = points[:, 1].astype(float)
- # ç¡®ä¿ A æ¯äºç»´æ°ç»
- A = np.column_stack((x, y, np.ones_like(x))) # ä½¿ç¨ column_stack ä»£æ¿ stack å¯è½æ´æ¸
æ°
- b = -(x ** 2 + y ** 2)
- try:
- sol, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)
- except np.linalg.LinAlgError as e:
- print(f"Linear algebra error occurred: {e}")
- raise ValueError("Could not fit circle to points.")
- D, E, F = sol
- cx = -D / 2.0
- cy = -E / 2.0
- r = np.sqrt(cx ** 2 + cy ** 2 - F)
- return (cx, cy), r
- from PIL import ImageDraw, Image
- import io
- def draw_el(all, background_img):
- """
- all = [x_center, y_center, a, b, theta, x1, y1, x2, y2]
- theta: ellipse rotation (degrees)
- (x1, y1): start point
- (x2, y2): end point
- """
- if isinstance(all, torch.Tensor):
- all = all.cpu().numpy()
- # Unpack parameters
- cx, cy, a, b, theta_deg, x1, y1, x2, y2 = all
- theta = np.radians(theta_deg)
- # ====== Draw ellipse ======
- phi = np.linspace(0, np.pi * 2, 500)
- x_ellipse = cx + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
- y_ellipse = cy + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
- # ====== Draw image ======
- plt.figure(figsize=(10, 10))
- plt.imshow(background_img)
- # Ellipse
- plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
- # Center
- plt.plot(cx, cy, 'ko', markersize=8)
- # Start & End points (now real coordinates)
- plt.plot(x1, y1, 'ro', markersize=10)
- plt.plot(x2, y2, 'go', markersize=10)
- # ====== Convert to tensor ======
- buf = io.BytesIO()
- plt.savefig(buf, format='png', bbox_inches='tight')
- buf.seek(0)
- result_img = Image.open(buf).convert('RGB')
- img_tensor = torch.from_numpy(np.array(result_img)).permute(2, 0, 1)
- plt.close()
- return img_tensor
- # from PIL import ImageDraw, Image
- # import io
- # # 绘制椭圆
- # def draw_el(all, background_img):
- # # 解析椭圆参数
- # if isinstance(all, torch.Tensor):
- # all = all.cpu().numpy()
- # print_params(all)
- # x, y, a, b, q, q1, q2 = all
- # theta = np.radians(q)
- # phi1 = np.radians(q1) # 第一个点的参数角
- # phi2 = np.radians(q2) # 第二个点的参数角
- #
- # # 生成椭圆上的点
- # phi = np.linspace(0, 2 * np.pi, 500)
- # x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
- # y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
- #
- # # 计算两个指定点的坐标
- # def param_to_point(phi, xc, yc, a, b, theta):
- # x = xc + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
- # y = yc + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
- # return x, y
- #
- # P1 = param_to_point(phi1, x, y, a, b, theta)
- # P2 = param_to_point(phi2, x, y, a, b, theta)
- #
- # # 创建画布并显示背景图片(使用传入的background_img,shape为[H, W, C])
- # plt.figure(figsize=(10, 10))
- # plt.imshow(background_img) # 直接显示背景图
- #
- # # 绘制椭圆及相关元素
- # plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
- # plt.plot(x, y, 'ko', markersize=8)
- # plt.plot(P1[0], P1[1], 'ro', markersize=10)
- # plt.plot(P2[0], P2[1], 'go', markersize=10)
- # 转换为TensorBoard所需的张量格式 [C, H, W]
- # buf = io.BytesIO()
- # plt.savefig(buf, format='png', bbox_inches='tight')
- # buf.seek(0)
- # result_img = Image.open(buf).convert('RGB')
- # img_tensor = torch.from_numpy(np.array(result_img)).permute(2, 0, 1)
- # plt.close()
- #
- # return img_tensor
- # 由低到高蓝黄红
- def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
- """
- 根据得分对线段着色并绘制
- :param tensor_image: (3, H, W) uint8 图像
- :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
- :param scores: (N,) 每条线的得分,范围 [0, 1]
- :param width: 线宽
- :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
- :return: (3, H, W) uint8 画好线的图像
- """
- assert tensor_image.dtype == torch.uint8
- assert tensor_image.shape[0] == 3
- assert lines.shape[0] == scores.shape[0]
- # 准备色图
- colormap = plt.get_cmap(cmap)
- colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8') # 去掉 alpha 通道
- # 转为 PIL 画图
- image_pil = F.to_pil_image(tensor_image)
- draw = ImageDraw.Draw(image_pil)
- for line, color in zip(lines, colors):
- start = tuple(map(float, line[0][:2].tolist()))
- end = tuple(map(float, line[1][:2].tolist()))
- draw.line([start, end], fill=tuple(color), width=width)
- return (F.to_tensor(image_pil) * 255).to(torch.uint8)
- class Trainer(BaseTrainer):
- def __init__(self, model=None, **kwargs):
- super().__init__(model, device, **kwargs)
- self.model = model
- # print(f'kwargs:{kwargs}')
- self.init_params(**kwargs)
- def init_params(self, **kwargs):
- if kwargs != {}:
- print(f'train_params:{kwargs["train_params"]}')
- self.freeze_config = kwargs['train_params']['freeze_params']
- print(f'freeze_config:{self.freeze_config}')
- self.dataset_path = kwargs['io']['datadir']
- self.data_type = kwargs['io']['data_type']
- self.batch_size = kwargs['train_params']['batch_size']
- self.num_workers = kwargs['train_params']['num_workers']
- self.logdir = kwargs['io']['logdir']
- self.resume_from = kwargs['train_params']['resume_from']
- self.optim = ''
- self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
- self.wts_path = os.path.join(self.train_result_ptath, 'weights')
- self.tb_path = os.path.join(self.train_result_ptath, 'logs')
- self.writer = SummaryWriter(self.tb_path)
- self.last_model_path = os.path.join(self.wts_path, 'last.pth')
- self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
- self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
- self.max_epoch = kwargs['train_params']['max_epoch']
- self.augmentation= kwargs['train_params']["augmentation"]
- def move_to_device(self, data, device):
- if isinstance(data, (list, tuple)):
- return type(data)(self.move_to_device(item, device) for item in data)
- elif isinstance(data, dict):
- return {key: self.move_to_device(value, device) for key, value in data.items()}
- elif isinstance(data, torch.Tensor):
- return data.to(device)
- else:
- return data # 对于非张量类型的数据不做任何改变
- def freeze_params(self, model):
- """根据配置冻结模型参数"""
- default_config = {
- 'backbone': True, # 冻结 backbone
- 'rpn': False, # 不冻结 rpn
- 'roi_heads': {
- 'box_head': False,
- 'box_predictor': False,
- 'line_head': False,
- 'line_predictor': {
- 'fc1': False,
- 'fc2': {
- '0': False,
- '2': False,
- '4': False
- }
- }
- }
- }
- # 更新默认配置
- default_config.update(self.freeze_config)
- config = default_config
- print("\n===== Parameter Freezing Configuration =====")
- for name, module in model.named_children():
- if name in config:
- if isinstance(config[name], bool):
- for param in module.parameters():
- param.requires_grad = not config[name]
- print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
- elif isinstance(config[name], dict):
- for subname, submodule in module.named_children():
- if subname in config[name]:
- if isinstance(config[name][subname], bool):
- for param in submodule.parameters():
- param.requires_grad = not config[name][subname]
- print(
- f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
- elif isinstance(config[name][subname], dict):
- for subsubname, subsubmodule in submodule.named_children():
- if subsubname in config[name][subname]:
- for param in subsubmodule.parameters():
- param.requires_grad = not config[name][subname][subsubname]
- print(
- f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
- # 打印参数统计
- total_params = sum(p.numel() for p in model.parameters())
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"\nTotal Parameters: {total_params:,}")
- print(f"Trainable Parameters: {trainable_params:,}")
- print(f"Frozen Parameters: {total_params - trainable_params:,}")
- def load_best_model(self, model, optimizer, save_path, device):
- if os.path.exists(save_path):
- checkpoint = torch.load(save_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- if optimizer is not None:
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
- print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
- else:
- print(f"No saved model found at {save_path}")
- return model, optimizer
- def writer_predict_result(self, img, result, epoch,):
- img = img.cpu().detach()
- im = img.permute(1, 2, 0) # [512, 512, 3]
- self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result["boxes"],
- colors="yellow", width=1)
- # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
- # plt.show()
- self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
- if 'points' in result:
- keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
- self.writer.add_image("z-output", keypoint_img, epoch)
- # print("lines shape:", result['lines'].shape)
- if 'lines' in result:
- # 用自己写的函数画线段
- # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
- print(f"shape of linescore:{result['lines_scores'].shape}")
- scores = result['lines_scores'].mean(dim=1) # shape: [31]
- line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'],scores, width=3, cmap='jet')
- self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
- if 'arcs' in result:
- arcs = result['arcs'][0]
- print(f'arcs in dra w:{arcs}')
- ellipse_img = draw_el(arcs, background_img=im)
- # img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
- #
- # img_tensor =torch.tensor(img_rgb)
- # img_tensor = np.transpose(img_tensor)
- self.writer.add_image('z-out-arc', ellipse_img, global_step=epoch)
- if 'ins_masks' in result:
- # points=result['circles']
- # points=points.squeeze(1)
- ppp=result['ins_masks']
- bbb=result['boxes']
- print(f'boxes shape:{bbb.shape}')
- print(f'ppp:{ppp.shape}')
- ins_masks = result['ins_masks']
- ins_masks = ins_masks.squeeze(1)
- print(f'ins_masks shape:{ins_masks.shape}')
- features = result['features']
- circle_image = img.cpu().numpy().transpose((1, 2, 0)) # CHW -> HWC
- circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
- sum_mask = ins_masks.sum(dim=0, keepdim=True)
- sum_mask = sum_mask / (sum_mask.max() + 1e-8)
- # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
- self.writer.add_image('z-ins-masks', sum_mask.squeeze(0), global_step=epoch)
- result_imgs = draw_ellipses_on_image(img, ins_masks, threshold=0.5)
- self.writer.add_image('z-out-ellipses', result_imgs, dataformats='HWC', global_step= epoch)
- features=self.apply_gaussian_blur_to_tensor(features,sigma=3)
- self.writer.add_image('z-feature', features, global_step=epoch)
- # cv2.imshow('arc', img_rgb)
- # cv2.waitKey(1000000)
- def normalize_tensor(self,tensor):
- """Normalize tensor to [0, 1]"""
- min_val = tensor.min()
- max_val = tensor.max()
- return (tensor - min_val) / (max_val - min_val)
- def apply_gaussian_blur_to_tensor(self,feature_map, sigma=3):
- """
- Apply Gaussian blur to a feature map and convert it into an RGB heatmap.
- :param feature_map: Tensor of shape (H, W) or (1, H, W)
- :param sigma: Standard deviation for Gaussian kernel
- :return: Tensor of shape (3, H, W) representing the RGB heatmap
- """
- if feature_map.dim() == 3:
- if feature_map.shape[0] != 1:
- raise ValueError("Only single-channel feature map supported.")
- feature_map = feature_map.squeeze(0)
- # Normalize to [0, 1]
- normalized_feat = self.normalize_tensor(feature_map).cpu().numpy()
- # Apply Gaussian blur
- blurred_feat = gaussian_filter(normalized_feat, sigma=sigma)
- # Convert to colormap (e.g., 'jet')
- colormap = plt.get_cmap('jet')
- colored = colormap(blurred_feat) # shape: (H, W, 4) RGBA
- # Convert to (3, H, W), drop alpha channel
- colored_rgb = colored[:, :, :3] # (H, W, 3)
- colored_tensor = torch.from_numpy(colored_rgb).permute(2, 0, 1) # (3, H, W)
- return colored_tensor.float()
- def writer_loss(self, losses, epoch, phase='train'):
- try:
- for key, value in losses.items():
- if key == 'loss_wirepoint':
- for subdict in losses['loss_wirepoint']['losses']:
- for subkey, subvalue in subdict.items():
- self.writer.add_scalar(f'{phase}/loss/{subkey}',
- subvalue.item() if hasattr(subvalue, 'item') else subvalue,
- epoch)
- elif isinstance(value, torch.Tensor):
- self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
- except Exception as e:
- print(f"TensorBoard logging error: {e}")
- def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置
- cfg = read_yaml(cfg)
- # print(f'cfg:{cfg}')
- # self.freeze_config = freeze_config or {} # 更新冻结配置
- self.train(model, **cfg)
- def train(self, model, **kwargs):
- self.init_params(**kwargs)
- dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
- dataset_val = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='val')
- train_sampler = torch.utils.data.RandomSampler(dataset_train)
- val_sampler = torch.utils.data.RandomSampler(dataset_val)
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
- train_collate_fn = utils.collate_fn
- val_collate_fn = utils.collate_fn
- data_loader_train = torch.utils.data.DataLoader(
- dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
- )
- data_loader_val = torch.utils.data.DataLoader(
- dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
- )
- model.to(device)
- optimizer = torch.optim.Adam(
- filter(lambda p: p.requires_grad, model.parameters()),
- lr=kwargs['train_params']['optim']['lr'],
- weight_decay=kwargs['train_params']['optim']['weight_decay'],
- )
- model, optimizer = self.load_best_model(model, optimizer,
- r"\\192.168.50.222\share\rlq\weights\250725_arc_res152_best_val.pth",
- device)
- # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
- scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)
- for epoch in range(self.max_epoch):
- print(f"train epoch:{epoch}")
- model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
- scheduler.step(epoch_train_loss)
- # ========== Validation ==========
- with torch.no_grad():
- model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
- scheduler.step(epoch_val_loss)
- if epoch==0:
- best_train_loss = epoch_train_loss
- best_val_loss = epoch_val_loss
- self.save_last_model(model,self.last_model_path, epoch, optimizer)
- best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
- best_train_loss,
- optimizer)
- best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
- optimizer)
- def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
- if phase == 'train':
- model.train()
- if phase == 'val':
- model.eval()
- total_loss = 0
- epoch_step = 0
- global_step = epoch * len(data_loader)
- for imgs, targets in data_loader:
- imgs = self.move_to_device(imgs, device)
- targets = self.move_to_device(targets, device)
- if phase== 'val':
- result,loss_dict = model(imgs, targets)
- losses = sum(loss_dict.values())
- print(f'val losses:{losses}')
- # print(f'val result:{result}')
- else:
- loss_dict = model(imgs, targets)
- losses = sum(loss_dict.values())
- print(f'train losses:{losses}')
- # loss = _loss(losses)
- loss=losses
- total_loss += loss.item()
- if phase == 'train':
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- self.writer_loss(loss_dict, global_step, phase=phase)
- global_step += 1
- if epoch_step == 0 and phase == 'val':
- t_start = time.time()
- print(f'start to predict:{t_start}')
- result = model(self.move_to_device(imgs, self.device))
- # print(f'result:{result}')
- t_end = time.time()
- print(f'predict used:{t_end - t_start}')
- from utils.data_process.show_prams import print_params
- print_params(imgs[0], result[0], epoch)
- self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)
- epoch_step+=1
- avg_loss = total_loss / len(data_loader)
- print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
- self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
- return model, avg_loss
- def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- if current_loss <= best_loss:
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'loss': current_loss
- }
- if optimizer is not None:
- checkpoint['optimizer_state_dict'] = optimizer.state_dict()
- torch.save(checkpoint, save_path)
- print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
- return current_loss
- return best_loss
- def save_last_model(self, model, save_path, epoch, optimizer=None):
- if os.path.exists(f'{self.wts_path}/last.pt'):
- os.remove(f'{self.wts_path}/last.pt')
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- }
- if optimizer is not None:
- checkpoint['optimizer_state_dict'] = optimizer.state_dict()
- torch.save(checkpoint, save_path)
- if __name__ == '__main__':
- print('')
|