|
@@ -171,7 +171,7 @@ class LineVectorizer(nn.Module):
|
|
|
|
|
|
|
|
# index: [N_TYPE, K]
|
|
# index: [N_TYPE, K]
|
|
|
score, index = torch.topk(jmap, k=K)
|
|
score, index = torch.topk(jmap, k=K)
|
|
|
- y = (index / 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
|
|
|
|
|
|
|
+ y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
|
|
|
x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
|
|
x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
|
|
|
|
|
|
|
|
# xy: [N_TYPE, K, 2]
|
|
# xy: [N_TYPE, K, 2]
|