Selaa lähdekoodia

Split training, testing, validation

Yichao Zhou 6 vuotta sitten
vanhempi
commit
c57814e79f
5 muutettua tiedostoa jossa 17 lisäystä ja 14 poistoa
  1. 1 1
      demo.py
  2. 12 9
      lcnn/models/line_vectorizer.py
  3. 1 1
      lcnn/models/multitask_learner.py
  4. 2 2
      lcnn/trainer.py
  5. 1 1
      process.py

+ 1 - 1
demo.py

@@ -107,7 +107,7 @@ def main():
                     "jmap": torch.zeros([1, 1, 128, 128]).to(device),
                     "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
                 },
-                "do_evaluation": True,
+                "mode": "testing",
             }
             H = model(input_dict)["preds"]
 

+ 12 - 9
lcnn/models/line_vectorizer.py

@@ -51,11 +51,11 @@ class LineVectorizer(nn.Module):
         xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
         for i, meta in enumerate(input_dict["meta"]):
             p, label, feat, jc = self.sample_lines(
-                meta, h["jmap"][i], h["joff"][i], input_dict["do_evaluation"]
+                meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
             )
             # print("p.shape:", p.shape)
             ys.append(label)
-            if not input_dict["do_evaluation"] and self.do_static_sampling:
+            if input_dict["mode"] == "training" and self.do_static_sampling:
                 p = torch.cat([p, meta["lpre"]])
                 feat = torch.cat([feat, meta["lpre_feat"]])
                 ys.append(meta["lpre_label"])
@@ -95,7 +95,7 @@ class LineVectorizer(nn.Module):
         x = torch.cat([x, f], 1)
         x = self.fc2(x).flatten()
 
-        if input_dict["do_evaluation"]:
+        if input_dict["mode"] != "training":
             p = torch.cat(ps)
             s = torch.sigmoid(x)
             b = s > 0.5
@@ -128,7 +128,8 @@ class LineVectorizer(nn.Module):
                 result["preds"]["junts"] = torch.cat(
                     [jcs[i][1] for i in range(n_batch)]
                 )
-        else:
+
+        if input_dict["mode"] != "testing":
             y = torch.cat(ys)
             loss = self.loss(x, y)
             lpos_mask, lneg_mask = y, 1 - y
@@ -142,11 +143,13 @@ class LineVectorizer(nn.Module):
             lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
             result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
             result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
+
+        if input_dict["mode"] == "training":
             del result["preds"]
 
         return result
 
-    def sample_lines(self, meta, jmap, joff, do_evaluation):
+    def sample_lines(self, meta, jmap, joff, mode):
         with torch.no_grad():
             junc = meta["junc"]  # [N, 2]
             jtyp = meta["jtyp"]  # [N]
@@ -158,7 +161,7 @@ class LineVectorizer(nn.Module):
             joff = joff.reshape(n_type, 2, -1)
             max_K = M.n_dyn_junc // n_type
             N = len(junc)
-            if do_evaluation:
+            if mode != "training":
                 K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
             else:
                 K = min(int(N * 2 + 2), max_K)
@@ -193,9 +196,7 @@ class LineVectorizer(nn.Module):
             up, vp = match[u], match[v]
             label = Lpos[up, vp]
 
-            if do_evaluation:
-                c = (u < v).flatten()
-            else:
+            if mode == "training":
                 c = torch.zeros_like(label, dtype=torch.bool)
 
                 # sample positive lines
@@ -217,6 +218,8 @@ class LineVectorizer(nn.Module):
                 # sample other (unmatched) lines
                 cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
                 c[cdx] = 1
+            else:
+                c = (u < v).flatten()
 
             # sample lines
             u, v, label = u[c], v[c], label[c]

+ 1 - 1
lcnn/models/multitask_learner.py

@@ -66,7 +66,7 @@ class MultitaskLearner(nn.Module):
                     "lmap": lmap.sigmoid(),
                     "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
                 }
-                if input_dict["do_evaluation"]:
+                if input_dict["mode"] == "testing":
                     return result
 
             L = OrderedDict()

+ 2 - 2
lcnn/trainer.py

@@ -116,7 +116,7 @@ class Trainer(object):
                     "image": recursive_to(image, self.device),
                     "meta": recursive_to(meta, self.device),
                     "target": recursive_to(target, self.device),
-                    "do_evaluation": True,
+                    "mode": "validation",
                 }
                 result = self.model(input_dict)
 
@@ -173,7 +173,7 @@ class Trainer(object):
                 "image": recursive_to(image, self.device),
                 "meta": recursive_to(meta, self.device),
                 "target": recursive_to(target, self.device),
-                "do_evaluation": False,
+                "mode": "training",
             }
             result = self.model(input_dict)
 

+ 1 - 1
process.py

@@ -97,7 +97,7 @@ def main():
                 "image": recursive_to(image, device),
                 "meta": recursive_to(meta, device),
                 "target": recursive_to(target, device),
-                "do_evaluation": True,
+                "mode": "validation",
             }
             H = model(input_dict)["preds"]
             for i in range(M.batch_size):