|
@@ -13,7 +13,7 @@ from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_
|
|
|
from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
|
|
|
from libs.vision_libs.models.detection._utils import overwrite_eps
|
|
|
from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
|
|
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN
|
|
|
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
|
|
|
|
from models.config.config_tool import read_yaml
|
|
|
import numpy as np
|
|
@@ -196,7 +196,7 @@ class LineRCNN(FasterRCNN):
|
|
|
backbone,
|
|
|
num_classes=None,
|
|
|
# transform parameters
|
|
|
- min_size=None,
|
|
|
+ min_size=512, # 原为None
|
|
|
max_size=1333,
|
|
|
image_mean=None,
|
|
|
image_std=None,
|
|
@@ -292,6 +292,18 @@ class LineRCNN(FasterRCNN):
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
+ if box_roi_pool is None:
|
|
|
+ box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
|
|
|
+
|
|
|
+ if box_head is None:
|
|
|
+ resolution = box_roi_pool.output_size[0]
|
|
|
+ representation_size = 1024
|
|
|
+ box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
|
|
|
+
|
|
|
+ if box_predictor is None:
|
|
|
+ representation_size = 1024
|
|
|
+ box_predictor = FastRCNNPredictor(representation_size, num_classes)
|
|
|
+
|
|
|
roi_heads = RoIHeads(
|
|
|
# Box
|
|
|
box_roi_pool,
|
|
@@ -311,7 +323,6 @@ class LineRCNN(FasterRCNN):
|
|
|
)
|
|
|
# super().roi_heads = roi_heads
|
|
|
self.roi_heads = roi_heads
|
|
|
-
|
|
|
self.roi_heads.line_head = line_head
|
|
|
self.roi_heads.line_predictor = line_predictor
|
|
|
|
|
@@ -355,7 +366,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
super().__init__()
|
|
|
# self.backbone = backbone
|
|
|
# self.cfg = read_yaml(cfg)
|
|
|
- self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\models\line_detect\wireframe.yaml')
|
|
|
+ self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml')
|
|
|
self.n_pts0 = self.cfg['model']['n_pts0']
|
|
|
self.n_pts1 = self.cfg['model']['n_pts1']
|
|
|
self.n_stc_posl = self.cfg['model']['n_stc_posl']
|
|
@@ -402,12 +413,15 @@ class LineRCNNPredictor(nn.Module):
|
|
|
)
|
|
|
self.loss = nn.BCEWithLogitsLoss(reduction="none")
|
|
|
|
|
|
- def forward(self, result, targets=None):
|
|
|
+ def forward(self, inputs, features, targets=None):
|
|
|
|
|
|
- # result = self.backbone(input_dict)
|
|
|
- h = result["preds"]
|
|
|
- x = self.fc1(result["feature"])
|
|
|
- n_batch, n_channel, row, col = x.shape
|
|
|
+ # outputs, features = input
|
|
|
+ # for out in outputs:
|
|
|
+ # print(f'out:{out.shape}')
|
|
|
+ # outputs=merge_features(outputs,100)
|
|
|
+ batch, channel, row, col = inputs.shape
|
|
|
+ # print(f'outputs:{inputs.shape}')
|
|
|
+ # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
|
|
|
|
|
|
if targets is not None:
|
|
|
self.training = True
|
|
@@ -430,30 +444,61 @@ class LineRCNNPredictor(nn.Module):
|
|
|
}
|
|
|
else:
|
|
|
self.training = False
|
|
|
- # self.training = False
|
|
|
t = {
|
|
|
- "junc_coords": torch.zeros(1, 2).to(device),
|
|
|
- "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
- "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
- "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ "junc_coords": torch.zeros(1, 2),
|
|
|
+ "jtyp": torch.zeros(1, dtype=torch.uint8),
|
|
|
+ "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
+ "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
}
|
|
|
wires_targets = [t for b in range(inputs.size(0))]
|
|
|
|
|
|
wires_meta = {
|
|
|
- "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
|
|
|
}
|
|
|
|
|
|
+ T = wires_meta.copy()
|
|
|
+ n_jtyp = T["junc_map"].shape[1]
|
|
|
+ offset = self.head_off
|
|
|
+ result = {}
|
|
|
+ for stack, output in enumerate([inputs]):
|
|
|
+ output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
|
|
|
+ # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
|
|
|
+ jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
|
|
|
+ lmap = output[offset[0]: offset[1]].squeeze(0)
|
|
|
+ joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
|
|
|
+
|
|
|
+ if stack == 0:
|
|
|
+ result["preds"] = {
|
|
|
+ "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
|
|
|
+ "lmap": lmap.sigmoid(),
|
|
|
+ "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
|
|
|
+ }
|
|
|
+ # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
|
|
|
+ # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
|
|
|
+ # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
|
|
|
+
|
|
|
+ h = result["preds"]
|
|
|
+ # print(f'features shape:{features.shape}')
|
|
|
+ x = self.fc1(features)
|
|
|
+
|
|
|
+ # print(f'x:{x.shape}')
|
|
|
+
|
|
|
+ n_batch, n_channel, row, col = x.shape
|
|
|
+
|
|
|
+ # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
|
|
|
+
|
|
|
xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
|
|
|
- for i, meta in enumerate(input_dict["meta"]):
|
|
|
+
|
|
|
+ for i, meta in enumerate(wires_targets):
|
|
|
p, label, feat, jc = self.sample_lines(
|
|
|
- meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
|
|
|
+ meta, h["jmap"][i], h["joff"][i],
|
|
|
)
|
|
|
- # print("p.shape:", p.shape)
|
|
|
+ # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
|
|
|
ys.append(label)
|
|
|
- if input_dict["mode"] == "training" and self.do_static_sampling:
|
|
|
+ if self.training and self.do_static_sampling:
|
|
|
p = torch.cat([p, meta["lpre"]])
|
|
|
feat = torch.cat([feat, meta["lpre_feat"]])
|
|
|
ys.append(meta["lpre_label"])
|
|
@@ -480,25 +525,28 @@ class LineRCNNPredictor(nn.Module):
|
|
|
+ x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
+ x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
)
|
|
|
- .reshape(n_channel, -1, M.n_pts0)
|
|
|
+ .reshape(n_channel, -1, self.n_pts0)
|
|
|
.permute(1, 0, 2)
|
|
|
)
|
|
|
xp = self.pooling(xp)
|
|
|
+ # print(f'xp.shape:{xp.shape}')
|
|
|
xs.append(xp)
|
|
|
idx.append(idx[-1] + xp.shape[0])
|
|
|
-
|
|
|
+ # print(f'idx__:{idx}')
|
|
|
|
|
|
x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
|
x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
+
|
|
|
+ # print("Weight dtype:", self.fc2.weight.dtype)
|
|
|
x = torch.cat([x, f], 1)
|
|
|
+ # print("Input dtype:", x.dtype)
|
|
|
x = x.to(dtype=torch.float32)
|
|
|
+ # print("Input dtype1:", x.dtype)
|
|
|
x = self.fc2(x).flatten()
|
|
|
|
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
|
- all=[x, ys, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc]
|
|
|
- return all
|
|
|
- # return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
|
+ return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
|
|
|
|
# if mode != "training":
|
|
|
# self.inference(x, idx, jcs, n_batch, ps)
|
|
@@ -536,9 +584,6 @@ class LineRCNNPredictor(nn.Module):
|
|
|
xy_ = xy[..., None, :]
|
|
|
del x, y, index
|
|
|
|
|
|
- # print(f"xy_.is_cuda: {xy_.is_cuda}")
|
|
|
- # print(f"junc.is_cuda: {junc.is_cuda}")
|
|
|
-
|
|
|
# dist: [N_TYPE, K, N]
|
|
|
dist = torch.sum((xy_ - junc) ** 2, -1)
|
|
|
cost, match = torch.min(dist, -1)
|
|
@@ -604,6 +649,208 @@ class LineRCNNPredictor(nn.Module):
|
|
|
xy = xy.reshape(n_type, K, 2)
|
|
|
jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
|
return line, label.float(), feat, jcs
|
|
|
+ # def forward(self, result, targets=None):
|
|
|
+ #
|
|
|
+ # # result = self.backbone(input_dict)
|
|
|
+ # h = result["preds"]
|
|
|
+ # x = self.fc1(result["feature"])
|
|
|
+ # n_batch, n_channel, row, col = x.shape
|
|
|
+ #
|
|
|
+ # if targets is not None:
|
|
|
+ # self.training = True
|
|
|
+ # # print(f'target:{targets}')
|
|
|
+ # wires_targets = [t["wires"] for t in targets]
|
|
|
+ # # print(f'wires_target:{wires_targets}')
|
|
|
+ # # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
|
|
|
+ # junc_maps = [d["junc_map"] for d in wires_targets]
|
|
|
+ # junc_offsets = [d["junc_offset"] for d in wires_targets]
|
|
|
+ # line_maps = [d["line_map"] for d in wires_targets]
|
|
|
+ #
|
|
|
+ # junc_map_tensor = torch.stack(junc_maps, dim=0)
|
|
|
+ # junc_offset_tensor = torch.stack(junc_offsets, dim=0)
|
|
|
+ # line_map_tensor = torch.stack(line_maps, dim=0)
|
|
|
+ #
|
|
|
+ # wires_meta = {
|
|
|
+ # "junc_map": junc_map_tensor,
|
|
|
+ # "junc_offset": junc_offset_tensor,
|
|
|
+ # # "line_map": line_map_tensor,
|
|
|
+ # }
|
|
|
+ # else:
|
|
|
+ # self.training = False
|
|
|
+ # # self.training = False
|
|
|
+ # t = {
|
|
|
+ # "junc_coords": torch.zeros(1, 2).to(device),
|
|
|
+ # "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
+ # "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ # "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ # "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ # "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ # }
|
|
|
+ # wires_targets = [t for b in range(inputs.size(0))]
|
|
|
+ #
|
|
|
+ # wires_meta = {
|
|
|
+ # "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ # "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ # }
|
|
|
+ #
|
|
|
+ # xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
|
|
|
+ # for i, meta in enumerate(input_dict["meta"]):
|
|
|
+ # p, label, feat, jc = self.sample_lines(
|
|
|
+ # meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
|
|
|
+ # )
|
|
|
+ # # print("p.shape:", p.shape)
|
|
|
+ # ys.append(label)
|
|
|
+ # if input_dict["mode"] == "training" and self.do_static_sampling:
|
|
|
+ # p = torch.cat([p, meta["lpre"]])
|
|
|
+ # feat = torch.cat([feat, meta["lpre_feat"]])
|
|
|
+ # ys.append(meta["lpre_label"])
|
|
|
+ # del jc
|
|
|
+ # else:
|
|
|
+ # jcs.append(jc)
|
|
|
+ # ps.append(p)
|
|
|
+ # fs.append(feat)
|
|
|
+ #
|
|
|
+ # p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
|
|
|
+ # p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
|
|
|
+ # px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
|
|
|
+ # px0 = px.floor().clamp(min=0, max=127)
|
|
|
+ # py0 = py.floor().clamp(min=0, max=127)
|
|
|
+ # px1 = (px0 + 1).clamp(min=0, max=127)
|
|
|
+ # py1 = (py0 + 1).clamp(min=0, max=127)
|
|
|
+ # px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
|
|
|
+ #
|
|
|
+ # # xp: [N_LINE, N_CHANNEL, N_POINT]
|
|
|
+ # xp = (
|
|
|
+ # (
|
|
|
+ # x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
|
|
|
+ # + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
|
|
|
+ # + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
+ # + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
+ # )
|
|
|
+ # .reshape(n_channel, -1, M.n_pts0)
|
|
|
+ # .permute(1, 0, 2)
|
|
|
+ # )
|
|
|
+ # xp = self.pooling(xp)
|
|
|
+ # xs.append(xp)
|
|
|
+ # idx.append(idx[-1] + xp.shape[0])
|
|
|
+ #
|
|
|
+ #
|
|
|
+ # x, y = torch.cat(xs), torch.cat(ys)
|
|
|
+ # f = torch.cat(fs)
|
|
|
+ # x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
+ # x = torch.cat([x, f], 1)
|
|
|
+ # x = x.to(dtype=torch.float32)
|
|
|
+ # x = self.fc2(x).flatten()
|
|
|
+ #
|
|
|
+ # # return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
|
+ # all=[x, ys, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc]
|
|
|
+ # return all
|
|
|
+ # # return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
|
+ #
|
|
|
+ # # if mode != "training":
|
|
|
+ # # self.inference(x, idx, jcs, n_batch, ps)
|
|
|
+ #
|
|
|
+ # # return result
|
|
|
+ #
|
|
|
+ # def sample_lines(self, meta, jmap, joff):
|
|
|
+ # with torch.no_grad():
|
|
|
+ # junc = meta["junc_coords"] # [N, 2]
|
|
|
+ # jtyp = meta["jtyp"] # [N]
|
|
|
+ # Lpos = meta["line_pos_idx"]
|
|
|
+ # Lneg = meta["line_neg_idx"]
|
|
|
+ #
|
|
|
+ # n_type = jmap.shape[0]
|
|
|
+ # jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
|
|
|
+ # joff = joff.reshape(n_type, 2, -1)
|
|
|
+ # max_K = self.n_dyn_junc // n_type
|
|
|
+ # N = len(junc)
|
|
|
+ # # if mode != "training":
|
|
|
+ # if not self.training:
|
|
|
+ # K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
|
|
|
+ # else:
|
|
|
+ # K = min(int(N * 2 + 2), max_K)
|
|
|
+ # if K < 2:
|
|
|
+ # K = 2
|
|
|
+ # device = jmap.device
|
|
|
+ #
|
|
|
+ # # index: [N_TYPE, K]
|
|
|
+ # score, index = torch.topk(jmap, k=K)
|
|
|
+ # y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
|
|
|
+ # x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
|
|
|
+ #
|
|
|
+ # # xy: [N_TYPE, K, 2]
|
|
|
+ # xy = torch.cat([y[..., None], x[..., None]], dim=-1)
|
|
|
+ # xy_ = xy[..., None, :]
|
|
|
+ # del x, y, index
|
|
|
+ #
|
|
|
+ # # print(f"xy_.is_cuda: {xy_.is_cuda}")
|
|
|
+ # # print(f"junc.is_cuda: {junc.is_cuda}")
|
|
|
+ #
|
|
|
+ # # dist: [N_TYPE, K, N]
|
|
|
+ # dist = torch.sum((xy_ - junc) ** 2, -1)
|
|
|
+ # cost, match = torch.min(dist, -1)
|
|
|
+ #
|
|
|
+ # # xy: [N_TYPE * K, 2]
|
|
|
+ # # match: [N_TYPE, K]
|
|
|
+ # for t in range(n_type):
|
|
|
+ # match[t, jtyp[match[t]] != t] = N
|
|
|
+ # match[cost > 1.5 * 1.5] = N
|
|
|
+ # match = match.flatten()
|
|
|
+ #
|
|
|
+ # _ = torch.arange(n_type * K, device=device)
|
|
|
+ # u, v = torch.meshgrid(_, _)
|
|
|
+ # u, v = u.flatten(), v.flatten()
|
|
|
+ # up, vp = match[u], match[v]
|
|
|
+ # label = Lpos[up, vp]
|
|
|
+ #
|
|
|
+ # # if mode == "training":
|
|
|
+ # if self.training:
|
|
|
+ # c = torch.zeros_like(label, dtype=torch.bool)
|
|
|
+ #
|
|
|
+ # # sample positive lines
|
|
|
+ # cdx = label.nonzero().flatten()
|
|
|
+ # if len(cdx) > self.n_dyn_posl:
|
|
|
+ # # print("too many positive lines")
|
|
|
+ # perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
|
|
|
+ # cdx = cdx[perm]
|
|
|
+ # c[cdx] = 1
|
|
|
+ #
|
|
|
+ # # sample negative lines
|
|
|
+ # cdx = Lneg[up, vp].nonzero().flatten()
|
|
|
+ # if len(cdx) > self.n_dyn_negl:
|
|
|
+ # # print("too many negative lines")
|
|
|
+ # perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
|
|
|
+ # cdx = cdx[perm]
|
|
|
+ # c[cdx] = 1
|
|
|
+ #
|
|
|
+ # # sample other (unmatched) lines
|
|
|
+ # cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
|
|
|
+ # c[cdx] = 1
|
|
|
+ # else:
|
|
|
+ # c = (u < v).flatten()
|
|
|
+ #
|
|
|
+ # # sample lines
|
|
|
+ # u, v, label = u[c], v[c], label[c]
|
|
|
+ # xy = xy.reshape(n_type * K, 2)
|
|
|
+ # xyu, xyv = xy[u], xy[v]
|
|
|
+ #
|
|
|
+ # u2v = xyu - xyv
|
|
|
+ # u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
|
|
|
+ # feat = torch.cat(
|
|
|
+ # [
|
|
|
+ # xyu / 128 * self.use_cood,
|
|
|
+ # xyv / 128 * self.use_cood,
|
|
|
+ # u2v * self.use_slop,
|
|
|
+ # (u[:, None] > K).float(),
|
|
|
+ # (v[:, None] > K).float(),
|
|
|
+ # ],
|
|
|
+ # 1,
|
|
|
+ # )
|
|
|
+ # line = torch.cat([xyu[:, None], xyv[:, None]], 1)
|
|
|
+ #
|
|
|
+ # xy = xy.reshape(n_type, K, 2)
|
|
|
+ # jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
|
+ # return line, label.float(), feat, jcs
|
|
|
|
|
|
|
|
|
|
|
@@ -746,7 +993,6 @@ def linercnn_resnet50_fpn(
|
|
|
"""
|
|
|
weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
|
|
|
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
|
-
|
|
|
if weights is not None:
|
|
|
weights_backbone = None
|
|
|
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|