|
|
@@ -242,30 +242,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
jcs.append(jc)
|
|
|
ps.append(p)
|
|
|
fs.append(feat)
|
|
|
- #
|
|
|
- # p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
|
|
|
- # p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
|
|
|
- # px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
|
|
|
- # px0 = px.floor().clamp(min=0, max=127)
|
|
|
- # py0 = py.floor().clamp(min=0, max=127)
|
|
|
- # px1 = (px0 + 1).clamp(min=0, max=127)
|
|
|
- # py1 = (py0 + 1).clamp(min=0, max=127)
|
|
|
- # px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
|
|
|
- #
|
|
|
- # # xp: [N_LINE, N_CHANNEL, N_POINT]
|
|
|
- # xp = (
|
|
|
- # (
|
|
|
- # x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
|
|
|
- # + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
|
|
|
- # + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
|
|
|
- # + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
|
|
|
- # )
|
|
|
- # .reshape(n_channel, -1, self.n_pts0)
|
|
|
- # .permute(1, 0, 2)
|
|
|
- # )
|
|
|
- # xp = self.pooling(xp)
|
|
|
- # # print(f'xp.shape:{xp.shape}')
|
|
|
- # xs.append(xp)
|
|
|
+
|
|
|
idx.append(idx[-1] + feat.shape[0])
|
|
|
# print(f'idx__:{idx}')
|
|
|
|
|
|
@@ -275,10 +252,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
print(f'f:{f.shape}')
|
|
|
# x = x.reshape(-1, self.n_pts1 * self.dim_loi)
|
|
|
|
|
|
- # print("Weight dtype:", self.fc2.weight.dtype)
|
|
|
- # x = torch.cat([x, f], 1)
|
|
|
- # print(f'x3:{x.shape}')
|
|
|
- # print("Input dtype:", x.dtype)
|
|
|
+
|
|
|
f= f.to(dtype=torch.float32)
|
|
|
# x = x.to(dtype=torch.float32)
|
|
|
# print("Input dtype1:", x.dtype)
|
|
|
@@ -287,10 +261,7 @@ class LineRCNNPredictor(nn.Module):
|
|
|
# return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
|
|
|
return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
|
|
|
|
|
|
- # if mode != "training":
|
|
|
- # self.inference(x, idx, jcs, n_batch, ps)
|
|
|
|
|
|
- # return result
|
|
|
|
|
|
def sample_lines(self, meta, jmap, joff,lmap):
|
|
|
device = jmap.device
|
|
|
@@ -312,9 +283,10 @@ class LineRCNNPredictor(nn.Module):
|
|
|
print(f'jmap min:{torch.min(jmap[0])}')
|
|
|
print(f'jmap num:{(jmap > self.eval_junc_thres).float().sum().item()}')
|
|
|
print(f'jmap:{jmap}')
|
|
|
- print(f'K:{K}')
|
|
|
+ print(f'predict K:{K}')
|
|
|
else:
|
|
|
K = min(int(N * 2 + 2), max_K)
|
|
|
+ print(f'train K:{K}')
|
|
|
if K < 2:
|
|
|
K = 2
|
|
|
device = jmap.device
|
|
|
@@ -331,13 +303,19 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
# dist: [N_TYPE, K, N]
|
|
|
dist = torch.sum((xy_ - junc) ** 2, -1)
|
|
|
+ print(f'dist:{dist}')
|
|
|
+
|
|
|
cost, match = torch.min(dist, -1)
|
|
|
+ print(f'match:{match},cost:{cost}')
|
|
|
|
|
|
# xy: [N_TYPE * K, 2]
|
|
|
# match: [N_TYPE, K]
|
|
|
- for t in range(n_type):
|
|
|
- match[t, jtyp[match[t]] != t] = N
|
|
|
+ # for t in range(n_type):
|
|
|
+ # match[t, jtyp[match[t]] != t] = N
|
|
|
+
|
|
|
match[cost > 1.5 * 1.5] = N
|
|
|
+
|
|
|
+ print(f'match__ : {match}')
|
|
|
match = match.flatten()
|
|
|
|
|
|
_ = torch.arange(n_type * K, device=device)
|
|
|
@@ -379,22 +357,6 @@ class LineRCNNPredictor(nn.Module):
|
|
|
|
|
|
u2v = xyu - xyv
|
|
|
u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
|
|
|
-
|
|
|
-
|
|
|
- # print(f'xp.shape:{xp.shape}')
|
|
|
-
|
|
|
- feat = torch.cat(
|
|
|
- [
|
|
|
- xyu / 128 * self.use_cood,
|
|
|
- xyv / 128 * self.use_cood,
|
|
|
- u2v * self.use_slop,
|
|
|
- (u[:, None] > K).float(),
|
|
|
- (v[:, None] > K).float(),
|
|
|
- ],
|
|
|
- 1,
|
|
|
- )
|
|
|
- print(f'feat shape:{feat.shape}')
|
|
|
-
|
|
|
# lmap = gaussian_filter(lmap, sigma=1)
|
|
|
# lmap = torch.from_numpy(gaussian_filter(lmap.cpu().numpy(), sigma=1)).to('cuda:0')
|
|
|
|
|
|
@@ -425,11 +387,14 @@ class LineRCNNPredictor(nn.Module):
|
|
|
.permute(1, 0, 2)
|
|
|
)
|
|
|
xp = self.pooling(xp).squeeze(1)
|
|
|
+ if not self.training:
|
|
|
+ print(f'predict xp values:{xp}')
|
|
|
print(f'xp shape:{xp.shape}')
|
|
|
|
|
|
|
|
|
xy = xy.reshape(n_type, K, 2)
|
|
|
jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
|
|
|
+ print(f'line.shape :{line.shape}')
|
|
|
return line, label.float(), xp, jcs
|
|
|
|
|
|
|