import math import os import sys from datetime import datetime import torch import torchvision from torch.utils.tensorboard import SummaryWriter from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights from models.wirenet.postprocess import postprocess_keypoint from torchvision.utils import draw_bounding_boxes from torchvision import transforms import matplotlib.pyplot as plt import numpy as np import matplotlib as mpl from tools.coco_utils import get_coco_api_from_dataset from tools.coco_eval import CocoEvaluator import time from models.config.config_tool import read_yaml from models.ins.maskrcnn_dataset import MaskRCNNDataset from models.keypoint.keypoint_dataset import KeypointDataset from tools import utils, presets def log_losses_to_tensorboard(writer, result, step): writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step) writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step) writer.add_scalar('Loss/keypoint', result['loss_keypoint'].item(), step) writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step) writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step) def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None): model.train1() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) header = f"Epoch: [{epoch}]" lr_scheduler = None if epoch == 0: warmup_factor = 1.0 / 1000 warmup_iters = min(1000, len(data_loader) - 1) lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=warmup_factor, total_iters=warmup_iters ) total_train_loss=0 for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): global_step = epoch * len(data_loader) + batch_idx # print(f'images:{images}') images = list(image.to(device) for image in images) targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] with torch.cuda.amp.autocast(enabled=scaler is not None): loss_dict = model(images, targets) # print(f'loss_dict:{loss_dict}') losses = sum(loss for loss in loss_dict.values()) total_train_loss += losses.item() log_losses_to_tensorboard(writer, loss_dict, global_step) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() if not math.isfinite(loss_value): print(f"Loss is {loss_value}, stopping training") print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() if scaler is not None: scaler.scale(losses).backward() scaler.step(optimizer) scaler.update() else: losses.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step() metric_logger.update(loss=losses_reduced, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) return metric_logger, total_train_loss cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def c(x): return sm.to_rgba(x) def show_line(img, pred, epoch, writer): im = img.permute(1, 2, 0) # [512, 512, 3] writer.add_image("ori", im, epoch, dataformats="HWC") boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"], colors="yellow", width=1) # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy()) # plt.show() writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC") PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} lines = pred["keypoints"].detach().cpu().numpy() scores = pred["keypoints_scores"].detach().cpu().numpy() # for i in range(1, len(lines)): # if (lines[i] == lines[0]).all(): # lines = lines[:i] # scores = scores[:i] # break # postprocess lines to remove overlapped lines diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False) # print(f'nscores:{nscores}') for i, t in enumerate([0.5]): plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) for (a, b), s in zip(nlines, nscores): if s < t: continue # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) # plt.scatter(a[1], a[0], **PLTOPTS) # plt.scatter(b[1], b[0], **PLTOPTS) plt.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=2, zorder=s) plt.scatter(a[0], a[1], **PLTOPTS) plt.scatter(b[0], b[1], **PLTOPTS) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(im.cpu()) plt.tight_layout() fig = plt.gcf() fig.canvas.draw() image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( fig.canvas.get_width_height()[::-1] + (3,)) plt.close() img2 = transforms.ToTensor()(image_from_plot) writer.add_image("output", img2, epoch) def _get_iou_types(model): model_without_ddp = model if isinstance(model, torch.nn.parallel.DistributedDataParallel): model_without_ddp = model.module iou_types = ["bbox"] if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): iou_types.append("segm") if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): iou_types.append("keypoints") return iou_types def evaluate(model, data_loader, epoch, writer, device): n_threads = torch.get_num_threads() # FIXME remove this and make paste_masks_in_image run on the GPU torch.set_num_threads(1) cpu_device = torch.device("cpu") model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = "Test:" coco = get_coco_api_from_dataset(data_loader.dataset) iou_types = _get_iou_types(model) coco_evaluator = CocoEvaluator(coco, iou_types) print(f'start to evaluate!!!') for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)): images = list(img.to(device) for img in images) model_time = time.time() outputs = model(images) # print(f'outputs:{outputs}') if batch_idx == 0: show_line(images[0], outputs[0], epoch, writer) # print(f'outputs:{outputs}') # print(f'outputs[0]:{outputs[0]}') # outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] # model_time = time.time() - model_time # # res = {target["image_id"]: output for target, output in zip(targets, outputs)} # evaluator_time = time.time() # coco_evaluator.update(res) # evaluator_time = time.time() - evaluator_time # metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) # # # gather the stats from all processes # metric_logger.synchronize_between_processes() # print("Averaged stats:", metric_logger) # coco_evaluator.synchronize_between_processes() # # # accumulate predictions from all images # coco_evaluator.accumulate() # coco_evaluator.summarize() # torch.set_num_threads(n_threads) # return coco_evaluator def train_cfg(model, cfg): parameters = read_yaml(cfg) print(f'train parameters:{parameters}') train(model, **parameters) def train(model, **kwargs): # 默认参数 default_params = { 'dataset_path': '/path/to/dataset', 'num_classes': 2, 'num_keypoints': 2, 'opt': 'adamw', 'batch_size': 2, 'epochs': 10, 'lr': 0.005, 'momentum': 0.9, 'weight_decay': 1e-4, 'lr_step_size': 3, 'lr_gamma': 0.1, 'num_workers': 4, 'print_freq': 10, 'target_type': 'polygon', 'enable_logs': True, 'augmentation': False, 'checkpoint': None } # 更新默认参数 for key, value in kwargs.items(): if key in default_params: default_params[key] = value else: raise ValueError(f"Unknown argument: {key}") # 解析参数 dataset_path = default_params['dataset_path'] num_classes = default_params['num_classes'] batch_size = default_params['batch_size'] epochs = default_params['epochs'] lr = default_params['lr'] momentum = default_params['momentum'] weight_decay = default_params['weight_decay'] lr_step_size = default_params['lr_step_size'] lr_gamma = default_params['lr_gamma'] num_workers = default_params['num_workers'] print_freq = default_params['print_freq'] target_type = default_params['target_type'] augmentation = default_params['augmentation'] # 设置设备 device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S")) wts_path = os.path.join(train_result_ptath, 'weights') tb_path = os.path.join(train_result_ptath, 'logs') writer = SummaryWriter(tb_path) transforms = None # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms() if augmentation: transforms = get_transform(is_train=True) print(f'transforms:{transforms}') if not os.path.exists('train_results'): os.mkdir('train_results') model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) dataset = KeypointDataset(dataset_path=dataset_path, transforms=transforms, dataset_type='train', target_type=target_type) dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None, dataset_type='val') train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.RandomSampler(dataset_test) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True) train_collate_fn = utils.collate_fn data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn ) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn ) img_results_path = os.path.join(train_result_ptath, 'img_results') if os.path.exists(train_result_ptath): pass # os.remove(train_result_ptath) else: os.mkdir(train_result_ptath) if os.path.exists(train_result_ptath): os.mkdir(wts_path) os.mkdir(img_results_path) for epoch in range(epochs): metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None) losses = metric_logger.meters['loss'].global_avg print(f'epoch {epoch}:loss:{losses}') if os.path.exists(f'{wts_path}/last.pt'): os.remove(f'{wts_path}/last.pt') torch.save(model.state_dict(), f'{wts_path}/last.pt') # write_metric_logs(epoch, metric_logger, writer) if epoch == 0: best_loss = losses; if best_loss >= losses: best_loss = losses if os.path.exists(f'{wts_path}/best.pt'): os.remove(f'{wts_path}/best.pt') torch.save(model.state_dict(), f'{wts_path}/best.pt') evaluate(model, data_loader_test, epoch, writer, device=device) avg_train_loss = total_train_loss / len(data_loader) writer.add_scalar('Loss/train', avg_train_loss, epoch) def get_transform(is_train, **kwargs): default_params = { 'augmentation': 'multiscale', 'backend': 'tensor', 'use_v2': False, } # 更新默认参数 for key, value in kwargs.items(): if key in default_params: default_params[key] = value else: raise ValueError(f"Unknown argument: {key}") # 解析参数 augmentation = default_params['augmentation'] backend = default_params['backend'] use_v2 = default_params['use_v2'] if is_train: return presets.DetectionPresetTrain( data_augmentation=augmentation, backend=backend, use_v2=use_v2 ) # elif weights and test_only: # weights = torchvision.models.get_weight(args.weights) # trans = weights.transforms() # return lambda img, target: (trans(img), target) else: return presets.DetectionPresetEval(backend=backend, use_v2=use_v2) def write_metric_logs(epoch, metric_logger, writer): writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch) writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch) # writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch) writer.add_scalar('Loss/box_reg', metric_logger.meters['loss_keypoint'].global_avg, epoch) writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch) writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch) writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch) # def log_losses_to_tensorboard(writer, result, step): # writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step) # writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step) # writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step) # writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step) # writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)