|
|
@@ -45,7 +45,7 @@ class LineVectorizer(nn.Module):
|
|
|
|
|
|
def forward(self, input_dict):
|
|
|
result = self.backbone(input_dict)
|
|
|
- h = result["heatmaps"]
|
|
|
+ h = result["preds"]
|
|
|
x = self.fc1(result["feature"])
|
|
|
n_batch, n_channel, row, col = x.shape
|
|
|
|
|
|
@@ -134,16 +134,16 @@ class LineVectorizer(nn.Module):
|
|
|
jcs[i][j] = jcs[i][j][
|
|
|
None, torch.arange(M.n_out_junc) % len(jcs[i][j])
|
|
|
]
|
|
|
- result["heatmaps"]["lines"] = torch.cat(lines)
|
|
|
- result["heatmaps"]["score"] = torch.cat(score)
|
|
|
- result["heatmaps"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
|
|
|
+ result["preds"]["lines"] = torch.cat(lines)
|
|
|
+ result["preds"]["score"] = torch.cat(score)
|
|
|
+ result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
|
|
|
if len(jcs[i]) > 1:
|
|
|
- result["heatmaps"]["junts"] = torch.cat(
|
|
|
+ result["preds"]["junts"] = torch.cat(
|
|
|
[jcs[i][1] for i in range(n_batch)]
|
|
|
)
|
|
|
else:
|
|
|
- if "heatmaps" in result:
|
|
|
- del result["heatmaps"]
|
|
|
+ if "preds" in result:
|
|
|
+ del result["preds"]
|
|
|
return result
|
|
|
|
|
|
def sample_lines(self, meta, jmap, joff, do_evaluation):
|