|
|
@@ -51,11 +51,11 @@ class LineVectorizer(nn.Module):
|
|
|
xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
|
|
|
for i, meta in enumerate(input_dict["meta"]):
|
|
|
p, label, feat, jc = self.sample_lines(
|
|
|
- meta, h["jmap"][i], h["joff"][i], input_dict["do_evaluation"]
|
|
|
+ meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
|
|
|
)
|
|
|
# print("p.shape:", p.shape)
|
|
|
ys.append(label)
|
|
|
- if not input_dict["do_evaluation"] and self.do_static_sampling:
|
|
|
+ if input_dict["mode"] == "training" and self.do_static_sampling:
|
|
|
p = torch.cat([p, meta["lpre"]])
|
|
|
feat = torch.cat([feat, meta["lpre_feat"]])
|
|
|
ys.append(meta["lpre_label"])
|
|
|
@@ -95,7 +95,7 @@ class LineVectorizer(nn.Module):
|
|
|
x = torch.cat([x, f], 1)
|
|
|
x = self.fc2(x).flatten()
|
|
|
|
|
|
- if input_dict["do_evaluation"]:
|
|
|
+ if input_dict["mode"] != "training":
|
|
|
p = torch.cat(ps)
|
|
|
s = torch.sigmoid(x)
|
|
|
b = s > 0.5
|
|
|
@@ -128,7 +128,8 @@ class LineVectorizer(nn.Module):
|
|
|
result["preds"]["junts"] = torch.cat(
|
|
|
[jcs[i][1] for i in range(n_batch)]
|
|
|
)
|
|
|
- else:
|
|
|
+
|
|
|
+ if input_dict["mode"] != "testing":
|
|
|
y = torch.cat(ys)
|
|
|
loss = self.loss(x, y)
|
|
|
lpos_mask, lneg_mask = y, 1 - y
|
|
|
@@ -142,11 +143,13 @@ class LineVectorizer(nn.Module):
|
|
|
lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
|
|
|
result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
|
|
|
result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
|
|
|
+
|
|
|
+ if input_dict["mode"] == "training":
|
|
|
del result["preds"]
|
|
|
|
|
|
return result
|
|
|
|
|
|
- def sample_lines(self, meta, jmap, joff, do_evaluation):
|
|
|
+ def sample_lines(self, meta, jmap, joff, mode):
|
|
|
with torch.no_grad():
|
|
|
junc = meta["junc"] # [N, 2]
|
|
|
jtyp = meta["jtyp"] # [N]
|
|
|
@@ -158,7 +161,7 @@ class LineVectorizer(nn.Module):
|
|
|
joff = joff.reshape(n_type, 2, -1)
|
|
|
max_K = M.n_dyn_junc // n_type
|
|
|
N = len(junc)
|
|
|
- if do_evaluation:
|
|
|
+ if mode != "training":
|
|
|
K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
|
|
|
else:
|
|
|
K = min(int(N * 2 + 2), max_K)
|
|
|
@@ -193,9 +196,7 @@ class LineVectorizer(nn.Module):
|
|
|
up, vp = match[u], match[v]
|
|
|
label = Lpos[up, vp]
|
|
|
|
|
|
- if do_evaluation:
|
|
|
- c = (u < v).flatten()
|
|
|
- else:
|
|
|
+ if mode == "training":
|
|
|
c = torch.zeros_like(label, dtype=torch.bool)
|
|
|
|
|
|
# sample positive lines
|
|
|
@@ -217,6 +218,8 @@ class LineVectorizer(nn.Module):
|
|
|
# sample other (unmatched) lines
|
|
|
cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
|
|
|
c[cdx] = 1
|
|
|
+ else:
|
|
|
+ c = (u < v).flatten()
|
|
|
|
|
|
# sample lines
|
|
|
u, v, label = u[c], v[c], label[c]
|