Procházet zdrojové kódy

Rename heatmaps to preds to avoid confusion

Yichao Zhou před 6 roky
rodič
revize
33176e4f2b

+ 7 - 7
lcnn/models/line_vectorizer.py

@@ -45,7 +45,7 @@ class LineVectorizer(nn.Module):
 
     def forward(self, input_dict):
         result = self.backbone(input_dict)
-        h = result["heatmaps"]
+        h = result["preds"]
         x = self.fc1(result["feature"])
         n_batch, n_channel, row, col = x.shape
 
@@ -134,16 +134,16 @@ class LineVectorizer(nn.Module):
                     jcs[i][j] = jcs[i][j][
                         None, torch.arange(M.n_out_junc) % len(jcs[i][j])
                     ]
-            result["heatmaps"]["lines"] = torch.cat(lines)
-            result["heatmaps"]["score"] = torch.cat(score)
-            result["heatmaps"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+            result["preds"]["lines"] = torch.cat(lines)
+            result["preds"]["score"] = torch.cat(score)
+            result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
             if len(jcs[i]) > 1:
-                result["heatmaps"]["junts"] = torch.cat(
+                result["preds"]["junts"] = torch.cat(
                     [jcs[i][1] for i in range(n_batch)]
                 )
         else:
-            if "heatmaps" in result:
-                del result["heatmaps"]
+            if "preds" in result:
+                del result["preds"]
         return result
 
     def sample_lines(self, meta, jmap, joff, do_evaluation):

+ 1 - 1
lcnn/models/multitask_learner.py

@@ -61,7 +61,7 @@ class MultitaskLearner(nn.Module):
             lmap = output[offset[0] : offset[1]].squeeze(0)
             joff = output[offset[1] : offset[2]].reshape(n_jtyp, 2, batch, row, col)
             if stack == 0:
-                result["heatmaps"] = {
+                result["preds"] = {
                     "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
                     "lmap": lmap.sigmoid(),
                     "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,

+ 1 - 1
lcnn/trainer.py

@@ -122,7 +122,7 @@ class Trainer(object):
 
                 total_loss += self._loss(result)
 
-                H = result["heatmaps"]
+                H = result["preds"]
                 for i in range(H["jmap"].shape[0]):
                     index = batch_idx * self.batch_size + i
                     np.savez(

+ 1 - 1
process.py

@@ -99,7 +99,7 @@ def main():
                 "target": recursive_to(target, device),
                 "do_evaluation": True,
             }
-            H = model(input_dict)["heatmaps"]
+            H = model(input_dict)["preds"]
             for i in range(M.batch_size):
                 index = batch_idx * M.batch_size + i
                 np.savez(