Kaynağa Gözat

reduce evaluation batch size

Yichao Zhou 6 yıl önce
ebeveyn
işleme
0082f1738c
1 değiştirilmiş dosya ile 5 ekleme ve 3 silme
  1. 5 3
      train.py

+ 5 - 3
train.py

@@ -93,16 +93,18 @@ def main():
 
     datadir = C.io.datadir
     kwargs = {
-        "batch_size": M.batch_size,
         "collate_fn": collate,
         "num_workers": C.io.num_workers,
         "pin_memory": True,
     }
     train_loader = torch.utils.data.DataLoader(
-        WireframeDataset(datadir, split="train"), shuffle=True, **kwargs
+        WireframeDataset(datadir, split="train"),
+        shuffle=True,
+        batch_size=M.batch_size,
+        **kwargs,
     )
     val_loader = torch.utils.data.DataLoader(
-        WireframeDataset(datadir, split="valid"), shuffle=False, **kwargs
+        WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs
     )
     epoch_size = len(train_loader)
     # print("epoch_size (train):", epoch_size)