| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- import math
- import os
- import sys
- from datetime import datetime
- import cv2
- import numpy as np
- import torch
- import torchvision
- from torch.utils.tensorboard import SummaryWriter
- from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
- from torchvision.utils import draw_bounding_boxes
- from libs.vision_libs.utils import draw_segmentation_masks
- from models.config.config_tool import read_yaml
- from models.ins_detect.maskrcnn_dataset import MaskRCNNDataset
- from tools import utils, presets
- def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
- model.train()
- metric_logger = utils.MetricLogger(delimiter=" ")
- metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
- header = f"Epoch: [{epoch}]"
- lr_scheduler = None
- if epoch == 0:
- warmup_factor = 1.0 / 1000
- warmup_iters = min(1000, len(data_loader) - 1)
- lr_scheduler = torch.optim.lr_scheduler.LinearLR(
- optimizer, start_factor=warmup_factor, total_iters=warmup_iters
- )
- for images, targets in metric_logger.log_every(data_loader, print_freq, header):
- # print(f'images:{images}')
- images = list(image.to(device) for image in images)
- targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
- with torch.cuda.amp.autocast(enabled=scaler is not None):
- loss_dict = model(images, targets)
- losses = sum(loss for loss in loss_dict.values())
- # reduce losses over all GPUs for logging purposes
- loss_dict_reduced = utils.reduce_dict(loss_dict)
- losses_reduced = sum(loss for loss in loss_dict_reduced.values())
- loss_value = losses_reduced.item()
- if not math.isfinite(loss_value):
- print(f"Loss is {loss_value}, stopping training")
- print(loss_dict_reduced)
- sys.exit(1)
- optimizer.zero_grad()
- if scaler is not None:
- scaler.scale(losses).backward()
- scaler.step(optimizer)
- scaler.update()
- else:
- losses.backward()
- optimizer.step()
- if lr_scheduler is not None:
- lr_scheduler.step()
- metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
- metric_logger.update(lr=optimizer.param_groups[0]["lr"])
- return metric_logger
- def load_train_parameter(cfg):
- parameters = read_yaml(cfg)
- return parameters
- def train_cfg(model, cfg):
- parameters = read_yaml(cfg)
- print(f'train parameters:{parameters}')
- train(model, **parameters)
- def move_to_device(data, device):
- if isinstance(data, (list, tuple)):
- return type(data)(move_to_device(item, device) for item in data)
- elif isinstance(data, dict):
- return {key: move_to_device(value, device) for key, value in data.items()}
- elif isinstance(data, torch.Tensor):
- return data.to(device)
- else:
- return data # 对于非张量类型的数据不做任何改变
- def train(model, **kwargs):
- # 默认参数
- default_params = {
- 'dataset_path': '/path/to/dataset',
- 'num_classes': 2,
- 'num_keypoints':2,
- 'opt': 'adamw',
- 'batch_size': 2,
- 'epochs': 10,
- 'lr': 0.005,
- 'momentum': 0.9,
- 'weight_decay': 1e-4,
- 'lr_step_size': 3,
- 'lr_gamma': 0.1,
- 'num_workers': 4,
- 'print_freq': 10,
- 'target_type': 'polygon',
- 'enable_logs': True,
- 'augmentation': False,
- 'checkpoint':None
- }
- # 更新默认参数
- for key, value in kwargs.items():
- if key in default_params:
- default_params[key] = value
- else:
- raise ValueError(f"Unknown argument: {key}")
- # 解析参数
- dataset_path = default_params['dataset_path']
- num_classes = default_params['num_classes']
- batch_size = default_params['batch_size']
- epochs = default_params['epochs']
- lr = default_params['lr']
- momentum = default_params['momentum']
- weight_decay = default_params['weight_decay']
- lr_step_size = default_params['lr_step_size']
- lr_gamma = default_params['lr_gamma']
- num_workers = default_params['num_workers']
- print_freq = default_params['print_freq']
- target_type = default_params['target_type']
- augmentation = default_params['augmentation']
- # 设置设备
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
- train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
- wts_path = os.path.join(train_result_ptath, 'weights')
- tb_path = os.path.join(train_result_ptath, 'logs')
- writer = SummaryWriter(tb_path)
- transforms = None
- # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
- if augmentation:
- transforms = get_transform(is_train=True)
- print(f'transforms:{transforms}')
- if not os.path.exists('train_results'):
- os.mkdir('train_results')
- model.to(device)
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
- dataset = MaskRCNNDataset(dataset_path=dataset_path,
- transforms=transforms, dataset_type='train', target_type=target_type)
- val_dataset = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
- dataset_type='val')
- train_sampler = torch.utils.data.RandomSampler(dataset)
- val_sampler = torch.utils.data.RandomSampler(val_dataset)
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size, drop_last=True)
- train_collate_fn = utils.collate_fn
- val_collate_fn = utils.collate_fn
- data_loader = torch.utils.data.DataLoader(
- dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
- )
- val_data_loader = torch.utils.data.DataLoader(
- val_dataset, batch_sampler=val_batch_sampler, num_workers=num_workers, collate_fn=val_collate_fn
- )
- # data_loader_test = torch.utils.data.DataLoader(
- # dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
- # )
- img_results_path = os.path.join(train_result_ptath, 'img_results')
- if os.path.exists(train_result_ptath):
- pass
- # os.remove(train_result_ptath)
- else:
- os.mkdir(train_result_ptath)
- if os.path.exists(train_result_ptath):
- os.mkdir(wts_path)
- os.mkdir(img_results_path)
- for epoch in range(epochs):
- metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
- losses = metric_logger.meters['loss'].global_avg
- print(f'epoch {epoch}:loss:{losses}')
- if os.path.exists(f'{wts_path}/last.pt'):
- os.remove(f'{wts_path}/last.pt')
- torch.save(model.state_dict(), f'{wts_path}/last.pt')
- write_metric_logs(epoch, metric_logger, writer)
- if epoch == 0:
- best_loss = losses;
- if best_loss >= losses:
- best_loss = losses
- if os.path.exists(f'{wts_path}/best.pt'):
- os.remove(f'{wts_path}/best.pt')
- torch.save(model.state_dict(), f'{wts_path}/best.pt')
- with torch.no_grad():
- model.eval()
- for step, (imgs,targets) in enumerate(val_data_loader):
- imgs,targets=(imgs,targets)
- imgs=move_to_device(imgs,device)
- if step ==0:
- result=model([imgs[0]])
- print(f'model eval result :{result}')
- write_val_imgs(epoch,imgs[0],result,writer)
- def get_transform(is_train, **kwargs):
- default_params = {
- 'augmentation': 'multiscale',
- 'backend': 'tensor',
- 'use_v2': False,
- }
- # 更新默认参数
- for key, value in kwargs.items():
- if key in default_params:
- default_params[key] = value
- else:
- raise ValueError(f"Unknown argument: {key}")
- # 解析参数
- augmentation = default_params['augmentation']
- backend = default_params['backend']
- use_v2 = default_params['use_v2']
- if is_train:
- return presets.DetectionPresetTrain(
- data_augmentation=augmentation, backend=backend, use_v2=use_v2
- )
- # elif weights and test_only:
- # weights = torchvision.models.get_weight(args.weights)
- # trans = weights.transforms()
- # return lambda img, target: (trans(img), target)
- else:
- return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
- def write_metric_logs(epoch, metric_logger, writer):
- writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
- writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
- writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
- writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
- writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
- writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
- def generate_colors(n):
- """
- 生成n个均匀分布在HSV色彩空间中的颜色,并转换成BGR色彩空间。
- :param n: 需要的颜色数量
- :return: 一个包含n个颜色的列表,每个颜色为BGR格式的元组
- """
- hsv_colors = [(i / n * 180, 1 / 3 * 255, 2 / 3 * 255) for i in range(n)]
- bgr_colors = [tuple(map(int, cv2.cvtColor(np.uint8([[hsv]]), cv2.COLOR_HSV2BGR)[0][0])) for hsv in hsv_colors]
- return bgr_colors
- def overlay_masks_on_image(image, masks, alpha=0.6):
- """
- 在原图上叠加多个掩码,每个掩码使用不同的颜色。
- :param image: 原图 (NumPy 数组)
- :param masks: 掩码列表 (每个都是 NumPy 数组,二值图像)
- :param colors: 颜色列表 (每个颜色都是 (B, G, R) 格式的元组)
- :param alpha: 掩码的透明度 (0.0 到 1.0)
- :return: 叠加了多个掩码的图像
- """
- colors = generate_colors(len(masks))
- if len(masks) != len(colors):
- raise ValueError("The number of masks and colors must be the same.")
- # 复制原图,避免修改原始图像
- overlay = image.copy()
- for mask, color in zip(masks, colors):
- # 确保掩码是二值图像
- mask = mask.cpu().detach().permute(1, 2, 0).numpy()
- binary_mask = (mask > 0).astype(np.uint8) * 255 # 你可以根据实际情况调整阈值
- # 创建彩色掩码
- colored_mask = np.zeros_like(image)
- colored_mask[:] = color
- colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
- # 将彩色掩码与当前的叠加图像混合
- overlay = cv2.addWeighted(overlay, 1 - alpha, colored_mask, alpha, 0)
- return overlay
- def write_val_imgs(epoch, img, results, writer):
- masks = results[0]['masks'].squeeze(1).to(torch.bool)
- print(f'masks shape:{masks.shape}')
- boxes = results[0]['boxes']
- print(f'boxes shape:{boxes.shape}')
- print(f'writer img shape:{img.shape}')
- # cv2.imshow('ins',masks[0].cpu().detach().numpy())
- boxes = boxes.cpu().detach()
- drawn_boxes = draw_bounding_boxes((img * 255).to(torch.uint8), boxes, colors="red", width=5)
- print(f'drawn_boxes:{drawn_boxes.shape}')
- drawn_boxes=drawn_boxes.cpu()
- # boxed_img = drawn_boxes.permute(1, 2, 0).numpy()
- writer.add_image("z-boxes", drawn_boxes, epoch)
- # boxed_img=cv2.resize(boxed_img,(800,800))
- # cv2.imshow('boxes',boxed_img)
- masked_img = draw_segmentation_masks((img * 255).to(torch.uint8), masks)
- writer.add_image("z-masks", masked_img, epoch)
|