瀏覽代碼

重构LineNet相关代码

RenLiqiang 3 月之前
父節點
當前提交
52b716d1b3

+ 3 - 3
config/wireframe.yaml

@@ -1,9 +1,9 @@
 io:
   logdir: logs/
-  datadir: I:/datasets/wirenet_1000
+  datadir: I:/datasets/wirenet_lm
   resume_from:
-  num_workers: 0
-  tensorboard_port: 0
+  num_workers: 8
+  tensorboard_port: 6000
   validation_interval: 300
 
 model:

+ 1 - 0
libs/vision_libs/models/detection/rpn.py

@@ -370,6 +370,7 @@ class RegionProposalNetwork(torch.nn.Module):
         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
         proposals = proposals.view(num_images, -1, 4)
         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
+        # print(f'boxes:{boxes.shape},scores:{scores.shape}')
 
         losses = {}
         if self.training:

+ 1 - 0
models/line_detect/dataset_LD.py

@@ -117,6 +117,7 @@ class WirePointDataset(BaseDataset):
         # return wire_labels, target
         target["wires"] = wire_labels
         target["boxes"] = line_boxes(target)
+        # print(f'boxes:{target["boxes"].shape}')
         return target
 
     def show(self, idx):

+ 0 - 120
models/line_detect/fasterrcnn_resnet50.py

@@ -1,120 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision
-from typing import Dict, List, Optional, Tuple
-import torch.nn.functional as F
-from torchvision.ops import MultiScaleRoIAlign
-from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
-from torchvision.models.detection.transform import GeneralizedRCNNTransform
-
-
-def get_model(num_classes):
-    # 加载预训练的ResNet-50 FPN backbone
-    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
-
-    # 获取分类器的输入特征数
-    in_features = model.roi_heads.box_predictor.cls_score.in_features
-
-    # 替换分类器以适应新的类别数量
-    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
-
-    return model
-
-
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
-
-class Fasterrcnn_resnet50(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1):
-        super(Fasterrcnn_resnet50, self).__init__()
-
-        self.model = get_model(num_classes=5)
-        self.backbone = self.model.backbone
-
-        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
-
-        out_channels = self.backbone.out_channels
-        resolution = self.box_roi_pool.output_size[0]
-        representation_size = 1024
-        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
-
-        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
-                                             image_std=[0.229, 0.224, 0.225])
-        images, targets = transform(x, target1)
-        x_ = self.backbone(images.tensors)
-
-        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
-        # print(f'backbone:{self.backbone}')
-        # print(f'Fasterrcnn_resnet50 x_:{x_}')
-        feature_ = x_['0']  # 图片特征
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(feature_)
-            outputs.append(output)  # 多头
-
-        if train_or_val == "training":
-            loss_box = self.model(x, target1)
-            return outputs, feature_, loss_box
-        else:
-            box_all = self.model(x, target1)
-            return outputs, feature_, box_all
-
-
-def fasterrcnn_resnet50(**kwargs):
-    model = Fasterrcnn_resnet50(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1)
-    )
-    return model

+ 24 - 0
models/line_detect/line_head.py

@@ -0,0 +1,24 @@
+import torch
+from torch import nn
+
+
+class LineRCNNHeads(nn.Sequential):
+    def __init__(self, input_channels, num_class):
+        super(LineRCNNHeads, self).__init__()
+        # print("输入的维度是:", input_channels)
+        m = int(input_channels / 4)
+        heads = []
+        self.head_size = [[2], [1], [2]]
+        for output_channels in sum(self.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+        assert num_class == sum(sum(self.head_size, []))
+
+    def forward(self, x):
+        return torch.cat([head(x) for head in self.heads], dim=1)

+ 77 - 12
models/line_detect/LineNet.py → models/line_detect/line_net.py

@@ -1,25 +1,33 @@
-from typing import Any, Callable, List, Optional, Tuple, Union
 
+from typing import Any, Callable, List, Optional, Tuple, Union
 import torch
-import torch.nn.functional as F
 from torch import nn
 from torchvision.ops import MultiScaleRoIAlign
 
-from  libs.vision_libs.ops import misc as misc_nn_ops
+from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
+from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
+from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
+from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
+from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
+from libs.vision_libs.ops import misc as misc_nn_ops
 from libs.vision_libs.transforms._presets import ObjectDetection
+from .line_head import LineRCNNHeads
+from .line_predictor import LineRCNNPredictor
+from .roi_heads import RoIHeads
 from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
-from libs.vision_libs.models._meta import _COCO_CATEGORIES
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
 from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
 from libs.vision_libs.models.detection._utils import overwrite_eps
-from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
-from libs.vision_libs.models.detection.backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 
-from libs.vision_libs.models.detection.rpn import RegionProposalNetwork, RPNHead
-from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
+from models.config.config_tool import read_yaml
+import numpy as np
+import torch.nn.functional as F
 
-######## 弃用  ###########
+FEATURE_DIM = 8
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 __all__ = [
     "LineNet",
@@ -28,10 +36,50 @@ __all__ = [
     "LineNet_MobileNet_V3_Large_FPN_Weights",
     "LineNet_MobileNet_V3_Large_320_FPN_Weights",
     "linenet_resnet50_fpn",
-    "fasterrcnn_resnet50_fpn_v2",
+    "linenet_resnet50_fpn_v2",
     "linenet_mobilenet_v3_large_fpn",
     "linenet_mobilenet_v3_large_320_fpn",
 ]
+# __all__ = [
+#     "LineNet",
+#     "LineRCNN_ResNet50_FPN_Weights",
+#     "linercnn_resnet50_fpn",
+# ]
+
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+# class Bottleneck1D(nn.Module):
+#     def __init__(self, inplanes, outplanes):
+#         super(Bottleneck1D, self).__init__()
+#
+#         planes = outplanes // 2
+#         self.op = nn.Sequential(
+#             nn.BatchNorm1d(inplanes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(inplanes, planes, kernel_size=1),
+#             nn.BatchNorm1d(planes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+#             nn.BatchNorm1d(planes),
+#             nn.ReLU(inplace=True),
+#             nn.Conv1d(planes, outplanes, kernel_size=1),
+#         )
+#
+#     def forward(self, x):
+#         return x + self.op(x)
+
+
+
+
+
+
+
+
 
 from .roi_heads import RoIHeads
 
@@ -199,6 +247,9 @@ class LineNet(BaseDetectionNet):
         box_batch_size_per_image=512,
         box_positive_fraction=0.25,
         bbox_reg_weights=None,
+        # line parameters
+        line_head=None,
+        line_predictor=None,
         **kwargs,
     ):
 
@@ -227,6 +278,13 @@ class LineNet(BaseDetectionNet):
 
         out_channels = backbone.out_channels
 
+        if line_head is None:
+            num_class = 5
+            line_head = LineRCNNHeads(out_channels, num_class)
+
+        if line_predictor is None:
+            line_predictor = LineRCNNPredictor()
+
         if rpn_anchor_generator is None:
             rpn_anchor_generator = _default_anchorgen()
         if rpn_head is None:
@@ -265,6 +323,8 @@ class LineNet(BaseDetectionNet):
             box_roi_pool,
             box_head,
             box_predictor,
+            line_head,
+            line_predictor,
             box_fg_iou_thresh,
             box_bg_iou_thresh,
             box_batch_size_per_image,
@@ -283,6 +343,10 @@ class LineNet(BaseDetectionNet):
 
         super().__init__(backbone, rpn, roi_heads, transform)
 
+        self.roi_heads = roi_heads
+        # self.roi_heads.line_head = line_head
+        # self.roi_heads.line_predictor = line_predictor
+
 
 class TwoMLPHead(nn.Module):
     """
@@ -587,7 +651,7 @@ def linenet_resnet50_fpn(
     weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
-def fasterrcnn_resnet50_fpn_v2(
+def linenet_resnet50_fpn_v2(
     *,
     weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
     progress: bool = True,
@@ -845,3 +909,4 @@ def linenet_mobilenet_v3_large_fpn(
         trainable_backbone_layers=trainable_backbone_layers,
         **kwargs,
     )
+

+ 324 - 0
models/line_detect/line_predictor.py

@@ -0,0 +1,324 @@
+from typing import Any, Optional
+
+import torch
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from libs.vision_libs.ops import misc as misc_nn_ops
+from libs.vision_libs.transforms._presets import ObjectDetection
+from .roi_heads import RoIHeads
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+
+from models.config.config_tool import read_yaml
+import numpy as np
+import torch.nn.functional as F
+
+FEATURE_DIM = 8
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+class LineRCNNPredictor(nn.Module):
+    def __init__(self):
+        super().__init__()
+        # self.backbone = backbone
+        # self.cfg = read_yaml(cfg)
+        self.cfg = read_yaml(r'./config/wireframe.yaml')
+        self.n_pts0 = self.cfg['model']['n_pts0']
+        self.n_pts1 = self.cfg['model']['n_pts1']
+        self.n_stc_posl = self.cfg['model']['n_stc_posl']
+        self.dim_loi = self.cfg['model']['dim_loi']
+        self.use_conv = self.cfg['model']['use_conv']
+        self.dim_fc = self.cfg['model']['dim_fc']
+        self.n_out_line = self.cfg['model']['n_out_line']
+        self.n_out_junc = self.cfg['model']['n_out_junc']
+        self.loss_weight = self.cfg['model']['loss_weight']
+        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
+        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
+        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
+        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
+        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
+        self.use_cood = self.cfg['model']['use_cood']
+        self.use_slop = self.cfg['model']['use_slop']
+        self.n_stc_negl = self.cfg['model']['n_stc_negl']
+        self.head_size = self.cfg['model']['head_size']
+        self.num_class = sum(sum(self.head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in self.head_size])
+
+        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
+        scale_factor = self.n_pts0 // self.n_pts1
+        if self.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(self.dim_loi, self.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, self.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(self.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, inputs, features, targets=None):
+
+        # outputs, features = input
+        # for out in outputs:
+        #     print(f'out:{out.shape}')
+        # outputs=merge_features(outputs,100)
+        batch, channel, row, col = inputs.shape
+        # print(f'outputs:{inputs.shape}')
+        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
+
+        if targets is not None:
+            self.training = True
+            # print(f'target:{targets}')
+            wires_targets = [t["wires"] for t in targets]
+            # print(f'wires_target:{wires_targets}')
+            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
+            junc_maps = [d["junc_map"] for d in wires_targets]
+            junc_offsets = [d["junc_offset"] for d in wires_targets]
+            line_maps = [d["line_map"] for d in wires_targets]
+
+            junc_map_tensor = torch.stack(junc_maps, dim=0)
+            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
+            line_map_tensor = torch.stack(line_maps, dim=0)
+
+            wires_meta = {
+                "junc_map": junc_map_tensor,
+                "junc_offset": junc_offset_tensor,
+                # "line_map": line_map_tensor,
+            }
+        else:
+            self.training = False
+            t = {
+                "junc_coords": torch.zeros(1, 2),
+                "jtyp": torch.zeros(1, dtype=torch.uint8),
+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+            wires_targets = [t for b in range(inputs.size(0))]
+
+            wires_meta = {
+                "junc_map": torch.zeros([1, 1, 128, 128]),
+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
+            }
+
+        T = wires_meta.copy()
+        n_jtyp = T["junc_map"].shape[1]
+        offset = self.head_off
+        result = {}
+        for stack, output in enumerate([inputs]):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
+                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
+                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
+
+        h = result["preds"]
+        # print(f'features shape:{features.shape}')
+        x = self.fc1(features)
+
+        # print(f'x:{x.shape}')
+
+        n_batch, n_channel, row, col = x.shape
+
+        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
+
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+
+        for i, meta in enumerate(wires_targets):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i],
+            )
+            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
+            ys.append(label)
+            if self.training and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                .reshape(n_channel, -1, self.n_pts0)
+                .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            # print(f'xp.shape:{xp.shape}')
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+            # print(f'idx__:{idx}')
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
+
+        # print("Weight dtype:", self.fc2.weight.dtype)
+        x = torch.cat([x, f], 1)
+        # print("Input dtype:", x.dtype)
+        x = x.to(dtype=torch.float32)
+        # print("Input dtype1:", x.dtype)
+        x = self.fc2(x).flatten()
+
+        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
+        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
+
+        # if mode != "training":
+        # self.inference(x, idx, jcs, n_batch, ps)
+
+        # return result
+
+    def sample_lines(self, meta, jmap, joff):
+        device = jmap.device
+        with torch.no_grad():
+            junc = meta["junc_coords"].to(device)  # [N, 2]
+            jtyp = meta["jtyp"].to(device)  # [N]
+            Lpos = meta["line_pos_idx"].to(device)
+            Lneg = meta["line_neg_idx"].to(device)
+
+            n_type = jmap.shape[0]
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = self.n_dyn_junc // n_type
+            N = len(junc)
+            # if mode != "training":
+            if not self.training:
+                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            # if mode == "training":
+            if self.training:
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > self.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > self.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * self.use_cood,
+                    xyv / 128 * self.use_cood,
+                    u2v * self.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}

+ 0 - 806
models/line_detect/line_rcnn.py

@@ -1,806 +0,0 @@
-from typing import Any, Optional
-
-import torch
-from torch import nn
-from torchvision.ops import MultiScaleRoIAlign
-
-from libs.vision_libs.ops import misc as misc_nn_ops
-from libs.vision_libs.transforms._presets import ObjectDetection
-from .roi_heads import RoIHeads
-from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
-from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
-from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
-from libs.vision_libs.models.detection._utils import overwrite_eps
-from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
-
-from models.config.config_tool import read_yaml
-import numpy as np
-import torch.nn.functional as F
-
-FEATURE_DIM = 8
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-__all__ = [
-    "LineRCNN",
-    "LineRCNN_ResNet50_FPN_Weights",
-    "linercnn_resnet50_fpn",
-]
-
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-class Bottleneck1D(nn.Module):
-    def __init__(self, inplanes, outplanes):
-        super(Bottleneck1D, self).__init__()
-
-        planes = outplanes // 2
-        self.op = nn.Sequential(
-            nn.BatchNorm1d(inplanes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(inplanes, planes, kernel_size=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-            nn.BatchNorm1d(planes),
-            nn.ReLU(inplace=True),
-            nn.Conv1d(planes, outplanes, kernel_size=1),
-        )
-
-    def forward(self, x):
-        return x + self.op(x)
-
-
-class LineRCNN(FasterRCNN):
-    """
-    Implements Keypoint R-CNN.
-
-    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
-    image, and should be in 0-1 range. Different images can have different sizes.
-
-    The behavior of the model changes depending on if it is in training or evaluation mode.
-
-    During training, the model expects both the input tensors and targets (list of dictionary),
-    containing:
-
-        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
-            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (Int64Tensor[N]): the class label for each ground-truth box
-        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
-          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
-
-    The model returns a Dict[Tensor] during training, containing the classification and regression
-    losses for both the RPN and the R-CNN, and the keypoint loss.
-
-    During inference, the model requires only the input tensors, and returns the post-processed
-    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
-    follows:
-
-        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
-            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (Int64Tensor[N]): the predicted labels for each image
-        - scores (Tensor[N]): the scores or each prediction
-        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
-
-    Args:
-        backbone (nn.Module): the network used to compute the features for the model.
-            It should contain an out_channels attribute, which indicates the number of output
-            channels that each feature map has (and it should be the same for all feature maps).
-            The backbone should return a single Tensor or and OrderedDict[Tensor].
-        num_classes (int): number of output classes of the model (including the background).
-            If box_predictor is specified, num_classes should be None.
-        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
-        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
-        image_mean (Tuple[float, float, float]): mean values used for input normalization.
-            They are generally the mean values of the dataset on which the backbone has been trained
-            on
-        image_std (Tuple[float, float, float]): std values used for input normalization.
-            They are generally the std values of the dataset on which the backbone has been trained on
-        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
-            maps.
-        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
-        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
-        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
-        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
-        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
-        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
-        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
-            considered as positive during training of the RPN.
-        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
-            considered as negative during training of the RPN.
-        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
-            for computing the loss
-        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
-            of the RPN
-        rpn_score_thresh (float): during inference, only return proposals with a classification score
-            greater than rpn_score_thresh
-        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
-            the locations indicated by the bounding boxes
-        box_head (nn.Module): module that takes the cropped feature maps as input
-        box_predictor (nn.Module): module that takes the output of box_head and returns the
-            classification logits and box regression deltas.
-        box_score_thresh (float): during inference, only return proposals with a classification score
-            greater than box_score_thresh
-        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
-        box_detections_per_img (int): maximum number of detections per image, for all classes.
-        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
-            considered as positive during training of the classification head
-        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
-            considered as negative during training of the classification head
-        box_batch_size_per_image (int): number of proposals that are sampled during training of the
-            classification head
-        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
-            of the classification head
-        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
-            bounding boxes
-        keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
-             the locations indicated by the bounding boxes, which will be used for the keypoint head.
-        keypoint_head (nn.Module): module that takes the cropped feature maps as input
-        keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
-            heatmap logits
-
-    Example::
-
-        >>> import torch
-        >>> import torchvision
-        >>> from torchvision.models.detection import KeypointRCNN
-        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
-        >>>
-        >>> # load a pre-trained model for classification and return
-        >>> # only the features
-        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
-        >>> # KeypointRCNN needs to know the number of
-        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
-        >>> # so we need to add it here
-        >>> backbone.out_channels = 1280
-        >>>
-        >>> # let's make the RPN generate 5 x 3 anchors per spatial
-        >>> # location, with 5 different sizes and 3 different aspect
-        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
-        >>> # map could potentially have different sizes and
-        >>> # aspect ratios
-        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
-        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
-        >>>
-        >>> # let's define what are the feature maps that we will
-        >>> # use to perform the region of interest cropping, as well as
-        >>> # the size of the crop after rescaling.
-        >>> # if your backbone returns a Tensor, featmap_names is expected to
-        >>> # be ['0']. More generally, the backbone should return an
-        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
-        >>> # feature maps to use.
-        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
-        >>>                                                 output_size=7,
-        >>>                                                 sampling_ratio=2)
-        >>>
-        >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
-        >>>                                                          output_size=14,
-        >>>                                                          sampling_ratio=2)
-        >>> # put the pieces together inside a KeypointRCNN model
-        >>> model = KeypointRCNN(backbone,
-        >>>                      num_classes=2,
-        >>>                      rpn_anchor_generator=anchor_generator,
-        >>>                      box_roi_pool=roi_pooler,
-        >>>                      keypoint_roi_pool=keypoint_roi_pooler)
-        >>> model.eval()
-        >>> model.eval()
-        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
-        >>> predictions = model(x)
-    """
-
-    def __init__(
-            self,
-            backbone,
-            num_classes=None,
-            # transform parameters
-            min_size=512,  # 原为None
-            max_size=1333,
-            image_mean=None,
-            image_std=None,
-            # RPN parameters
-            rpn_anchor_generator=None,
-            rpn_head=None,
-            rpn_pre_nms_top_n_train=2000,
-            rpn_pre_nms_top_n_test=1000,
-            rpn_post_nms_top_n_train=2000,
-            rpn_post_nms_top_n_test=1000,
-            rpn_nms_thresh=0.7,
-            rpn_fg_iou_thresh=0.7,
-            rpn_bg_iou_thresh=0.3,
-            rpn_batch_size_per_image=256,
-            rpn_positive_fraction=0.5,
-            rpn_score_thresh=0.0,
-            # Box parameters
-            box_roi_pool=None,
-            box_head=None,
-            box_predictor=None,
-            box_score_thresh=0.05,
-            box_nms_thresh=0.5,
-            box_detections_per_img=100,
-            box_fg_iou_thresh=0.5,
-            box_bg_iou_thresh=0.5,
-            box_batch_size_per_image=512,
-            box_positive_fraction=0.25,
-            bbox_reg_weights=None,
-            # line parameters
-            line_head=None,
-            line_predictor=None,
-
-            **kwargs,
-    ):
-
-        # if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
-        #     raise TypeError(
-        #         "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
-        #     )
-        # if min_size is None:
-        #     min_size = (640, 672, 704, 736, 768, 800)
-        #
-        # if num_keypoints is not None:
-        #     if keypoint_predictor is not None:
-        #         raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
-        # else:
-        #     num_keypoints = 17
-
-        out_channels = backbone.out_channels
-
-        if line_head is None:
-            num_class = 5
-            line_head = LineRCNNHeads(out_channels, num_class)
-
-        if line_predictor is None:
-            line_predictor = LineRCNNPredictor()
-
-        super().__init__(
-            backbone,
-            num_classes,
-            # transform parameters
-            min_size,
-            max_size,
-            image_mean,
-            image_std,
-            # RPN-specific parameters
-            rpn_anchor_generator,
-            rpn_head,
-            rpn_pre_nms_top_n_train,
-            rpn_pre_nms_top_n_test,
-            rpn_post_nms_top_n_train,
-            rpn_post_nms_top_n_test,
-            rpn_nms_thresh,
-            rpn_fg_iou_thresh,
-            rpn_bg_iou_thresh,
-            rpn_batch_size_per_image,
-            rpn_positive_fraction,
-            rpn_score_thresh,
-            # Box parameters
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            box_score_thresh,
-            box_nms_thresh,
-            box_detections_per_img,
-            box_fg_iou_thresh,
-            box_bg_iou_thresh,
-            box_batch_size_per_image,
-            box_positive_fraction,
-            bbox_reg_weights,
-            **kwargs,
-        )
-
-        if box_roi_pool is None:
-            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
-
-        if box_head is None:
-            resolution = box_roi_pool.output_size[0]
-            representation_size = 1024
-            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        if box_predictor is None:
-            representation_size = 1024
-            box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        roi_heads = RoIHeads(
-            # Box
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            line_head,
-            line_predictor,
-            box_fg_iou_thresh,
-            box_bg_iou_thresh,
-            box_batch_size_per_image,
-            box_positive_fraction,
-            bbox_reg_weights,
-            box_score_thresh,
-            box_nms_thresh,
-            box_detections_per_img,
-
-        )
-        # super().roi_heads = roi_heads
-        self.roi_heads = roi_heads
-        self.roi_heads.line_head = line_head
-        self.roi_heads.line_predictor = line_predictor
-
-
-class LineRCNNHeads(nn.Sequential):
-    def __init__(self, input_channels, num_class):
-        super(LineRCNNHeads, self).__init__()
-        # print("输入的维度是:", input_channels)
-        m = int(input_channels / 4)
-        heads = []
-        self.head_size = [[2], [1], [2]]
-        for output_channels in sum(self.head_size, []):
-            heads.append(
-                nn.Sequential(
-                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
-                    nn.ReLU(inplace=True),
-                    nn.Conv2d(m, output_channels, kernel_size=1),
-                )
-            )
-        self.heads = nn.ModuleList(heads)
-        assert num_class == sum(sum(self.head_size, []))
-
-    def forward(self, x):
-        return torch.cat([head(x) for head in self.heads], dim=1)
-
-
-
-class LineRCNNPredictor(nn.Module):
-    def __init__(self):
-        super().__init__()
-        # self.backbone = backbone
-        # self.cfg = read_yaml(cfg)
-        self.cfg = read_yaml(r'./config/wireframe.yaml')
-        self.n_pts0 = self.cfg['model']['n_pts0']
-        self.n_pts1 = self.cfg['model']['n_pts1']
-        self.n_stc_posl = self.cfg['model']['n_stc_posl']
-        self.dim_loi = self.cfg['model']['dim_loi']
-        self.use_conv = self.cfg['model']['use_conv']
-        self.dim_fc = self.cfg['model']['dim_fc']
-        self.n_out_line = self.cfg['model']['n_out_line']
-        self.n_out_junc = self.cfg['model']['n_out_junc']
-        self.loss_weight = self.cfg['model']['loss_weight']
-        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
-        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
-        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
-        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
-        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
-        self.use_cood = self.cfg['model']['use_cood']
-        self.use_slop = self.cfg['model']['use_slop']
-        self.n_stc_negl = self.cfg['model']['n_stc_negl']
-        self.head_size = self.cfg['model']['head_size']
-        self.num_class = sum(sum(self.head_size, []))
-        self.head_off = np.cumsum([sum(h) for h in self.head_size])
-
-        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
-        self.register_buffer("lambda_", lambda_)
-        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
-
-        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
-        scale_factor = self.n_pts0 // self.n_pts1
-        if self.use_conv:
-            self.pooling = nn.Sequential(
-                nn.MaxPool1d(scale_factor, scale_factor),
-                Bottleneck1D(self.dim_loi, self.dim_loi),
-            )
-            self.fc2 = nn.Sequential(
-                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
-            )
-        else:
-            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
-            self.fc2 = nn.Sequential(
-                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, self.dim_fc),
-                nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, 1),
-            )
-        self.loss = nn.BCEWithLogitsLoss(reduction="none")
-
-    def forward(self, inputs, features, targets=None):
-
-        # outputs, features = input
-        # for out in outputs:
-        #     print(f'out:{out.shape}')
-        # outputs=merge_features(outputs,100)
-        batch, channel, row, col = inputs.shape
-        # print(f'outputs:{inputs.shape}')
-        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
-
-        if targets is not None:
-            self.training = True
-            # print(f'target:{targets}')
-            wires_targets = [t["wires"] for t in targets]
-            # print(f'wires_target:{wires_targets}')
-            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
-            junc_maps = [d["junc_map"] for d in wires_targets]
-            junc_offsets = [d["junc_offset"] for d in wires_targets]
-            line_maps = [d["line_map"] for d in wires_targets]
-
-            junc_map_tensor = torch.stack(junc_maps, dim=0)
-            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
-            line_map_tensor = torch.stack(line_maps, dim=0)
-
-            wires_meta = {
-                "junc_map": junc_map_tensor,
-                "junc_offset": junc_offset_tensor,
-                # "line_map": line_map_tensor,
-            }
-        else:
-            self.training = False
-            t = {
-                "junc_coords": torch.zeros(1, 2),
-                "jtyp": torch.zeros(1, dtype=torch.uint8),
-                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
-                "junc_map": torch.zeros([1, 1, 128, 128]),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
-            }
-            wires_targets = [t for b in range(inputs.size(0))]
-
-            wires_meta = {
-                "junc_map": torch.zeros([1, 1, 128, 128]),
-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
-            }
-
-        T = wires_meta.copy()
-        n_jtyp = T["junc_map"].shape[1]
-        offset = self.head_off
-        result = {}
-        for stack, output in enumerate([inputs]):
-            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
-            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-            lmap = output[offset[0]: offset[1]].squeeze(0)
-            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
-
-            if stack == 0:
-                result["preds"] = {
-                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
-                    "lmap": lmap.sigmoid(),
-                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
-                }
-                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
-                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
-                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
-
-        h = result["preds"]
-        # print(f'features shape:{features.shape}')
-        x = self.fc1(features)
-
-        # print(f'x:{x.shape}')
-
-        n_batch, n_channel, row, col = x.shape
-
-        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
-
-        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
-
-        for i, meta in enumerate(wires_targets):
-            p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i],
-            )
-            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
-            ys.append(label)
-            if self.training and self.do_static_sampling:
-                p = torch.cat([p, meta["lpre"]])
-                feat = torch.cat([feat, meta["lpre_feat"]])
-                ys.append(meta["lpre_label"])
-                del jc
-            else:
-                jcs.append(jc)
-                ps.append(p)
-            fs.append(feat)
-
-            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
-            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
-            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
-            px0 = px.floor().clamp(min=0, max=127)
-            py0 = py.floor().clamp(min=0, max=127)
-            px1 = (px0 + 1).clamp(min=0, max=127)
-            py1 = (py0 + 1).clamp(min=0, max=127)
-            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
-            # xp: [N_LINE, N_CHANNEL, N_POINT]
-            xp = (
-                (
-                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
-                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
-                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
-                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
-                )
-                .reshape(n_channel, -1, self.n_pts0)
-                .permute(1, 0, 2)
-            )
-            xp = self.pooling(xp)
-            # print(f'xp.shape:{xp.shape}')
-            xs.append(xp)
-            idx.append(idx[-1] + xp.shape[0])
-            # print(f'idx__:{idx}')
-
-        x, y = torch.cat(xs), torch.cat(ys)
-        f = torch.cat(fs)
-        x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-
-        # print("Weight dtype:", self.fc2.weight.dtype)
-        x = torch.cat([x, f], 1)
-        # print("Input dtype:", x.dtype)
-        x = x.to(dtype=torch.float32)
-        # print("Input dtype1:", x.dtype)
-        x = self.fc2(x).flatten()
-
-        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
-        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
-
-        # if mode != "training":
-        # self.inference(x, idx, jcs, n_batch, ps)
-
-        # return result
-
-    def sample_lines(self, meta, jmap, joff):
-        with torch.no_grad():
-            junc = meta["junc_coords"]  # [N, 2]
-            jtyp = meta["jtyp"]  # [N]
-            Lpos = meta["line_pos_idx"]
-            Lneg = meta["line_neg_idx"]
-
-            n_type = jmap.shape[0]
-            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
-            joff = joff.reshape(n_type, 2, -1)
-            max_K = self.n_dyn_junc // n_type
-            N = len(junc)
-            # if mode != "training":
-            if not self.training:
-                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
-            else:
-                K = min(int(N * 2 + 2), max_K)
-            if K < 2:
-                K = 2
-            device = jmap.device
-
-            # index: [N_TYPE, K]
-            score, index = torch.topk(jmap, k=K)
-            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
-            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-
-            # xy: [N_TYPE, K, 2]
-            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
-            xy_ = xy[..., None, :]
-            del x, y, index
-
-            # dist: [N_TYPE, K, N]
-            dist = torch.sum((xy_ - junc) ** 2, -1)
-            cost, match = torch.min(dist, -1)
-
-            # xy: [N_TYPE * K, 2]
-            # match: [N_TYPE, K]
-            for t in range(n_type):
-                match[t, jtyp[match[t]] != t] = N
-            match[cost > 1.5 * 1.5] = N
-            match = match.flatten()
-
-            _ = torch.arange(n_type * K, device=device)
-            u, v = torch.meshgrid(_, _)
-            u, v = u.flatten(), v.flatten()
-            up, vp = match[u], match[v]
-            label = Lpos[up, vp]
-
-            # if mode == "training":
-            if self.training:
-                c = torch.zeros_like(label, dtype=torch.bool)
-
-                # sample positive lines
-                cdx = label.nonzero().flatten()
-                if len(cdx) > self.n_dyn_posl:
-                    # print("too many positive lines")
-                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample negative lines
-                cdx = Lneg[up, vp].nonzero().flatten()
-                if len(cdx) > self.n_dyn_negl:
-                    # print("too many negative lines")
-                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
-                    cdx = cdx[perm]
-                c[cdx] = 1
-
-                # sample other (unmatched) lines
-                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
-                c[cdx] = 1
-            else:
-                c = (u < v).flatten()
-
-            # sample lines
-            u, v, label = u[c], v[c], label[c]
-            xy = xy.reshape(n_type * K, 2)
-            xyu, xyv = xy[u], xy[v]
-
-            u2v = xyu - xyv
-            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
-            feat = torch.cat(
-                [
-                    xyu / 128 * self.use_cood,
-                    xyv / 128 * self.use_cood,
-                    u2v * self.use_slop,
-                    (u[:, None] > K).float(),
-                    (v[:, None] > K).float(),
-                ],
-                1,
-            )
-            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
-
-            xy = xy.reshape(n_type, K, 2)
-            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
-            return line, label.float(), feat, jcs
-
-
-
-_COMMON_META = {
-    "categories": _COCO_PERSON_CATEGORIES,
-    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
-    "min_size": (1, 1),
-}
-
-
-class LineRCNN_ResNet50_FPN_Weights(WeightsEnum):
-    COCO_LEGACY = Weights(
-        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
-        transforms=ObjectDetection,
-        meta={
-            **_COMMON_META,
-            "num_params": 59137258,
-            "recipe": "https://github.com/pytorch/vision/issues/1606",
-            "_metrics": {
-                "COCO-val2017": {
-                    "box_map": 50.6,
-                    "kp_map": 61.1,
-                }
-            },
-            "_ops": 133.924,
-            "_file_size": 226.054,
-            "_docs": """
-                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
-                from an early epoch.
-            """,
-        },
-    )
-    COCO_V1 = Weights(
-        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
-        transforms=ObjectDetection,
-        meta={
-            **_COMMON_META,
-            "num_params": 59137258,
-            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
-            "_metrics": {
-                "COCO-val2017": {
-                    "box_map": 54.6,
-                    "kp_map": 65.0,
-                }
-            },
-            "_ops": 137.42,
-            "_file_size": 226.054,
-            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
-        },
-    )
-    DEFAULT = COCO_V1
-
-
-@register_model()
-@handle_legacy_interface(
-    weights=(
-            "pretrained",
-            lambda kwargs: LineRCNN_ResNet50_FPN_Weights.COCO_LEGACY
-            if kwargs["pretrained"] == "legacy"
-            else LineRCNN_ResNet50_FPN_Weights.COCO_V1,
-    ),
-    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
-)
-def linercnn_resnet50_fpn(
-        *,
-        weights: Optional[LineRCNN_ResNet50_FPN_Weights] = None,
-        progress: bool = True,
-        num_classes: Optional[int] = None,
-        num_keypoints: Optional[int] = None,
-        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
-        trainable_backbone_layers: Optional[int] = None,
-        **kwargs: Any,
-) -> LineRCNN:
-    """
-    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
-
-    .. betastatus:: detection module
-
-    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
-
-    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
-    image, and should be in ``0-1`` range. Different images can have different sizes.
-
-    The behavior of the model changes depending on if it is in training or evaluation mode.
-
-    During training, the model expects both the input tensors and targets (list of dictionary),
-    containing:
-
-        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
-          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
-        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
-          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
-
-    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
-    losses for both the RPN and the R-CNN, and the keypoint loss.
-
-    During inference, the model requires only the input tensors, and returns the post-processed
-    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
-    follows, where ``N`` is the number of detected instances:
-
-        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
-          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (``Int64Tensor[N]``): the predicted labels for each instance
-        - scores (``Tensor[N]``): the scores or each instance
-        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
-
-    For more details on the output, you may refer to :ref:`instance_seg_output`.
-
-    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
-
-    Example::
-
-        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
-        >>> model.eval()
-        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
-        >>> predictions = model(x)
-        >>>
-        >>> # optionally, if you want to export the model to ONNX:
-        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
-
-    Args:
-        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
-            pretrained weights to use. See
-            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
-            below for more details, and possible values. By default, no
-            pre-trained weights are used.
-        progress (bool): If True, displays a progress bar of the download to stderr
-        num_classes (int, optional): number of output classes of the model (including the background)
-        num_keypoints (int, optional): number of keypoints
-        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
-            pretrained weights for the backbone.
-        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
-            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
-            passed (the default) this value is set to 3.
-
-    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
-        :members:
-    """
-    weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
-    weights_backbone = ResNet50_Weights.verify(weights_backbone)
-    if weights is not None:
-        weights_backbone = None
-        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
-        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
-    else:
-        if num_classes is None:
-            num_classes = 2
-        if num_keypoints is None:
-            num_keypoints = 17
-
-    is_trained = weights is not None or weights_backbone is not None
-    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
-    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
-
-    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
-
-    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-    model = LineRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
-
-    if weights is not None:
-        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
-        if weights == LineRCNN_ResNet50_FPN_Weights.COCO_V1:
-            overwrite_eps(model, 0.0)
-
-    return model

+ 7 - 0
models/line_detect/roi_heads.py

@@ -1053,6 +1053,7 @@ class RoIHeads(nn.Module):
 
         features_lcnn = features['0']
         if self.has_line():
+            # print('has line_head')
             outputs = self.line_head(features_lcnn)
             loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
             x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
@@ -1072,10 +1073,16 @@ class RoIHeads(nn.Module):
                 rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
                 loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
             else:
+
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
                 result.append(pred)
                 loss_wirepoint = {}
             losses.update(loss_wirepoint)
+        else:
+            pass
+            # print('has not line_head')
+
+
 
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]

+ 6 - 167
train——line_rcnn.py

@@ -1,164 +1,3 @@
-# 根据LCNN写的train    2025/2/7
-'''
-#!/usr/bin/env python3
-import datetime
-import glob
-import os
-import os.path as osp
-import platform
-import pprint
-import random
-import shlex
-import shutil
-import subprocess
-import sys
-import numpy as np
-import torch
-import torchvision
-import yaml
-import lcnn
-from lcnn.config import C, M
-from lcnn.datasets import WireframeDataset, collate
-from lcnn.models.line_vectorizer import LineVectorizer
-from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
-from torchvision.models import resnet50
-
-from models.line_detect.line_rcnn import linercnn_resnet50_fpn
-
-
-
-def main():
-
-    # 训练配置参数
-    config = {
-        # 数据集配置
-        'datadir': r'D:\python\PycharmProjects\data',  # 数据集目录
-        'config_file': 'config/wireframe.yaml',  # 配置文件路径
-
-        # GPU配置
-        'devices': '0',  # 使用的GPU设备
-        'identifier': 'fasterrcnn_resnet50',  # 训练标识符 stacked_hourglass unet
-
-        # 预训练模型路径
-        # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth',  # 预训练模型路径
-    }
-
-    # 更新配置
-    C.update(C.from_yaml(filename=config['config_file']))
-    M.update(C.model)
-
-    # 设置随机数种子
-    random.seed(0)
-    np.random.seed(0)
-    torch.manual_seed(0)
-
-    # 设备配置
-    device_name = "cpu"
-    os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
-
-    if torch.cuda.is_available():
-        device_name = "cuda"
-        torch.backends.cudnn.deterministic = True
-        torch.cuda.manual_seed(0)
-        print("Let's use", torch.cuda.device_count(), "GPU(s)!")
-    else:
-        print("CUDA is not available")
-
-    device = torch.device(device_name)
-
-    # 数据加载
-    kwargs = {
-        "collate_fn": collate,
-        "num_workers": C.io.num_workers if os.name != "nt" else 0,
-        "pin_memory": True,
-    }
-
-    train_loader = torch.utils.data.DataLoader(
-        WireframeDataset(config['datadir'], dataset_type="train"),
-        shuffle=True,
-        batch_size=M.batch_size,
-        **kwargs,
-    )
-
-    val_loader = torch.utils.data.DataLoader(
-        WireframeDataset(config['datadir'], dataset_type="val"),
-        shuffle=False,
-        batch_size=M.batch_size_eval,
-        **kwargs,
-    )
-
-    model = linercnn_resnet50_fpn().to(device)
-
-    # 加载预训练权重
-
-    try:
-        # 加载模型权重
-        checkpoint = torch.load(config['pretrained_model'], map_location=device)
-
-        # 根据实际的检查点结构选择加载方式
-        if 'model_state_dict' in checkpoint:
-            # 如果是完整的检查点
-            model.load_state_dict(checkpoint['model_state_dict'])
-        elif 'state_dict' in checkpoint:
-            # 如果是只有状态字典的检查点
-            model.load_state_dict(checkpoint['state_dict'])
-        else:
-            # 直接加载权重字典
-            model.load_state_dict(checkpoint)
-
-        print("Successfully loaded pre-trained model weights.")
-    except Exception as e:
-        print(f"Error loading model weights: {e}")
-
-
-    # 优化器配置
-    if C.optim.name == "Adam":
-        optim = torch.optim.Adam(
-            filter(lambda p: p.requires_grad, model.parameters()),
-            lr=C.optim.lr,
-            weight_decay=C.optim.weight_decay,
-            amsgrad=C.optim.amsgrad,
-        )
-    elif C.optim.name == "SGD":
-        optim = torch.optim.SGD(
-            filter(lambda p: p.requires_grad, model.parameters()),
-            lr=C.optim.lr,
-            weight_decay=C.optim.weight_decay,
-            momentum=C.optim.momentum,
-        )
-    else:
-        raise NotImplementedError
-
-    # 输出目录
-    outdir = osp.join(
-        osp.expanduser(C.io.logdir),
-        f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}"
-    )
-    os.makedirs(outdir, exist_ok=True)
-
-    try:
-        trainer = lcnn.trainer.Trainer(
-            device=device,
-            model=model,
-            optimizer=optim,
-            train_loader=train_loader,
-            val_loader=val_loader,
-            out=outdir,
-        )
-
-        print("Starting training...")
-        trainer.train()
-        print("Training completed.")
-
-    except BaseException:
-        if len(glob.glob(f"{outdir}/viz/*")) <= 1:
-            shutil.rmtree(outdir)
-        raise
-
-
-if __name__ == "__main__":
-    main()
-'''
 
 
 # 2025/2/9
@@ -178,7 +17,7 @@ import matplotlib.pyplot as plt
 import matplotlib as mpl
 from skimage import io
 
-from models.line_detect.line_rcnn import linercnn_resnet50_fpn
+from models.line_detect.line_net import  linenet_resnet50_fpn
 from torchvision.utils import draw_bounding_boxes
 from models.wirenet.postprocess import postprocess
 from torchvision import transforms
@@ -277,22 +116,22 @@ if __name__ == '__main__':
     dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
     train_sampler = torch.utils.data.RandomSampler(dataset_train)
     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
     train_collate_fn = utils.collate_fn_wirepoint
     data_loader_train = torch.utils.data.DataLoader(
-        dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn
+        dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
     )
 
     dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
     val_sampler = torch.utils.data.RandomSampler(dataset_val)
     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
+    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
     val_collate_fn = utils.collate_fn_wirepoint
     data_loader_val = torch.utils.data.DataLoader(
-        dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn
+        dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
     )
 
-    model = linercnn_resnet50_fpn().to(device)
+    model = linenet_resnet50_fpn().to(device)
 
     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
     writer = SummaryWriter(cfg['io']['logdir'])