|
@@ -8,12 +8,33 @@ import torchvision
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
|
|
|
|
|
|
+from models.wirenet.postprocess import postprocess_keypoint
|
|
|
+from torchvision.utils import draw_bounding_boxes
|
|
|
+from torchvision import transforms
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import numpy as np
|
|
|
+import matplotlib as mpl
|
|
|
+from tools.coco_utils import get_coco_api_from_dataset
|
|
|
+from tools.coco_eval import CocoEvaluator
|
|
|
+import time
|
|
|
+
|
|
|
from models.config.config_tool import read_yaml
|
|
|
from models.ins.maskrcnn_dataset import MaskRCNNDataset
|
|
|
from models.keypoint.keypoint_dataset import KeypointDataset
|
|
|
from tools import utils, presets
|
|
|
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
|
|
|
+
|
|
|
+
|
|
|
+def log_losses_to_tensorboard(writer, result, step):
|
|
|
+ writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
|
|
|
+ writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
|
|
|
+ writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
|
|
|
+ writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
|
|
|
+ writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
|
|
|
+
|
|
|
+
|
|
|
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq,writer, scaler=None):
|
|
|
model.train()
|
|
|
+ total_train_loss=0.0
|
|
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
|
|
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
|
|
header = f"Epoch: [{epoch}]"
|
|
@@ -27,15 +48,21 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
|
|
|
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
|
|
|
)
|
|
|
|
|
|
- for images, targets in metric_logger.log_every(data_loader, print_freq, header):
|
|
|
+ for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
|
|
+
|
|
|
+ global_step = epoch * len(data_loader) + batch_idx
|
|
|
# print(f'images:{images}')
|
|
|
images = list(image.to(device) for image in images)
|
|
|
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
|
|
|
with torch.cuda.amp.autocast(enabled=scaler is not None):
|
|
|
loss_dict = model(images, targets)
|
|
|
- print(f'loss_dict:{loss_dict}')
|
|
|
+ # print(f'loss_dict:{loss_dict}')
|
|
|
+
|
|
|
losses = sum(loss for loss in loss_dict.values())
|
|
|
|
|
|
+ total_train_loss += losses.item()
|
|
|
+ log_losses_to_tensorboard(writer, loss_dict, global_step)
|
|
|
+
|
|
|
# reduce losses over all GPUs for logging purposes
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
|
|
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
|
@@ -64,17 +91,133 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
|
|
|
|
|
|
return metric_logger
|
|
|
|
|
|
+
|
|
|
+cmap = plt.get_cmap("jet")
|
|
|
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
+sm.set_array([])
|
|
|
+
|
|
|
+
|
|
|
+def c(x):
|
|
|
+ return sm.to_rgba(x)
|
|
|
+
|
|
|
+
|
|
|
+def show_line(img, pred, epoch, writer):
|
|
|
+ im = img.permute(1, 2, 0)
|
|
|
+ writer.add_image("ori", im, epoch, dataformats="HWC")
|
|
|
+
|
|
|
+ boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
|
|
|
+ colors="yellow", width=1)
|
|
|
+ writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
+
|
|
|
+ PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
+ # H = pred[1]['wires']
|
|
|
+ lines = pred["keypoints"].detach().cpu().numpy()
|
|
|
+ scores = pred["keypoints_scores"].detach().cpu().numpy()
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+
|
|
|
+ # postprocess lines to remove overlapped lines
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
|
|
|
+ print(f'nscores:{nscores}')
|
|
|
+
|
|
|
+ for i, t in enumerate([0.5]):
|
|
|
+ plt.gca().set_axis_off()
|
|
|
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
+ plt.margins(0, 0)
|
|
|
+ for (a, b), s in zip(nlines, nscores):
|
|
|
+ if s < t :
|
|
|
+ continue
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
+ plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
+ plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.imshow(im.cpu())
|
|
|
+ plt.tight_layout()
|
|
|
+ fig = plt.gcf()
|
|
|
+ fig.canvas.draw()
|
|
|
+ image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
|
|
|
+ fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
+ plt.close()
|
|
|
+ img2 = transforms.ToTensor()(image_from_plot)
|
|
|
+
|
|
|
+ writer.add_image("output", img2, epoch)
|
|
|
+
|
|
|
+
|
|
|
+def _get_iou_types(model):
|
|
|
+ model_without_ddp = model
|
|
|
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
|
|
+ model_without_ddp = model.module
|
|
|
+ iou_types = ["bbox"]
|
|
|
+ if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
|
|
|
+ iou_types.append("segm")
|
|
|
+ if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
|
|
|
+ iou_types.append("keypoints")
|
|
|
+ return iou_types
|
|
|
+
|
|
|
+
|
|
|
+def evaluate(model, data_loader, epoch, writer, device):
|
|
|
+ n_threads = torch.get_num_threads()
|
|
|
+ # FIXME remove this and make paste_masks_in_image run on the GPU
|
|
|
+ torch.set_num_threads(1)
|
|
|
+ cpu_device = torch.device("cpu")
|
|
|
+ model.eval()
|
|
|
+ metric_logger = utils.MetricLogger(delimiter=" ")
|
|
|
+ header = "Test:"
|
|
|
+
|
|
|
+ coco = get_coco_api_from_dataset(data_loader.dataset)
|
|
|
+ iou_types = _get_iou_types(model)
|
|
|
+ coco_evaluator = CocoEvaluator(coco, iou_types)
|
|
|
+
|
|
|
+ print(f'start to evaluate!!!')
|
|
|
+ for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
|
|
+ images = list(img.to(device) for img in images)
|
|
|
+
|
|
|
+ model_time = time.time()
|
|
|
+ outputs = model(images)
|
|
|
+ print(f'outputs:{outputs}')
|
|
|
+
|
|
|
+ if batch_idx == 0:
|
|
|
+ show_line(images[0], outputs[0], epoch, writer)
|
|
|
+
|
|
|
+ outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
|
|
|
+ model_time = time.time() - model_time
|
|
|
+
|
|
|
+ res = {target["image_id"]: output for target, output in zip(targets, outputs)}
|
|
|
+ evaluator_time = time.time()
|
|
|
+ coco_evaluator.update(res)
|
|
|
+ evaluator_time = time.time() - evaluator_time
|
|
|
+ metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
|
|
|
+
|
|
|
+ # gather the stats from all processes
|
|
|
+ metric_logger.synchronize_between_processes()
|
|
|
+ print("Averaged stats:", metric_logger)
|
|
|
+ coco_evaluator.synchronize_between_processes()
|
|
|
+
|
|
|
+ # accumulate predictions from all images
|
|
|
+ coco_evaluator.accumulate()
|
|
|
+ coco_evaluator.summarize()
|
|
|
+ torch.set_num_threads(n_threads)
|
|
|
+ return coco_evaluator
|
|
|
+
|
|
|
+
|
|
|
def train_cfg(model, cfg):
|
|
|
parameters = read_yaml(cfg)
|
|
|
print(f'train parameters:{parameters}')
|
|
|
train(model, **parameters)
|
|
|
|
|
|
+
|
|
|
def train(model, **kwargs):
|
|
|
# 默认参数
|
|
|
default_params = {
|
|
|
'dataset_path': '/path/to/dataset',
|
|
|
'num_classes': 2,
|
|
|
- 'num_keypoints':2,
|
|
|
+ 'num_keypoints': 2,
|
|
|
'opt': 'adamw',
|
|
|
'batch_size': 2,
|
|
|
'epochs': 10,
|
|
@@ -88,7 +231,7 @@ def train(model, **kwargs):
|
|
|
'target_type': 'polygon',
|
|
|
'enable_logs': True,
|
|
|
'augmentation': False,
|
|
|
- 'checkpoint':None
|
|
|
+ 'checkpoint': None
|
|
|
}
|
|
|
# 更新默认参数
|
|
|
for key, value in kwargs.items():
|
|
@@ -142,9 +285,9 @@ def train(model, **kwargs):
|
|
|
data_loader = torch.utils.data.DataLoader(
|
|
|
dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
|
|
|
)
|
|
|
- # data_loader_test = torch.utils.data.DataLoader(
|
|
|
- # dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
|
|
|
- # )
|
|
|
+ data_loader_test = torch.utils.data.DataLoader(
|
|
|
+ dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
|
|
|
+ )
|
|
|
|
|
|
img_results_path = os.path.join(train_result_ptath, 'img_results')
|
|
|
if os.path.exists(train_result_ptath):
|
|
@@ -158,7 +301,7 @@ def train(model, **kwargs):
|
|
|
os.mkdir(img_results_path)
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
- metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
|
|
|
+ metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
|
|
|
losses = metric_logger.meters['loss'].global_avg
|
|
|
print(f'epoch {epoch}:loss:{losses}')
|
|
|
if os.path.exists(f'{wts_path}/last.pt'):
|
|
@@ -173,6 +316,9 @@ def train(model, **kwargs):
|
|
|
os.remove(f'{wts_path}/best.pt')
|
|
|
torch.save(model.state_dict(), f'{wts_path}/best.pt')
|
|
|
|
|
|
+ evaluate(model, data_loader_test, epoch, writer, device=device)
|
|
|
+
|
|
|
+
|
|
|
def get_transform(is_train, **kwargs):
|
|
|
default_params = {
|
|
|
'augmentation': 'multiscale',
|
|
@@ -209,4 +355,4 @@ def write_metric_logs(epoch, metric_logger, writer):
|
|
|
writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
|
|
|
writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
|
|
|
writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
|
|
|
- writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
|
|
|
+ writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
|