|
|
@@ -19,8 +19,6 @@ from models.config.config_tool import read_yaml
|
|
|
import numpy as np
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
-# from scipy.ndimage import gaussian_filter
|
|
|
-
|
|
|
FEATURE_DIM = 8
|
|
|
|
|
|
def non_maximum_suppression(a):
|
|
|
@@ -52,7 +50,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
def __init__(self,n_pts0 = 32,
|
|
|
n_pts1 = 8,
|
|
|
n_stc_posl =300,
|
|
|
- dim_loi = 1,
|
|
|
+ dim_loi = 128,
|
|
|
use_conv = 0,
|
|
|
dim_fc = 1024,
|
|
|
n_out_line = 2500,
|
|
|
@@ -134,7 +132,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
else:
|
|
|
self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
|
|
|
self.fc2 = nn.Sequential(
|
|
|
- nn.Linear(self.dim_loi * FEATURE_DIM, self.dim_fc),
|
|
|
+ nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Linear(self.dim_fc, self.dim_fc),
|
|
|
nn.ReLU(inplace=True),
|
|
|
@@ -149,7 +147,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
# print(f'out:{out.shape}')
|
|
|
# outputs=merge_features(outputs,100)
|
|
|
batch, channel, row, col = inputs.shape
|
|
|
- # print(f'outputs:{inputs.shape}')
|
|
|
+ print(f'outputs:{inputs.shape}')
|
|
|
# print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
|
|
|
|
|
|
if targets is not None:
|
|
|
@@ -192,13 +190,11 @@ class LineRCNNPredictor(nn.Module):
|
|
|
n_jtyp = T["junc_map"].shape[1]
|
|
|
offset = self.head_off
|
|
|
result = {}
|
|
|
- print(f' wires_targets len:{len(wires_targets)}')
|
|
|
for stack, output in enumerate([inputs]):
|
|
|
output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
|
|
|
# print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
|
|
|
jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
|
|
|
- # lmap = output[offset[0]: offset[1]].squeeze(0)
|
|
|
- lmap = output[offset[0]: offset[1]]
|
|
|
+ lmap = output[offset[0]: offset[1]].squeeze(0)
|
|
|
joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
|
|
|
|
|
|
if stack == 0:
|
|
|
@@ -212,16 +208,12 @@ class LineRCNNPredictor(nn.Module):
|
|
|
# visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
|
|
|
|
|
|
h = result["preds"]
|
|
|
- print(f'features shape:{features.shape}')
|
|
|
- print(f'inputs shape :{inputs.shape}')
|
|
|
- # x = self.fc1(features)
|
|
|
- lmap = inputs[:,2:3,:,:].sigmoid()
|
|
|
- x=lmap
|
|
|
- print(f'x:{lmap.shape}')
|
|
|
+ # print(f'features shape:{features.shape}')
|
|
|
+ x = self.fc1(features)
|
|
|
|
|
|
- n_batch, n_channel, row, col = lmap.shape
|
|
|
- # n_batch, n_channel, row, col = x.shape
|
|
|
+ # print(f'x:{x.shape}')
|
|
|
|
|
|
+ n_batch, n_channel, row, col = x.shape
|
|
|
|
|
|
# print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
|
|
|
|
|
|
@@ -229,9 +221,9 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
for i, meta in enumerate(wires_targets):
|
|
|
p, label, feat, jc = self.sample_lines(
|
|
|
- meta, h["jmap"][i], h["joff"][i],lmap[i]
|
|
|
+ meta, h["jmap"][i], h["joff"][i],
|
|
|
)
|
|
|
- print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
|
|
|
+ # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
|
|
|
ys.append(label)
|
|
|
if self.training and self.do_static_sampling:
|
|
|
p = torch.cat([p, meta["lpre"]])
|
|
|
@@ -243,27 +235,53 @@ class LineRCNNPredictor(nn.Module):
|
|
|
ps.append(p)
|
|
|
fs.append(feat)
|
|
|
|
|
|
- idx.append(idx[-1] + feat.shape[0])
|
|
|
+ p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
|
|
|
+ p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
|
|
|
+ px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
|
|
|
+ px0 = px.floor().clamp(min=0, max=127)
|
|
|
+ py0 = py.floor().clamp(min=0, max=127)
|
|
|
+ px1 = (px0 + 1).clamp(min=0, max=127)
|
|
|
+ py1 = (py0 + 1).clamp(min=0, max=127)
|
|
|
+ px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
|
|
|
+
|
|
|
+ # xp: [N_LINE, N_CHANNEL, N_POINT]
|
|
|
+ xp = (
|
|
|
+ (
|
|
|
+ x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
|
|
|
+ + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
|
|
|
+ + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
+ + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
+ )
|
|
|
+ .reshape(n_channel, -1, self.n_pts0)
|
|
|
+ .permute(1, 0, 2)
|
|
|
+ )
|
|
|
+ xp = self.pooling(xp)
|
|
|
+ print(f'xp forward.shape:{xp.shape}')
|
|
|
+ xs.append(xp)
|
|
|
+ idx.append(idx[-1] + xp.shape[0])
|
|
|
# print(f'idx__:{idx}')
|
|
|
|
|
|
- # x, y = torch.cat(xs), torch.cat(ys)
|
|
|
- y=torch.cat(ys)
|
|
|
+ x, y = torch.cat(xs), torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
|
- print(f'f:{f.shape}')
|
|
|
- # x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
-
|
|
|
+ x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
+ print(f' x reshape:{x.shape}')
|
|
|
|
|
|
- f= f.to(dtype=torch.float32)
|
|
|
- # x = x.to(dtype=torch.float32)
|
|
|
+ # print("Weight dtype:", self.fc2.weight.dtype)
|
|
|
+ x = torch.cat([x, f], 1)
|
|
|
+ # print("Input dtype:", x.dtype)
|
|
|
+ x = x.to(dtype=torch.float32)
|
|
|
# print("Input dtype1:", x.dtype)
|
|
|
- x = self.fc2(f).flatten()
|
|
|
+ x = self.fc2(x).flatten()
|
|
|
|
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
|
return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
|
|
|
|
+ # if mode != "training":
|
|
|
+ # self.inference(x, idx, jcs, n_batch, ps)
|
|
|
|
|
|
+ # return result
|
|
|
|
|
|
- def sample_lines(self, meta, jmap, joff,lmap):
|
|
|
+ def sample_lines(self, meta, jmap, joff):
|
|
|
device = jmap.device
|
|
|
with torch.no_grad():
|
|
|
junc = meta["junc_coords"].to(device) # [N, 2]
|
|
|
@@ -279,14 +297,8 @@ class LineRCNNPredictor(nn.Module):
|
|
|
# if mode != "training":
|
|
|
if not self.training:
|
|
|
K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
|
|
|
- print(f'jmap max:{torch.max(jmap[0])}')
|
|
|
- print(f'jmap min:{torch.min(jmap[0])}')
|
|
|
- print(f'jmap num:{(jmap > self.eval_junc_thres).float().sum().item()}')
|
|
|
- print(f'jmap:{jmap}')
|
|
|
- print(f'predict K:{K}')
|
|
|
else:
|
|
|
K = min(int(N * 2 + 2), max_K)
|
|
|
- print(f'train K:{K}')
|
|
|
if K < 2:
|
|
|
K = 2
|
|
|
device = jmap.device
|
|
|
@@ -303,10 +315,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
# dist: [N_TYPE, K, N]
|
|
|
dist = torch.sum((xy_ - junc) ** 2, -1)
|
|
|
- print(f'dist:{dist}')
|
|
|
-
|
|
|
cost, match = torch.min(dist, -1)
|
|
|
- print(f'match:{match},cost:{cost}')
|
|
|
|
|
|
# xy: [N_TYPE * K, 2]
|
|
|
# match: [N_TYPE, K]
|
|
|
@@ -314,8 +323,6 @@ class LineRCNNPredictor(nn.Module):
|
|
|
# match[t, jtyp[match[t]] != t] = N
|
|
|
|
|
|
match[cost > 1.5 * 1.5] = N
|
|
|
-
|
|
|
- print(f'match__ : {match}')
|
|
|
match = match.flatten()
|
|
|
|
|
|
_ = torch.arange(n_type * K, device=device)
|
|
|
@@ -357,45 +364,21 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
u2v = xyu - xyv
|
|
|
u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
|
|
|
- # lmap = gaussian_filter(lmap, sigma=1)
|
|
|
- # lmap = torch.from_numpy(gaussian_filter(lmap.cpu().numpy(), sigma=1)).to('cuda:0')
|
|
|
-
|
|
|
- line = torch.cat([xyu[:, None], xyv[:, None]], 1)
|
|
|
- # print(f'line:{line.shape}')
|
|
|
- n_channel, row, col = lmap.shape
|
|
|
- p=line
|
|
|
- print(f'p.shape :{p.shape}')
|
|
|
- p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
|
|
|
- p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
|
|
|
- px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
|
|
|
- px0 = px.floor().clamp(min=0, max=127)
|
|
|
- py0 = py.floor().clamp(min=0, max=127)
|
|
|
- px1 = (px0 + 1).clamp(min=0, max=127)
|
|
|
- py1 = (py0 + 1).clamp(min=0, max=127)
|
|
|
- px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
|
|
|
-
|
|
|
- # xp: [N_LINE, N_CHANNEL, N_POINT]
|
|
|
- x=lmap
|
|
|
- xp = (
|
|
|
- (
|
|
|
- x[ :, px0l, py0l] * (px1 - px) * (py1 - py)
|
|
|
- + x[ :, px1l, py0l] * (px - px0) * (py1 - py)
|
|
|
- + x[ :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
- + x[ :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
- )
|
|
|
- .reshape(n_channel, -1, self.n_pts0)
|
|
|
- .permute(1, 0, 2)
|
|
|
+ feat = torch.cat(
|
|
|
+ [
|
|
|
+ xyu / 128 * self.use_cood,
|
|
|
+ xyv / 128 * self.use_cood,
|
|
|
+ u2v * self.use_slop,
|
|
|
+ (u[:, None] > K).float(),
|
|
|
+ (v[:, None] > K).float(),
|
|
|
+ ],
|
|
|
+ 1,
|
|
|
)
|
|
|
- xp = self.pooling(xp).squeeze(1)
|
|
|
- if not self.training:
|
|
|
- print(f'predict xp values:{xp}')
|
|
|
- print(f'xp shape:{xp.shape}')
|
|
|
-
|
|
|
+ line = torch.cat([xyu[:, None], xyv[:, None]], 1)
|
|
|
|
|
|
xy = xy.reshape(n_type, K, 2)
|
|
|
jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
|
- print(f'line.shape :{line.shape}')
|
|
|
- return line, label.float(), xp, jcs
|
|
|
+ return line, label.float(), feat, jcs
|
|
|
|
|
|
|
|
|
|
|
|
@@ -403,5 +386,4 @@ _COMMON_META = {
|
|
|
"categories": _COCO_PERSON_CATEGORIES,
|
|
|
"keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
|
|
|
"min_size": (1, 1),
|
|
|
-}
|
|
|
-
|
|
|
+}
|