Browse Source

keypoint tensorboard

xue50 5 months ago
parent
commit
5aa2887b3d

+ 1 - 0
models/ins/maskrcnn.py

@@ -125,6 +125,7 @@ class MaskRCNNModel(nn.Module):
 
             # 创建彩色掩码
             colored_mask = np.zeros_like(image)
+
             colored_mask[:] = color
             colored_mask = cv2.bitwise_and(colored_mask, colored_mask, mask=binary_mask)
 

+ 10 - 7
models/keypoint/keypoint_dataset.py

@@ -24,7 +24,6 @@ from torch.utils.data.dataloader import default_collate
 import matplotlib.pyplot as plt
 from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
 
-
 def validate_keypoints(keypoints, image_width, image_height):
     for kp in keypoints:
         x, y, v = kp
@@ -67,7 +66,7 @@ class KeypointDataset(BaseDataset):
         return len(self.imgs)
 
     def read_target(self, item, lbl_path, shape, extra=None):
-        print(f'shape:{shape}')
+        # print(f'shape:{shape}')
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:
             lable_all = json.load(file)
@@ -124,17 +123,18 @@ class KeypointDataset(BaseDataset):
 
         target["labels"] = torch.stack(labels)
         # print(f'labels:{target["labels"]}')
+        # target["boxes"] = line_boxes(target)
         target["boxes"], keypoints = line_boxes(target)
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
         # keypoints= wire_labels["junc_coords"]
         a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
         keypoints = torch.cat((keypoints, a), dim=1)
-        target["keypoints"] = keypoints.to(torch.float32)
-        print(f'boxes:{target["boxes"].shape}')
+        target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
+        # print(f'boxes:{target["boxes"].shape}')
         # 在 __getitem__ 方法中调用此函数
         validate_keypoints(keypoints, shape[0], shape[1])
-        print(f'keypoints:{target["keypoints"].shape}')
+        # print(f'keypoints:{target["keypoints"].shape}')
         return target
 
     def show(self, idx):
@@ -167,6 +167,7 @@ class KeypointDataset(BaseDataset):
                         break
                     plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
 
+
             img_path = os.path.join(self.img_path, self.imgs[idx])
             img = PIL.Image.open(img_path).convert('RGB')
             boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
@@ -190,11 +191,13 @@ class KeypointDataset(BaseDataset):
         # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
         draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
 
+
     def show_img(self, img_path):
         pass
 
 
+
 if __name__ == '__main__':
-    path = r"I:\wirenet_dateset"
-    dataset = KeypointDataset(dataset_path=path, dataset_type='train')
+    path=r"D:\python\PycharmProjects\data"
+    dataset= KeypointDataset(dataset_path=path, dataset_type='train')
     dataset.show(0)

+ 6 - 2
models/keypoint/test.py

@@ -17,7 +17,7 @@ def show(imgs):
         axs[0, i].imshow(np.asarray(img))
         axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
 
-img_path=r"D:\python\PycharmProjects\data\images\val\00031591_0.png"
+img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
 # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
 img_int = read_image(img_path)
 
@@ -58,4 +58,8 @@ keypoints = kpts[idx]
 
 res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
 show(res)
-plt.show()
+plt.show()
+
+
+
+

+ 156 - 10
models/keypoint/trainer.py

@@ -8,12 +8,33 @@ import torchvision
 from torch.utils.tensorboard import SummaryWriter
 from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
 
+from models.wirenet.postprocess import postprocess_keypoint
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+import matplotlib.pyplot as plt
+import numpy as np
+import matplotlib as mpl
+from tools.coco_utils import get_coco_api_from_dataset
+from tools.coco_eval import CocoEvaluator
+import time
+
 from models.config.config_tool import read_yaml
 from models.ins.maskrcnn_dataset import MaskRCNNDataset
 from models.keypoint.keypoint_dataset import KeypointDataset
 from tools import utils, presets
-def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+
+
+def log_losses_to_tensorboard(writer, result, step):
+    writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
+    writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
+    writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
+    writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
+    writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq,writer, scaler=None):
     model.train()
+    total_train_loss=0.0
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"
@@ -27,15 +48,21 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
             optimizer, start_factor=warmup_factor, total_iters=warmup_iters
         )
 
-    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+        global_step = epoch * len(data_loader) + batch_idx
         # print(f'images:{images}')
         images = list(image.to(device) for image in images)
         targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
         with torch.cuda.amp.autocast(enabled=scaler is not None):
             loss_dict = model(images, targets)
-            print(f'loss_dict:{loss_dict}')
+            # print(f'loss_dict:{loss_dict}')
+
             losses = sum(loss for loss in loss_dict.values())
 
+            total_train_loss += losses.item()
+            log_losses_to_tensorboard(writer, loss_dict, global_step)
+
         # reduce losses over all GPUs for logging purposes
         loss_dict_reduced = utils.reduce_dict(loss_dict)
         losses_reduced = sum(loss for loss in loss_dict_reduced.values())
@@ -64,17 +91,133 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
 
     return metric_logger
 
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    # H = pred[1]['wires']
+    lines = pred["keypoints"].detach().cpu().numpy()
+    scores = pred["keypoints_scores"].detach().cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
+    print(f'nscores:{nscores}')
+
+    for i, t in enumerate([0.5]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t :
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im.cpu())
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+
+def _get_iou_types(model):
+    model_without_ddp = model
+    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+        model_without_ddp = model.module
+    iou_types = ["bbox"]
+    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
+        iou_types.append("segm")
+    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
+        iou_types.append("keypoints")
+    return iou_types
+
+
+def evaluate(model, data_loader, epoch, writer, device):
+    n_threads = torch.get_num_threads()
+    # FIXME remove this and make paste_masks_in_image run on the GPU
+    torch.set_num_threads(1)
+    cpu_device = torch.device("cpu")
+    model.eval()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    header = "Test:"
+
+    coco = get_coco_api_from_dataset(data_loader.dataset)
+    iou_types = _get_iou_types(model)
+    coco_evaluator = CocoEvaluator(coco, iou_types)
+
+    print(f'start to evaluate!!!')
+    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
+        images = list(img.to(device) for img in images)
+
+        model_time = time.time()
+        outputs = model(images)
+        print(f'outputs:{outputs}')
+
+        if batch_idx == 0:
+            show_line(images[0], outputs[0], epoch, writer)
+
+        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+        model_time = time.time() - model_time
+
+        res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+        evaluator_time = time.time()
+        coco_evaluator.update(res)
+        evaluator_time = time.time() - evaluator_time
+        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print("Averaged stats:", metric_logger)
+    coco_evaluator.synchronize_between_processes()
+
+    # accumulate predictions from all images
+    coco_evaluator.accumulate()
+    coco_evaluator.summarize()
+    torch.set_num_threads(n_threads)
+    return coco_evaluator
+
+
 def train_cfg(model, cfg):
     parameters = read_yaml(cfg)
     print(f'train parameters:{parameters}')
     train(model, **parameters)
 
+
 def train(model, **kwargs):
     # 默认参数
     default_params = {
         'dataset_path': '/path/to/dataset',
         'num_classes': 2,
-        'num_keypoints':2,
+        'num_keypoints': 2,
         'opt': 'adamw',
         'batch_size': 2,
         'epochs': 10,
@@ -88,7 +231,7 @@ def train(model, **kwargs):
         'target_type': 'polygon',
         'enable_logs': True,
         'augmentation': False,
-        'checkpoint':None
+        'checkpoint': None
     }
     # 更新默认参数
     for key, value in kwargs.items():
@@ -142,9 +285,9 @@ def train(model, **kwargs):
     data_loader = torch.utils.data.DataLoader(
         dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
     )
-    # data_loader_test = torch.utils.data.DataLoader(
-    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
-    # )
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    )
 
     img_results_path = os.path.join(train_result_ptath, 'img_results')
     if os.path.exists(train_result_ptath):
@@ -158,7 +301,7 @@ def train(model, **kwargs):
         os.mkdir(img_results_path)
 
     for epoch in range(epochs):
-        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
         losses = metric_logger.meters['loss'].global_avg
         print(f'epoch {epoch}:loss:{losses}')
         if os.path.exists(f'{wts_path}/last.pt'):
@@ -173,6 +316,9 @@ def train(model, **kwargs):
                 os.remove(f'{wts_path}/best.pt')
             torch.save(model.state_dict(), f'{wts_path}/best.pt')
 
+        evaluate(model, data_loader_test, epoch, writer, device=device)
+
+
 def get_transform(is_train, **kwargs):
     default_params = {
         'augmentation': 'multiscale',
@@ -209,4 +355,4 @@ def write_metric_logs(epoch, metric_logger, writer):
     writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
     writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
     writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
-    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

File diff suppressed because it is too large
+ 0 - 2693
models/wirenet/head.py


+ 54 - 5
models/wirenet/postprocess.py

@@ -34,11 +34,11 @@ def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
         start, end = 0, 1
         for a, b in nlines:
             if (
-                min(
-                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
-                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
-                )
-                > threshold ** 2
+                    min(
+                        max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                        max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                    )
+                    > threshold ** 2
             ):
                 continue
             lambda_a = plambda(*p, *q, *a)
@@ -75,3 +75,52 @@ def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
         nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
         nscores.append(score)
     return np.array(nlines), np.array(nscores)
+
+
+def postprocess_keypoint(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                    min(
+                        max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                        max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                    )
+                    > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(min(score[0],score[1]))
+    return np.array(nlines), np.array(nscores)

+ 0 - 1
models/wirenet/roi_head.py

@@ -767,7 +767,6 @@ class RoIHeads(nn.Module):
             regression_targets = None
             matched_idxs = None
 
-
         box_features = self.box_roi_pool(features, proposals, image_shapes)
         box_features = self.box_head(box_features)
         class_logits, box_regression = self.box_predictor(box_features)

+ 2 - 11
models/wirenet/test.py

@@ -1,8 +1,6 @@
 from models.wirenet.wirepoint_dataset import WirePointDataset
 from models.config.config_tool import read_yaml
 
-import matplotlib.pyplot as plt
-
 # image_file = "D:/python/PycharmProjects/data"
 #
 # label_file = "D:/python/PycharmProjects/data/labels/train"
@@ -10,8 +8,6 @@ import matplotlib.pyplot as plt
 # dataset_test.show(0)
 # for i in dataset_test:
 #     print(i)
-
-
 cfg = 'wirenet.yaml'
 cfg = read_yaml(cfg)
 print(f'cfg:{cfg}')
@@ -21,12 +17,7 @@ print(cfg['model']['n_dyn_negl'])
 dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
 # dataset.show(0)
 
-# for i in range(len(dataset)):
-#     # dataset.show(i)
-for i in dataset:
-    # dataset.show(i)
-    # print(i)
-    print(i[1]['wires']['line_map'].shape)
-    plt.show(dataset['wires']['line_map'])
+for i in range(len(dataset)):
+    dataset.show(i)
 
 

+ 14 - 0
models/wirenet/test_mask.py

@@ -0,0 +1,14 @@
+import torch
+from matplotlib import pyplot as plt
+
+img=torch.ones((128,128,3))
+mask=torch.zeros((128,128,3))
+
+mask[0:30,:,:]=1
+
+
+img[mask==1]=0
+
+
+plt.imshow(img)
+plt.show()

+ 1 - 1
models/wirenet/wirenet.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: D:\python\PycharmProjects\data
+  datadir: I:/wirenet_dateset
   resume_from:
   num_workers: 4
   tensorboard_port: 0

+ 16 - 46
models/wirenet/wirepoint_rcnn.py

@@ -10,7 +10,6 @@ import torch.nn.functional as F
 # from torchinfo import summary
 from torchvision.io import read_image
 from torchvision.models import resnet50, ResNet50_Weights
-from torchvision.models import resnet18, ResNet18_Weights
 from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
 from torchvision.models.detection._utils import overwrite_eps
 from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
@@ -431,7 +430,9 @@ class WirepointPredictor(nn.Module):
             Lneg = meta["line_neg_idx"]
 
             n_type = jmap.shape[0]
+            print(f'jmap:{jmap.shape}')
             jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+
             joff = joff.reshape(n_type, 2, -1)
             max_K = self.n_dyn_junc // n_type
             N = len(junc)
@@ -523,55 +524,26 @@ class WirepointPredictor(nn.Module):
             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
             return line, label.float(), feat, jcs
 
-# def wirepointrcnn_resnet50_fpn(
-#         *,
-#         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
-#         progress: bool = True,
-#         num_classes: Optional[int] = None,
-#         num_keypoints: Optional[int] = None,
-#         weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
-#         trainable_backbone_layers: Optional[int] = None,
-#         **kwargs: Any,
-# ) -> WirepointRCNN:
-#     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
-#     weights_backbone = ResNet50_Weights.verify(weights_backbone)
-#
-#     is_trained = weights is not None or weights_backbone is not None
-#     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
-#
-#     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
-#
-#     backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
-#     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-#     model = WirepointRCNN(backbone, num_classes=5, **kwargs)
-#
-#     if weights is not None:
-#         model.load_state_dict(weights.get_state_dict(progress=progress))
-#         if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
-#             overwrite_eps(model, 0.0)
-#
-#     return model
-
 
-def wirepointrcnn_resnet18_fpn(
+def wirepointrcnn_resnet50_fpn(
         *,
         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
         progress: bool = True,
         num_classes: Optional[int] = None,
         num_keypoints: Optional[int] = None,
-        weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
         trainable_backbone_layers: Optional[int] = None,
         **kwargs: Any,
 ) -> WirepointRCNN:
     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
-    weights_backbone = ResNet18_Weights.verify(weights_backbone)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
 
     is_trained = weights is not None or weights_backbone is not None
     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
 
     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
 
-    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
     model = WirepointRCNN(backbone, num_classes=5, **kwargs)
 
@@ -758,6 +730,7 @@ def show_line(img, pred, prefix, epoch, write):
         plt.show()
         plt.close()
 
+
         img2 = cv2.imread(fn)  # 预测图
         # img1 = im.resize(img2.shape)  # 原图
 
@@ -800,11 +773,7 @@ if __name__ == '__main__':
         dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
     )
 
-    model = wirepointrcnn_resnet18_fpn().to(device)
-    # print(model)
-
-    # model1 = wirepointrcnn_resnet50_fpn().to(device)
-    # print(model1)
+    model = wirepointrcnn_resnet50_fpn().to(device)
 
     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
     writer = SummaryWriter(cfg['io']['logdir'])
@@ -845,26 +814,27 @@ if __name__ == '__main__':
         model.train()
 
         for imgs, targets in data_loader_train:
+            print(f'targets:{targets[0]["wires"]["line_map"].shape}')
             losses = model(move_to_device(imgs, device), move_to_device(targets, device))
             loss = _loss(losses)
-            print(f"loss:{loss}")
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
-            writer_loss(writer, losses, epoch)
+            print(loss)
+        # optimizer.zero_grad()
+        # loss.backward()
+        # optimizer.step()
+        # writer_loss(writer, losses, epoch)
 
         # model.eval()
         # with torch.no_grad():
         #     for batch_idx, (imgs, targets) in enumerate(data_loader_val):
         #         pred = model(move_to_device(imgs, device))
-        #         print(f"pred:{pred}")
+        #         # print(f"pred:{pred}")
         #
         #         if batch_idx == 0:
         #             result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
         #             print(imgs[0].shape)  # [3,512,512]
         #             # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
         #             _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
-        #             show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
+                    # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
 
 # imgs, targets = next(iter(data_loader))
 #

+ 4 - 4
tools/coco_utils.py

@@ -139,8 +139,8 @@ def convert_to_coco_api(ds):
         bboxes[:, 2:] -= bboxes[:, :2]
         bboxes = bboxes.tolist()
         labels = targets["labels"].tolist()
-        areas = targets["area"].tolist()
-        iscrowd = targets["iscrowd"].tolist()
+        # areas = targets["area"].tolist()
+        # iscrowd = targets["iscrowd"].tolist()
         if "masks" in targets:
             masks = targets["masks"]
             # make masks Fortran contiguous for coco_mask
@@ -155,8 +155,8 @@ def convert_to_coco_api(ds):
             ann["bbox"] = bboxes[i]
             ann["category_id"] = labels[i]
             categories.add(labels[i])
-            ann["area"] = areas[i]
-            ann["iscrowd"] = iscrowd[i]
+            # ann["area"] = areas[i]
+            # ann["iscrowd"] = iscrowd[i]
             ann["id"] = ann_id
             if "masks" in targets:
                 ann["segmentation"] = coco_mask.encode(masks[i].numpy())

Some files were not shown because too many files changed in this diff