소스 검색

reduce evaluation batch size

Yichao Zhou 6 년 전
부모
커밋
0082f1738c
1개의 변경된 파일5개의 추가작업 그리고 3개의 파일을 삭제
  1. 5 3
      train.py

+ 5 - 3
train.py

@@ -93,16 +93,18 @@ def main():
 
     datadir = C.io.datadir
     kwargs = {
-        "batch_size": M.batch_size,
         "collate_fn": collate,
         "num_workers": C.io.num_workers,
         "pin_memory": True,
     }
     train_loader = torch.utils.data.DataLoader(
-        WireframeDataset(datadir, split="train"), shuffle=True, **kwargs
+        WireframeDataset(datadir, split="train"),
+        shuffle=True,
+        batch_size=M.batch_size,
+        **kwargs,
     )
     val_loader = torch.utils.data.DataLoader(
-        WireframeDataset(datadir, split="valid"), shuffle=False, **kwargs
+        WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs
     )
     epoch_size = len(train_loader)
     # print("epoch_size (train):", epoch_size)