|
@@ -132,7 +132,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
else:
|
|
else:
|
|
|
self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
|
|
self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
|
|
|
self.fc2 = nn.Sequential(
|
|
self.fc2 = nn.Sequential(
|
|
|
- nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
|
|
|
|
|
|
|
+ nn.Linear(self.dim_loi * FEATURE_DIM, self.dim_fc),
|
|
|
nn.ReLU(inplace=True),
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Linear(self.dim_fc, self.dim_fc),
|
|
nn.Linear(self.dim_fc, self.dim_fc),
|
|
|
nn.ReLU(inplace=True),
|
|
nn.ReLU(inplace=True),
|
|
@@ -213,10 +213,11 @@ class LineRCNNPredictor(nn.Module):
|
|
|
print(f'features shape:{features.shape}')
|
|
print(f'features shape:{features.shape}')
|
|
|
print(f'inputs shape :{inputs.shape}')
|
|
print(f'inputs shape :{inputs.shape}')
|
|
|
# x = self.fc1(features)
|
|
# x = self.fc1(features)
|
|
|
- x = inputs[:,2:3,:,:].sigmoid()
|
|
|
|
|
- print(f'x:{x.shape}')
|
|
|
|
|
|
|
+ lmap = inputs[:,2:3,:,:].sigmoid()
|
|
|
|
|
+ x=lmap
|
|
|
|
|
+ print(f'x:{lmap.shape}')
|
|
|
|
|
|
|
|
- n_batch, n_channel, row, col = x.shape
|
|
|
|
|
|
|
+ n_batch, n_channel, row, col = lmap.shape
|
|
|
# n_batch, n_channel, row, col = x.shape
|
|
# n_batch, n_channel, row, col = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
@@ -226,9 +227,9 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
|
|
for i, meta in enumerate(wires_targets):
|
|
for i, meta in enumerate(wires_targets):
|
|
|
p, label, feat, jc = self.sample_lines(
|
|
p, label, feat, jc = self.sample_lines(
|
|
|
- meta, h["jmap"][i], h["joff"][i],
|
|
|
|
|
|
|
+ meta, h["jmap"][i], h["joff"][i],lmap[i]
|
|
|
)
|
|
)
|
|
|
- # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
|
|
|
|
|
|
|
+ print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
|
|
|
ys.append(label)
|
|
ys.append(label)
|
|
|
if self.training and self.do_static_sampling:
|
|
if self.training and self.do_static_sampling:
|
|
|
p = torch.cat([p, meta["lpre"]])
|
|
p = torch.cat([p, meta["lpre"]])
|
|
@@ -239,43 +240,47 @@ class LineRCNNPredictor(nn.Module):
|
|
|
jcs.append(jc)
|
|
jcs.append(jc)
|
|
|
ps.append(p)
|
|
ps.append(p)
|
|
|
fs.append(feat)
|
|
fs.append(feat)
|
|
|
-
|
|
|
|
|
- p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
|
|
|
|
|
- p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
|
|
|
|
|
- px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
|
|
|
|
|
- px0 = px.floor().clamp(min=0, max=127)
|
|
|
|
|
- py0 = py.floor().clamp(min=0, max=127)
|
|
|
|
|
- px1 = (px0 + 1).clamp(min=0, max=127)
|
|
|
|
|
- py1 = (py0 + 1).clamp(min=0, max=127)
|
|
|
|
|
- px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
|
|
|
|
|
-
|
|
|
|
|
- # xp: [N_LINE, N_CHANNEL, N_POINT]
|
|
|
|
|
- xp = (
|
|
|
|
|
- (
|
|
|
|
|
- x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
|
|
|
|
|
- + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
|
|
|
|
|
- + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
|
|
- + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
|
|
- )
|
|
|
|
|
- .reshape(n_channel, -1, self.n_pts0)
|
|
|
|
|
- .permute(1, 0, 2)
|
|
|
|
|
- )
|
|
|
|
|
- xp = self.pooling(xp)
|
|
|
|
|
- # print(f'xp.shape:{xp.shape}')
|
|
|
|
|
- xs.append(xp)
|
|
|
|
|
- idx.append(idx[-1] + xp.shape[0])
|
|
|
|
|
|
|
+ #
|
|
|
|
|
+ # p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
|
|
|
|
|
+ # p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
|
|
|
|
|
+ # px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
|
|
|
|
|
+ # px0 = px.floor().clamp(min=0, max=127)
|
|
|
|
|
+ # py0 = py.floor().clamp(min=0, max=127)
|
|
|
|
|
+ # px1 = (px0 + 1).clamp(min=0, max=127)
|
|
|
|
|
+ # py1 = (py0 + 1).clamp(min=0, max=127)
|
|
|
|
|
+ # px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
|
|
|
|
|
+ #
|
|
|
|
|
+ # # xp: [N_LINE, N_CHANNEL, N_POINT]
|
|
|
|
|
+ # xp = (
|
|
|
|
|
+ # (
|
|
|
|
|
+ # x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
|
|
|
|
|
+ # + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
|
|
|
|
|
+ # + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
|
|
+ # + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
|
|
+ # )
|
|
|
|
|
+ # .reshape(n_channel, -1, self.n_pts0)
|
|
|
|
|
+ # .permute(1, 0, 2)
|
|
|
|
|
+ # )
|
|
|
|
|
+ # xp = self.pooling(xp)
|
|
|
|
|
+ # # print(f'xp.shape:{xp.shape}')
|
|
|
|
|
+ # xs.append(xp)
|
|
|
|
|
+ idx.append(idx[-1] + feat.shape[0])
|
|
|
# print(f'idx__:{idx}')
|
|
# print(f'idx__:{idx}')
|
|
|
|
|
|
|
|
- x, y = torch.cat(xs), torch.cat(ys)
|
|
|
|
|
|
|
+ # x, y = torch.cat(xs), torch.cat(ys)
|
|
|
|
|
+ y=torch.cat(ys)
|
|
|
f = torch.cat(fs)
|
|
f = torch.cat(fs)
|
|
|
- x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
|
|
|
|
+ print(f'f:{f.shape}')
|
|
|
|
|
+ # x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
|
|
|
|
|
# print("Weight dtype:", self.fc2.weight.dtype)
|
|
# print("Weight dtype:", self.fc2.weight.dtype)
|
|
|
- x = torch.cat([x, f], 1)
|
|
|
|
|
|
|
+ # x = torch.cat([x, f], 1)
|
|
|
|
|
+ # print(f'x3:{x.shape}')
|
|
|
# print("Input dtype:", x.dtype)
|
|
# print("Input dtype:", x.dtype)
|
|
|
- x = x.to(dtype=torch.float32)
|
|
|
|
|
|
|
+ f= f.to(dtype=torch.float32)
|
|
|
|
|
+ # x = x.to(dtype=torch.float32)
|
|
|
# print("Input dtype1:", x.dtype)
|
|
# print("Input dtype1:", x.dtype)
|
|
|
- x = self.fc2(x).flatten()
|
|
|
|
|
|
|
+ x = self.fc2(f).flatten()
|
|
|
|
|
|
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
|
return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
@@ -285,7 +290,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
|
|
# return result
|
|
# return result
|
|
|
|
|
|
|
|
- def sample_lines(self, meta, jmap, joff):
|
|
|
|
|
|
|
+ def sample_lines(self, meta, jmap, joff,lmap):
|
|
|
device = jmap.device
|
|
device = jmap.device
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
junc = meta["junc_coords"].to(device) # [N, 2]
|
|
junc = meta["junc_coords"].to(device) # [N, 2]
|
|
@@ -367,6 +372,10 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
|
|
u2v = xyu - xyv
|
|
u2v = xyu - xyv
|
|
|
u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
|
|
u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ # print(f'xp.shape:{xp.shape}')
|
|
|
|
|
+
|
|
|
feat = torch.cat(
|
|
feat = torch.cat(
|
|
|
[
|
|
[
|
|
|
xyu / 128 * self.use_cood,
|
|
xyu / 128 * self.use_cood,
|
|
@@ -377,11 +386,41 @@ class LineRCNNPredictor(nn.Module):
|
|
|
],
|
|
],
|
|
|
1,
|
|
1,
|
|
|
)
|
|
)
|
|
|
|
|
+ print(f'feat shape:{feat.shape}')
|
|
|
|
|
+
|
|
|
line = torch.cat([xyu[:, None], xyv[:, None]], 1)
|
|
line = torch.cat([xyu[:, None], xyv[:, None]], 1)
|
|
|
|
|
+ # print(f'line:{line.shape}')
|
|
|
|
|
+ n_channel, row, col = lmap.shape
|
|
|
|
|
+ p=line
|
|
|
|
|
+ print(f'p.shape :{p.shape}')
|
|
|
|
|
+ p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
|
|
|
|
|
+ p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
|
|
|
|
|
+ px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
|
|
|
|
|
+ px0 = px.floor().clamp(min=0, max=127)
|
|
|
|
|
+ py0 = py.floor().clamp(min=0, max=127)
|
|
|
|
|
+ px1 = (px0 + 1).clamp(min=0, max=127)
|
|
|
|
|
+ py1 = (py0 + 1).clamp(min=0, max=127)
|
|
|
|
|
+ px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
|
|
|
|
|
+
|
|
|
|
|
+ # xp: [N_LINE, N_CHANNEL, N_POINT]
|
|
|
|
|
+ x=lmap
|
|
|
|
|
+ xp = (
|
|
|
|
|
+ (
|
|
|
|
|
+ x[ :, px0l, py0l] * (px1 - px) * (py1 - py)
|
|
|
|
|
+ + x[ :, px1l, py0l] * (px - px0) * (py1 - py)
|
|
|
|
|
+ + x[ :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
|
|
+ + x[ :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
|
|
+ )
|
|
|
|
|
+ .reshape(n_channel, -1, self.n_pts0)
|
|
|
|
|
+ .permute(1, 0, 2)
|
|
|
|
|
+ )
|
|
|
|
|
+ xp = self.pooling(xp).squeeze(1)
|
|
|
|
|
+ print(f'xp shape:{xp.shape}')
|
|
|
|
|
+
|
|
|
|
|
|
|
|
xy = xy.reshape(n_type, K, 2)
|
|
xy = xy.reshape(n_type, K, 2)
|
|
|
jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
|
- return line, label.float(), feat, jcs
|
|
|
|
|
|
|
+ return line, label.float(), xp, jcs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|