ソースを参照

合并训练keypoint代码

RenLiqiang 5 ヶ月 前
コミット
df978a5957

+ 1 - 1
models/keypoint/keypoint_dataset.py

@@ -198,6 +198,6 @@ class KeypointDataset(BaseDataset):
 
 
 if __name__ == '__main__':
-    path=r"I:\wirenet_dateset"
+    path=r"D:\python\PycharmProjects\data"
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
     dataset.show(0)

+ 164 - 10
models/keypoint/trainer.py

@@ -8,12 +8,33 @@ import torchvision
 from torch.utils.tensorboard import SummaryWriter
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 
+from models.wirenet.postprocess import postprocess_keypoint
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+import matplotlib.pyplot as plt
+import numpy as np
+import matplotlib as mpl
+from tools.coco_utils import get_coco_api_from_dataset
+from tools.coco_eval import CocoEvaluator
+import time
+
 from models.config.config_tool import read_yaml
 from models.ins.maskrcnn_dataset import MaskRCNNDataset
 from models.keypoint.keypoint_dataset import KeypointDataset
 from tools import utils, presets
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+
+
+def log_losses_to_tensorboard(writer, result, step):
+    writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
+    writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
+    writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
+    writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
+    writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq,writer, scaler=None):
     model.train()
+    total_train_loss=0.0
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"
@@ -27,15 +48,21 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
             optimizer, start_factor=warmup_factor, total_iters=warmup_iters
         )
 
-    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+        global_step = epoch * len(data_loader) + batch_idx
         # print(f'images:{images}')
         images = list(image.to(device) for image in images)
         targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
         with torch.cuda.amp.autocast(enabled=scaler is not None):
             loss_dict = model(images, targets)
             # print(f'loss_dict:{loss_dict}')
+
             losses = sum(loss for loss in loss_dict.values())
 
+            total_train_loss += losses.item()
+            log_losses_to_tensorboard(writer, loss_dict, global_step)
+
         # reduce losses over all GPUs for logging purposes
         loss_dict_reduced = utils.reduce_dict(loss_dict)
         losses_reduced = sum(loss for loss in loss_dict_reduced.values())
@@ -64,17 +91,133 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
 
     return metric_logger
 
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    # H = pred[1]['wires']
+    lines = pred["keypoints"].detach().cpu().numpy()
+    scores = pred["keypoints_scores"].detach().cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
+    print(f'nscores:{nscores}')
+
+    for i, t in enumerate([0.5]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t :
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im.cpu())
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+
+def _get_iou_types(model):
+    model_without_ddp = model
+    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+        model_without_ddp = model.module
+    iou_types = ["bbox"]
+    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
+        iou_types.append("segm")
+    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
+        iou_types.append("keypoints")
+    return iou_types
+
+
+def evaluate(model, data_loader, epoch, writer, device):
+    n_threads = torch.get_num_threads()
+    # FIXME remove this and make paste_masks_in_image run on the GPU
+    torch.set_num_threads(1)
+    cpu_device = torch.device("cpu")
+    model.eval()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    header = "Test:"
+
+    coco = get_coco_api_from_dataset(data_loader.dataset)
+    iou_types = _get_iou_types(model)
+    coco_evaluator = CocoEvaluator(coco, iou_types)
+
+    print(f'start to evaluate!!!')
+    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
+        images = list(img.to(device) for img in images)
+
+        model_time = time.time()
+        outputs = model(images)
+        # print(f'outputs:{outputs}')
+
+        if batch_idx == 0:
+            show_line(images[0], outputs[0], epoch, writer)
+
+    #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+    #     #     model_time = time.time() - model_time
+    #     #
+    #     #     res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+    #     #     evaluator_time = time.time()
+    #     #     coco_evaluator.update(res)
+    #     #     evaluator_time = time.time() - evaluator_time
+    #     #     metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+    #     #
+    #     # # gather the stats from all processes
+    #     # metric_logger.synchronize_between_processes()
+    #     # print("Averaged stats:", metric_logger)
+    #     # coco_evaluator.synchronize_between_processes()
+    #     #
+    #     # # accumulate predictions from all images
+    #     # coco_evaluator.accumulate()
+    #     # coco_evaluator.summarize()
+    #     # torch.set_num_threads(n_threads)
+    #     # return coco_evaluator
+
+
 def train_cfg(model, cfg):
     parameters = read_yaml(cfg)
     print(f'train parameters:{parameters}')
     train(model, **parameters)
 
+
 def train(model, **kwargs):
     # 默认参数
     default_params = {
         'dataset_path': '/path/to/dataset',
         'num_classes': 2,
-        'num_keypoints':2,
+        'num_keypoints': 2,
         'opt': 'adamw',
         'batch_size': 2,
         'epochs': 10,
@@ -88,7 +231,7 @@ def train(model, **kwargs):
         'target_type': 'polygon',
         'enable_logs': True,
         'augmentation': False,
-        'checkpoint':None
+        'checkpoint': None
     }
     # 更新默认参数
     for key, value in kwargs.items():
@@ -142,9 +285,9 @@ def train(model, **kwargs):
     data_loader = torch.utils.data.DataLoader(
         dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
     )
-    # data_loader_test = torch.utils.data.DataLoader(
-    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
-    # )
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    )
 
     img_results_path = os.path.join(train_result_ptath, 'img_results')
     if os.path.exists(train_result_ptath):
@@ -158,7 +301,7 @@ def train(model, **kwargs):
         os.mkdir(img_results_path)
 
     for epoch in range(epochs):
-        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
         losses = metric_logger.meters['loss'].global_avg
         print(f'epoch {epoch}:loss:{losses}')
         if os.path.exists(f'{wts_path}/last.pt'):
@@ -173,6 +316,9 @@ def train(model, **kwargs):
                 os.remove(f'{wts_path}/best.pt')
             torch.save(model.state_dict(), f'{wts_path}/best.pt')
 
+        evaluate(model, data_loader_test, epoch, writer, device=device)
+
+
 def get_transform(is_train, **kwargs):
     default_params = {
         'augmentation': 'multiscale',
@@ -206,7 +352,15 @@ def get_transform(is_train, **kwargs):
 def write_metric_logs(epoch, metric_logger, writer):
     writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
     writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
-    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    # writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar('Loss/box_reg', metric_logger.meters['loss_keypoint'].global_avg, epoch)
     writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
     writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
-    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
+
+# def log_losses_to_tensorboard(writer, result, step):
+#     writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
+#     writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
+#     writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
+#     writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
+#     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)

+ 54 - 5
models/wirenet/postprocess.py

@@ -34,11 +34,11 @@ def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
         start, end = 0, 1
         for a, b in nlines:
             if (
-                min(
-                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
-                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
-                )
-                > threshold ** 2
+                    min(
+                        max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                        max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                    )
+                    > threshold ** 2
             ):
                 continue
             lambda_a = plambda(*p, *q, *a)
@@ -75,3 +75,52 @@ def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
         nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
         nscores.append(score)
     return np.array(nlines), np.array(nscores)
+
+
+def postprocess_keypoint(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                    min(
+                        max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                        max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                    )
+                    > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(min(score[0],score[1]))
+    return np.array(nlines), np.array(nscores)

+ 4 - 4
tools/coco_utils.py

@@ -139,8 +139,8 @@ def convert_to_coco_api(ds):
         bboxes[:, 2:] -= bboxes[:, :2]
         bboxes = bboxes.tolist()
         labels = targets["labels"].tolist()
-        areas = targets["area"].tolist()
-        iscrowd = targets["iscrowd"].tolist()
+        # areas = targets["area"].tolist()
+        # iscrowd = targets["iscrowd"].tolist()
         if "masks" in targets:
             masks = targets["masks"]
             # make masks Fortran contiguous for coco_mask
@@ -155,8 +155,8 @@ def convert_to_coco_api(ds):
             ann["bbox"] = bboxes[i]
             ann["category_id"] = labels[i]
             categories.add(labels[i])
-            ann["area"] = areas[i]
-            ann["iscrowd"] = iscrowd[i]
+            # ann["area"] = areas[i]
+            # ann["iscrowd"] = iscrowd[i]
             ann["id"] = ann_id
             if "masks" in targets:
                 ann["segmentation"] = coco_mask.encode(masks[i].numpy())