123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 |
- import math
- import os
- import sys
- from datetime import datetime
- import torch
- import torchvision
- from torch.utils.tensorboard import SummaryWriter
- from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
- from models.wirenet.postprocess import postprocess_keypoint
- from torchvision.utils import draw_bounding_boxes
- from torchvision import transforms
- import matplotlib.pyplot as plt
- import numpy as np
- import matplotlib as mpl
- from tools.coco_utils import get_coco_api_from_dataset
- from tools.coco_eval import CocoEvaluator
- import time
- from models.config.config_tool import read_yaml
- from models.ins.maskrcnn_dataset import MaskRCNNDataset
- from models.keypoint.keypoint_dataset import KeypointDataset
- from tools import utils, presets
- def log_losses_to_tensorboard(writer, result, step):
- writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
- writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
- writer.add_scalar('Loss/keypoint', result['loss_keypoint'].item(), step)
- writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
- writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
- def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
- model.train1()
- metric_logger = utils.MetricLogger(delimiter=" ")
- metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
- header = f"Epoch: [{epoch}]"
- lr_scheduler = None
- if epoch == 0:
- warmup_factor = 1.0 / 1000
- warmup_iters = min(1000, len(data_loader) - 1)
- lr_scheduler = torch.optim.lr_scheduler.LinearLR(
- optimizer, start_factor=warmup_factor, total_iters=warmup_iters
- )
- total_train_loss=0
- for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
- global_step = epoch * len(data_loader) + batch_idx
- # print(f'images:{images}')
- images = list(image.to(device) for image in images)
- targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
- with torch.cuda.amp.autocast(enabled=scaler is not None):
- loss_dict = model(images, targets)
- # print(f'loss_dict:{loss_dict}')
- losses = sum(loss for loss in loss_dict.values())
- total_train_loss += losses.item()
- log_losses_to_tensorboard(writer, loss_dict, global_step)
- # reduce losses over all GPUs for logging purposes
- loss_dict_reduced = utils.reduce_dict(loss_dict)
- losses_reduced = sum(loss for loss in loss_dict_reduced.values())
- loss_value = losses_reduced.item()
- if not math.isfinite(loss_value):
- print(f"Loss is {loss_value}, stopping training")
- print(loss_dict_reduced)
- sys.exit(1)
- optimizer.zero_grad()
- if scaler is not None:
- scaler.scale(losses).backward()
- scaler.step(optimizer)
- scaler.update()
- else:
- losses.backward()
- optimizer.step()
- if lr_scheduler is not None:
- lr_scheduler.step()
- metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
- metric_logger.update(lr=optimizer.param_groups[0]["lr"])
- return metric_logger, total_train_loss
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def c(x):
- return sm.to_rgba(x)
- def show_line(img, pred, epoch, writer):
- im = img.permute(1, 2, 0) # [512, 512, 3]
- writer.add_image("ori", im, epoch, dataformats="HWC")
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
- colors="yellow", width=1)
- # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
- # plt.show()
- writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
- lines = pred["keypoints"].detach().cpu().numpy()
- scores = pred["keypoints_scores"].detach().cpu().numpy()
- # for i in range(1, len(lines)):
- # if (lines[i] == lines[0]).all():
- # lines = lines[:i]
- # scores = scores[:i]
- # break
- # postprocess lines to remove overlapped lines
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
- # print(f'nscores:{nscores}')
- for i, t in enumerate([0.5]):
- plt.gca().set_axis_off()
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
- plt.margins(0, 0)
- for (a, b), s in zip(nlines, nscores):
- if s < t:
- continue
- # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
- # plt.scatter(a[1], a[0], **PLTOPTS)
- # plt.scatter(b[1], b[0], **PLTOPTS)
- plt.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=2, zorder=s)
- plt.scatter(a[0], a[1], **PLTOPTS)
- plt.scatter(b[0], b[1], **PLTOPTS)
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- plt.imshow(im.cpu())
- plt.tight_layout()
- fig = plt.gcf()
- fig.canvas.draw()
- image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
- fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- img2 = transforms.ToTensor()(image_from_plot)
- writer.add_image("output", img2, epoch)
- def _get_iou_types(model):
- model_without_ddp = model
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- model_without_ddp = model.module
- iou_types = ["bbox"]
- if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
- iou_types.append("segm")
- if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
- iou_types.append("keypoints")
- return iou_types
- def evaluate(model, data_loader, epoch, writer, device):
- n_threads = torch.get_num_threads()
- # FIXME remove this and make paste_masks_in_image run on the GPU
- torch.set_num_threads(1)
- cpu_device = torch.device("cpu")
- model.eval()
- metric_logger = utils.MetricLogger(delimiter=" ")
- header = "Test:"
- coco = get_coco_api_from_dataset(data_loader.dataset)
- iou_types = _get_iou_types(model)
- coco_evaluator = CocoEvaluator(coco, iou_types)
- print(f'start to evaluate!!!')
- for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
- images = list(img.to(device) for img in images)
- model_time = time.time()
- outputs = model(images)
- # print(f'outputs:{outputs}')
- if batch_idx == 0:
- show_line(images[0], outputs[0], epoch, writer)
- # print(f'outputs:{outputs}')
- # print(f'outputs[0]:{outputs[0]}')
- # outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
- # model_time = time.time() - model_time
- #
- # res = {target["image_id"]: output for target, output in zip(targets, outputs)}
- # evaluator_time = time.time()
- # coco_evaluator.update(res)
- # evaluator_time = time.time() - evaluator_time
- # metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
- #
- # # gather the stats from all processes
- # metric_logger.synchronize_between_processes()
- # print("Averaged stats:", metric_logger)
- # coco_evaluator.synchronize_between_processes()
- #
- # # accumulate predictions from all images
- # coco_evaluator.accumulate()
- # coco_evaluator.summarize()
- # torch.set_num_threads(n_threads)
- # return coco_evaluator
- def train_cfg(model, cfg):
- parameters = read_yaml(cfg)
- print(f'train parameters:{parameters}')
- train(model, **parameters)
- def train(model, **kwargs):
- # 默认参数
- default_params = {
- 'dataset_path': '/path/to/dataset',
- 'num_classes': 2,
- 'num_keypoints': 2,
- 'opt': 'adamw',
- 'batch_size': 2,
- 'epochs': 10,
- 'lr': 0.005,
- 'momentum': 0.9,
- 'weight_decay': 1e-4,
- 'lr_step_size': 3,
- 'lr_gamma': 0.1,
- 'num_workers': 4,
- 'print_freq': 10,
- 'target_type': 'polygon',
- 'enable_logs': True,
- 'augmentation': False,
- 'checkpoint': None
- }
- # 更新默认参数
- for key, value in kwargs.items():
- if key in default_params:
- default_params[key] = value
- else:
- raise ValueError(f"Unknown argument: {key}")
- # 解析参数
- dataset_path = default_params['dataset_path']
- num_classes = default_params['num_classes']
- batch_size = default_params['batch_size']
- epochs = default_params['epochs']
- lr = default_params['lr']
- momentum = default_params['momentum']
- weight_decay = default_params['weight_decay']
- lr_step_size = default_params['lr_step_size']
- lr_gamma = default_params['lr_gamma']
- num_workers = default_params['num_workers']
- print_freq = default_params['print_freq']
- target_type = default_params['target_type']
- augmentation = default_params['augmentation']
- # 设置设备
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
- train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
- wts_path = os.path.join(train_result_ptath, 'weights')
- tb_path = os.path.join(train_result_ptath, 'logs')
- writer = SummaryWriter(tb_path)
- transforms = None
- # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
- if augmentation:
- transforms = get_transform(is_train=True)
- print(f'transforms:{transforms}')
- if not os.path.exists('train_results'):
- os.mkdir('train_results')
- model.to(device)
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
- dataset = KeypointDataset(dataset_path=dataset_path,
- transforms=transforms, dataset_type='train', target_type=target_type)
- dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
- dataset_type='val')
- train_sampler = torch.utils.data.RandomSampler(dataset)
- test_sampler = torch.utils.data.RandomSampler(dataset_test)
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
- train_collate_fn = utils.collate_fn
- data_loader = torch.utils.data.DataLoader(
- dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
- )
- data_loader_test = torch.utils.data.DataLoader(
- dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
- )
- img_results_path = os.path.join(train_result_ptath, 'img_results')
- if os.path.exists(train_result_ptath):
- pass
- # os.remove(train_result_ptath)
- else:
- os.mkdir(train_result_ptath)
- if os.path.exists(train_result_ptath):
- os.mkdir(wts_path)
- os.mkdir(img_results_path)
- for epoch in range(epochs):
- metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
- losses = metric_logger.meters['loss'].global_avg
- print(f'epoch {epoch}:loss:{losses}')
- if os.path.exists(f'{wts_path}/last.pt'):
- os.remove(f'{wts_path}/last.pt')
- torch.save(model.state_dict(), f'{wts_path}/last.pt')
- # write_metric_logs(epoch, metric_logger, writer)
- if epoch == 0:
- best_loss = losses;
- if best_loss >= losses:
- best_loss = losses
- if os.path.exists(f'{wts_path}/best.pt'):
- os.remove(f'{wts_path}/best.pt')
- torch.save(model.state_dict(), f'{wts_path}/best.pt')
- evaluate(model, data_loader_test, epoch, writer, device=device)
- avg_train_loss = total_train_loss / len(data_loader)
- writer.add_scalar('Loss/train', avg_train_loss, epoch)
- def get_transform(is_train, **kwargs):
- default_params = {
- 'augmentation': 'multiscale',
- 'backend': 'tensor',
- 'use_v2': False,
- }
- # 更新默认参数
- for key, value in kwargs.items():
- if key in default_params:
- default_params[key] = value
- else:
- raise ValueError(f"Unknown argument: {key}")
- # 解析参数
- augmentation = default_params['augmentation']
- backend = default_params['backend']
- use_v2 = default_params['use_v2']
- if is_train:
- return presets.DetectionPresetTrain(
- data_augmentation=augmentation, backend=backend, use_v2=use_v2
- )
- # elif weights and test_only:
- # weights = torchvision.models.get_weight(args.weights)
- # trans = weights.transforms()
- # return lambda img, target: (trans(img), target)
- else:
- return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
- def write_metric_logs(epoch, metric_logger, writer):
- writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
- writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
- # writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
- writer.add_scalar('Loss/box_reg', metric_logger.meters['loss_keypoint'].global_avg, epoch)
- writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
- writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
- writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
- # def log_losses_to_tensorboard(writer, result, step):
- # writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
- # writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
- # writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
- # writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
- # writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
|