Pārlūkot izejas kodu

Add batch_size_eval. Closes #20.

Yichao Zhou 6 gadi atpakaļ
vecāks
revīzija
c044515dfd
3 mainītis faili ar 9 papildinājumiem un 5 dzēšanām
  1. 1 0
      config/wireframe.yaml
  2. 4 4
      lcnn/trainer.py
  3. 4 1
      train.py

+ 1 - 0
config/wireframe.yaml

@@ -12,6 +12,7 @@ model:
       stddev: [22.275, 22.124, 23.229]
 
   batch_size: 6
+  batch_size_eval: 2
 
   # backbone multi-task parameters
   head_size: [[2], [1], [2]]

+ 4 - 4
lcnn/trainer.py

@@ -16,7 +16,7 @@ import torch.nn.functional as F
 from skimage import io
 from tensorboardX import SummaryWriter
 
-from lcnn.config import C
+from lcnn.config import C, M
 from lcnn.utils import recursive_to
 
 
@@ -103,8 +103,8 @@ class Trainer(object):
         training = self.model.training
         self.model.eval()
 
-        viz = osp.join(self.out, "viz", f"{self.iteration * self.batch_size:09d}")
-        npz = osp.join(self.out, "npz", f"{self.iteration * self.batch_size:09d}")
+        viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
+        npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
         osp.exists(viz) or os.makedirs(viz)
         osp.exists(npz) or os.makedirs(npz)
 
@@ -124,7 +124,7 @@ class Trainer(object):
 
                 H = result["preds"]
                 for i in range(H["jmap"].shape[0]):
-                    index = batch_idx * self.batch_size + i
+                    index = batch_idx * M.batch_size_eval + i
                     np.savez(
                         f"{npz}/{index:06}.npz",
                         **{k: v[i].cpu().numpy() for k, v in H.items()},

+ 4 - 1
train.py

@@ -104,7 +104,10 @@ def main():
         **kwargs,
     )
     val_loader = torch.utils.data.DataLoader(
-        WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs
+        WireframeDataset(datadir, split="valid"),
+        shuffle=False,
+        batch_size=M.batch_size_eval,
+        **kwargs,
     )
     epoch_size = len(train_loader)
     # print("epoch_size (train):", epoch_size)