|
@@ -13,7 +13,7 @@ from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_
|
|
|
from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
|
|
|
from libs.vision_libs.models.detection._utils import overwrite_eps
|
|
|
from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
|
|
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN
|
|
|
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
|
|
|
|
from models.config.config_tool import read_yaml
|
|
|
import numpy as np
|
|
@@ -196,7 +196,7 @@ class LineRCNN(FasterRCNN):
|
|
|
backbone,
|
|
|
num_classes=None,
|
|
|
|
|
|
- min_size=None,
|
|
|
+ min_size=512,
|
|
|
max_size=1333,
|
|
|
image_mean=None,
|
|
|
image_std=None,
|
|
@@ -292,6 +292,18 @@ class LineRCNN(FasterRCNN):
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
+ if box_roi_pool is None:
|
|
|
+ box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
|
|
|
+
|
|
|
+ if box_head is None:
|
|
|
+ resolution = box_roi_pool.output_size[0]
|
|
|
+ representation_size = 1024
|
|
|
+ box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
|
|
|
+
|
|
|
+ if box_predictor is None:
|
|
|
+ representation_size = 1024
|
|
|
+ box_predictor = FastRCNNPredictor(representation_size, num_classes)
|
|
|
+
|
|
|
roi_heads = RoIHeads(
|
|
|
|
|
|
box_roi_pool,
|
|
@@ -311,7 +323,6 @@ class LineRCNN(FasterRCNN):
|
|
|
)
|
|
|
|
|
|
self.roi_heads = roi_heads
|
|
|
-
|
|
|
self.roi_heads.line_head = line_head
|
|
|
self.roi_heads.line_predictor = line_predictor
|
|
|
|
|
@@ -355,7 +366,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
- self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\models\line_detect\wireframe.yaml')
|
|
|
+ self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml')
|
|
|
self.n_pts0 = self.cfg['model']['n_pts0']
|
|
|
self.n_pts1 = self.cfg['model']['n_pts1']
|
|
|
self.n_stc_posl = self.cfg['model']['n_stc_posl']
|
|
@@ -402,12 +413,15 @@ class LineRCNNPredictor(nn.Module):
|
|
|
)
|
|
|
self.loss = nn.BCEWithLogitsLoss(reduction="none")
|
|
|
|
|
|
- def forward(self, result, targets=None):
|
|
|
+ def forward(self, inputs, features, targets=None):
|
|
|
|
|
|
-
|
|
|
- h = result["preds"]
|
|
|
- x = self.fc1(result["feature"])
|
|
|
- n_batch, n_channel, row, col = x.shape
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ batch, channel, row, col = inputs.shape
|
|
|
+
|
|
|
+
|
|
|
|
|
|
if targets is not None:
|
|
|
self.training = True
|
|
@@ -430,30 +444,61 @@ class LineRCNNPredictor(nn.Module):
|
|
|
}
|
|
|
else:
|
|
|
self.training = False
|
|
|
-
|
|
|
t = {
|
|
|
- "junc_coords": torch.zeros(1, 2).to(device),
|
|
|
- "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
- "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
- "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ "junc_coords": torch.zeros(1, 2),
|
|
|
+ "jtyp": torch.zeros(1, dtype=torch.uint8),
|
|
|
+ "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
+ "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
}
|
|
|
wires_targets = [t for b in range(inputs.size(0))]
|
|
|
|
|
|
wires_meta = {
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
}
|
|
|
|
|
|
+ T = wires_meta.copy()
|
|
|
+ n_jtyp = T["junc_map"].shape[1]
|
|
|
+ offset = self.head_off
|
|
|
+ result = {}
|
|
|
+ for stack, output in enumerate([inputs]):
|
|
|
+ output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
|
|
|
+
|
|
|
+ jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
|
|
|
+ lmap = output[offset[0]: offset[1]].squeeze(0)
|
|
|
+ joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
|
|
|
+
|
|
|
+ if stack == 0:
|
|
|
+ result["preds"] = {
|
|
|
+ "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
|
|
|
+ "lmap": lmap.sigmoid(),
|
|
|
+ "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ h = result["preds"]
|
|
|
+
|
|
|
+ x = self.fc1(features)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ n_batch, n_channel, row, col = x.shape
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
|
|
|
- for i, meta in enumerate(input_dict["meta"]):
|
|
|
+
|
|
|
+ for i, meta in enumerate(wires_targets):
|
|
|
p, label, feat, jc = self.sample_lines(
|
|
|
- meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
|
|
|
+ meta, h["jmap"][i], h["joff"][i],
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
ys.append(label)
|
|
|
- if input_dict["mode"] == "training" and self.do_static_sampling:
|
|
|
+ if self.training and self.do_static_sampling:
|
|
|
p = torch.cat([p, meta["lpre"]])
|
|
|
feat = torch.cat([feat, meta["lpre_feat"]])
|
|
|
ys.append(meta["lpre_label"])
|
|
@@ -480,25 +525,28 @@ class LineRCNNPredictor(nn.Module):
|
|
|
+ x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
+ x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
)
|
|
|
- .reshape(n_channel, -1, M.n_pts0)
|
|
|
+ .reshape(n_channel, -1, self.n_pts0)
|
|
|
.permute(1, 0, 2)
|
|
|
)
|
|
|
xp = self.pooling(xp)
|
|
|
+
|
|
|
xs.append(xp)
|
|
|
idx.append(idx[-1] + xp.shape[0])
|
|
|
-
|
|
|
+
|
|
|
|
|
|
x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
|
x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
+
|
|
|
+
|
|
|
x = torch.cat([x, f], 1)
|
|
|
+
|
|
|
x = x.to(dtype=torch.float32)
|
|
|
+
|
|
|
x = self.fc2(x).flatten()
|
|
|
|
|
|
|
|
|
- all=[x, ys, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc]
|
|
|
- return all
|
|
|
-
|
|
|
+ return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
|
|
|
|
|
|
|
|
|
@@ -536,9 +584,6 @@ class LineRCNNPredictor(nn.Module):
|
|
|
xy_ = xy[..., None, :]
|
|
|
del x, y, index
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
|
|
|
dist = torch.sum((xy_ - junc) ** 2, -1)
|
|
|
cost, match = torch.min(dist, -1)
|
|
@@ -604,6 +649,208 @@ class LineRCNNPredictor(nn.Module):
|
|
|
xy = xy.reshape(n_type, K, 2)
|
|
|
jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
|
return line, label.float(), feat, jcs
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
|
|
|
|
|
@@ -746,7 +993,6 @@ def linercnn_resnet50_fpn(
|
|
|
"""
|
|
|
weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
|
|
|
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
|
-
|
|
|
if weights is not None:
|
|
|
weights_backbone = None
|
|
|
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|