import math import os import sys from datetime import datetime import cv2 import numpy as np import torch import torchvision from torch.utils.tensorboard import SummaryWriter from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights from torchvision.utils import draw_bounding_boxes from libs.vision_libs.utils import draw_segmentation_masks from models.config.config_tool import read_yaml from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset from tools import utils, presets def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) header = f"Epoch: [{epoch}]" lr_scheduler = None if epoch == 0: warmup_factor = 1.0 / 1000 warmup_iters = min(1000, len(data_loader) - 1) lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=warmup_factor, total_iters=warmup_iters ) for images, targets in metric_logger.log_every(data_loader, print_freq, header): # print(f'images:{images}') images = list(image.to(device) for image in images) targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] with torch.cuda.amp.autocast(enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() if not math.isfinite(loss_value): print(f"Loss is {loss_value}, stopping training") print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() if scaler is not None: scaler.scale(losses).backward() scaler.step(optimizer) scaler.update() else: losses.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step() metric_logger.update(loss=losses_reduced, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) return metric_logger def load_train_parameter(cfg): parameters = read_yaml(cfg) return parameters def train_cfg(model, cfg): parameters = read_yaml(cfg) print(f'train parameters:{parameters}') train(model, **parameters) def move_to_device(data, device): if isinstance(data, (list, tuple)): return type(data)(move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 def train(model, **kwargs): # 默认参数 default_params = { 'dataset_path': '/path/to/dataset', 'num_classes': 2, 'num_keypoints':2, 'opt': 'adamw', 'batch_size': 2, 'epochs': 10, 'lr': 0.005, 'momentum': 0.9, 'weight_decay': 1e-4, 'lr_step_size': 3, 'lr_gamma': 0.1, 'num_workers': 4, 'print_freq': 10, 'target_type': 'polygon', 'enable_logs': True, 'augmentation': False, 'checkpoint':None } # 更新默认参数 for key, value in kwargs.items(): if key in default_params: default_params[key] = value else: raise ValueError(f"Unknown argument: {key}") # 解析参数 dataset_path = default_params['dataset_path'] num_classes = default_params['num_classes'] batch_size = default_params['batch_size'] epochs = default_params['epochs'] lr = default_params['lr'] momentum = default_params['momentum'] weight_decay = default_params['weight_decay'] lr_step_size = default_params['lr_step_size'] lr_gamma = default_params['lr_gamma'] num_workers = default_params['num_workers'] print_freq = default_params['print_freq'] target_type = default_params['target_type'] augmentation = default_params['augmentation'] # 设置设备 device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S")) wts_path = os.path.join(train_result_ptath, 'weights') tb_path = os.path.join(train_result_ptath, 'logs') writer = SummaryWriter(tb_path) transforms = None # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms() if augmentation: transforms = get_transform(is_train=True) print(f'transforms:{transforms}') if not os.path.exists('train_results'): os.mkdir('train_results') model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) dataset = MaskRCNNDataset(dataset_path=dataset_path, transforms=transforms, dataset_type='train', target_type=target_type) val_dataset = MaskRCNNDataset(dataset_path=dataset_path, transforms=None, dataset_type='val') train_sampler = torch.utils.data.RandomSampler(dataset) val_sampler = torch.utils.data.RandomSampler(val_dataset) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True) val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size, drop_last=True) train_collate_fn = utils.collate_fn val_collate_fn = utils.collate_fn data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn ) val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_sampler=val_batch_sampler, num_workers=num_workers, collate_fn=val_collate_fn ) # data_loader_test = torch.utils.data.DataLoader( # dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn # ) img_results_path = os.path.join(train_result_ptath, 'img_results') if os.path.exists(train_result_ptath): pass # os.remove(train_result_ptath) else: os.mkdir(train_result_ptath) if os.path.exists(train_result_ptath): os.mkdir(wts_path) os.mkdir(img_results_path) for epoch in range(epochs): metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None) losses = metric_logger.meters['loss'].global_avg print(f'epoch {epoch}:loss:{losses}') if os.path.exists(f'{wts_path}/last.pt'): os.remove(f'{wts_path}/last.pt') torch.save(model.state_dict(), f'{wts_path}/last.pt') write_metric_logs(epoch, metric_logger, writer) if epoch == 0: best_loss = losses; if best_loss >= losses: best_loss = losses if os.path.exists(f'{wts_path}/best.pt'): os.remove(f'{wts_path}/best.pt') torch.save(model.state_dict(), f'{wts_path}/best.pt') with torch.no_grad(): model.eval() for step, (imgs,targets) in enumerate(val_data_loader): imgs,targets=(imgs,targets) imgs=move_to_device(imgs,device) if step ==0: result=model([imgs[0]]) print(f'model eval result :{result}') write_val_imgs(epoch,imgs[0],result,writer) def get_transform(is_train, **kwargs): default_params = { 'augmentation': 'multiscale', 'backend': 'tensor', 'use_v2': False, } # 更新默认参数 for key, value in kwargs.items(): if key in default_params: default_params[key] = value else: raise ValueError(f"Unknown argument: {key}") # 解析参数 augmentation = default_params['augmentation'] backend = default_params['backend'] use_v2 = default_params['use_v2'] if is_train: return presets.DetectionPresetTrain( data_augmentation=augmentation, backend=backend, use_v2=use_v2 ) # elif weights and test_only: # weights = torchvision.models.get_weight(args.weights) # trans = weights.transforms() # return lambda img, target: (trans(img), target) else: return presets.DetectionPresetEval(backend=backend, use_v2=use_v2) def write_metric_logs(epoch, metric_logger, writer): writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch) writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch) writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch) writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch) writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch) writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch) def generate_colors(n): """ 生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。 :param n: 需要的颜色数量 :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组 """ hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)] bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors] return bgr_colors def overlay_masks_on_image(image, masks, alpha=0.6): """ 在原图上叠加多个掩码,每个掩码使用不同的颜色。 :param image: 原图 (NumPy 数组) :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像) :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组) :param alpha: 掩码的透明度 (0.0 到 1.0) :return: 叠加了多个掩码的图像 """ colors = generate_colors(len(masks)) if len(masks) != len(colors): raise ValueError("The number of masks and colors must be the same.") # 复制原图,避免修改原始图像 overlay = image.copy() for mask, color in zip(masks, colors): # 确保掩码是二值图像 mask = mask.cpu().detach().permute(1, 2, 0).numpy() binary_mask = (mask > 0).astype(np.uint8) * 255 # 你可以根据实际情况调整阈值 # 创建彩色掩码 colored_mask = np.zeros_like(image) colored_mask[:] = color colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask) # 将彩色掩码与当前的叠加图像混合 overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0) return overlay def write_val_imgs(epoch, img, results, writer): masks = results[0]['masks'].squeeze(1).to(torch.bool) print(f'masks shape:{masks.shape}') boxes = results[0]['boxes'] print(f'boxes shape:{boxes.shape}') print(f'writer img shape:{img.shape}') # cv2.imshow('ins',masks[0].cpu().detach().numpy()) boxes = boxes.cpu().detach() drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5) print(f'drawn_boxes:{drawn_boxes.shape}') drawn_boxes=drawn_boxes.cpu() # boxed_img = drawn_boxes.permute(1, 2, 0).numpy() writer.add_image("z-boxes", drawn_boxes, epoch) # boxed_img=cv2.resize(boxed_img,(800,800)) # cv2.imshow('boxes',boxed_img) masked_img = draw_segmentation_masks((img * 255).to(torch.uint8), masks) writer.add_image("z-masks", masked_img, epoch)