123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- import math
- import os
- import sys
- from datetime import datetime
- import torch
- import torchvision
- from torch.utils.tensorboard import SummaryWriter
- from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
- from models.config.config_tool import read_yaml
- from models.ins.maskrcnn_dataset import MaskRCNNDataset
- from tools import utils, presets
- def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
- model.train()
- metric_logger = utils.MetricLogger(delimiter=" ")
- metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
- header = f"Epoch: [{epoch}]"
- lr_scheduler = None
- if epoch == 0:
- warmup_factor = 1.0 / 1000
- warmup_iters = min(1000, len(data_loader) - 1)
- lr_scheduler = torch.optim.lr_scheduler.LinearLR(
- optimizer, start_factor=warmup_factor, total_iters=warmup_iters
- )
- for images, targets in metric_logger.log_every(data_loader, print_freq, header):
- print(f'images:{images}')
- images = list(image.to(device) for image in images)
- targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
- with torch.cuda.amp.autocast(enabled=scaler is not None):
- loss_dict = model(images, targets)
- losses = sum(loss for loss in loss_dict.values())
- # reduce losses over all GPUs for logging purposes
- loss_dict_reduced = utils.reduce_dict(loss_dict)
- losses_reduced = sum(loss for loss in loss_dict_reduced.values())
- loss_value = losses_reduced.item()
- if not math.isfinite(loss_value):
- print(f"Loss is {loss_value}, stopping training")
- print(loss_dict_reduced)
- sys.exit(1)
- optimizer.zero_grad()
- if scaler is not None:
- scaler.scale(losses).backward()
- scaler.step(optimizer)
- scaler.update()
- else:
- losses.backward()
- optimizer.step()
- if lr_scheduler is not None:
- lr_scheduler.step()
- metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
- metric_logger.update(lr=optimizer.param_groups[0]["lr"])
- return metric_logger
- def load_train_parameter(cfg):
- parameters = read_yaml(cfg)
- return parameters
- def train_cfg(model, cfg):
- parameters = read_yaml(cfg)
- print(f'train parameters:{parameters}')
- train(model, **parameters)
- def train(model, **kwargs):
- # 默认参数
- default_params = {
- 'dataset_path': '/path/to/dataset',
- 'num_classes': 10,
- 'opt': 'adamw',
- 'batch_size': 2,
- 'epochs': 10,
- 'lr': 0.005,
- 'momentum': 0.9,
- 'weight_decay': 1e-4,
- 'lr_step_size': 3,
- 'lr_gamma': 0.1,
- 'num_workers': 4,
- 'print_freq': 10,
- 'target_type': 'polygon',
- 'enable_logs': True,
- 'augmentation': False,
- 'checkpoint':None
- }
- # 更新默认参数
- for key, value in kwargs.items():
- if key in default_params:
- default_params[key] = value
- else:
- raise ValueError(f"Unknown argument: {key}")
- # 解析参数
- dataset_path = default_params['dataset_path']
- num_classes = default_params['num_classes']
- batch_size = default_params['batch_size']
- epochs = default_params['epochs']
- lr = default_params['lr']
- momentum = default_params['momentum']
- weight_decay = default_params['weight_decay']
- lr_step_size = default_params['lr_step_size']
- lr_gamma = default_params['lr_gamma']
- num_workers = default_params['num_workers']
- print_freq = default_params['print_freq']
- target_type = default_params['target_type']
- augmentation = default_params['augmentation']
- # 设置设备
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
- train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
- wts_path = os.path.join(train_result_ptath, 'weights')
- tb_path = os.path.join(train_result_ptath, 'logs')
- writer = SummaryWriter(tb_path)
- transforms = None
- # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
- if augmentation:
- transforms = get_transform(is_train=True)
- print(f'transforms:{transforms}')
- if not os.path.exists('train_results'):
- os.mkdir('train_results')
- model.to(device)
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
- dataset = MaskRCNNDataset(dataset_path=dataset_path,
- transforms=transforms, dataset_type='train', target_type=target_type)
- dataset_test = MaskRCNNDataset(dataset_path=dataset_path, transforms=None,
- dataset_type='val')
- train_sampler = torch.utils.data.RandomSampler(dataset)
- test_sampler = torch.utils.data.SequentialSampler(dataset_test)
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
- train_collate_fn = utils.collate_fn
- data_loader = torch.utils.data.DataLoader(
- dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
- )
- # data_loader_test = torch.utils.data.DataLoader(
- # dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
- # )
- img_results_path = os.path.join(train_result_ptath, 'img_results')
- if os.path.exists(train_result_ptath):
- pass
- # os.remove(train_result_ptath)
- else:
- os.mkdir(train_result_ptath)
- if os.path.exists(train_result_ptath):
- os.mkdir(wts_path)
- os.mkdir(img_results_path)
- for epoch in range(epochs):
- metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
- losses = metric_logger.meters['loss'].global_avg
- print(f'epoch {epoch}:loss:{losses}')
- if os.path.exists(f'{wts_path}/last.pt'):
- os.remove(f'{wts_path}/last.pt')
- torch.save(model.state_dict(), f'{wts_path}/last.pt')
- write_metric_logs(epoch, metric_logger, writer)
- if epoch == 0:
- best_loss = losses;
- if best_loss >= losses:
- best_loss = losses
- if os.path.exists(f'{wts_path}/best.pt'):
- os.remove(f'{wts_path}/best.pt')
- torch.save(model.state_dict(), f'{wts_path}/best.pt')
- def get_transform(is_train, **kwargs):
- default_params = {
- 'augmentation': 'multiscale',
- 'backend': 'tensor',
- 'use_v2': False,
- }
- # 更新默认参数
- for key, value in kwargs.items():
- if key in default_params:
- default_params[key] = value
- else:
- raise ValueError(f"Unknown argument: {key}")
- # 解析参数
- augmentation = default_params['augmentation']
- backend = default_params['backend']
- use_v2 = default_params['use_v2']
- if is_train:
- return presets.DetectionPresetTrain(
- data_augmentation=augmentation, backend=backend, use_v2=use_v2
- )
- # elif weights and test_only:
- # weights = torchvision.models.get_weight(args.weights)
- # trans = weights.transforms()
- # return lambda img, target: (trans(img), target)
- else:
- return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
- def write_metric_logs(epoch, metric_logger, writer):
- writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
- writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
- writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
- writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
- writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
- writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
|