Преглед на файлове

reduce evaluation batch size

Yichao Zhou преди 6 години
родител
ревизия
0082f1738c
променени са 1 файла, в които са добавени 5 реда и са изтрити 3 реда
  1. 5 3
      train.py

+ 5 - 3
train.py

@@ -93,16 +93,18 @@ def main():
 
     datadir = C.io.datadir
     kwargs = {
-        "batch_size": M.batch_size,
         "collate_fn": collate,
         "num_workers": C.io.num_workers,
         "pin_memory": True,
     }
     train_loader = torch.utils.data.DataLoader(
-        WireframeDataset(datadir, split="train"), shuffle=True, **kwargs
+        WireframeDataset(datadir, split="train"),
+        shuffle=True,
+        batch_size=M.batch_size,
+        **kwargs,
     )
     val_loader = torch.utils.data.DataLoader(
-        WireframeDataset(datadir, split="valid"), shuffle=False, **kwargs
+        WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs
     )
     epoch_size = len(train_loader)
     # print("epoch_size (train):", epoch_size)