|
|
@@ -93,16 +93,18 @@ def main():
|
|
|
|
|
|
datadir = C.io.datadir
|
|
|
kwargs = {
|
|
|
- "batch_size": M.batch_size,
|
|
|
"collate_fn": collate,
|
|
|
"num_workers": C.io.num_workers,
|
|
|
"pin_memory": True,
|
|
|
}
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
|
- WireframeDataset(datadir, split="train"), shuffle=True, **kwargs
|
|
|
+ WireframeDataset(datadir, split="train"),
|
|
|
+ shuffle=True,
|
|
|
+ batch_size=M.batch_size,
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
val_loader = torch.utils.data.DataLoader(
|
|
|
- WireframeDataset(datadir, split="valid"), shuffle=False, **kwargs
|
|
|
+ WireframeDataset(datadir, split="valid"), shuffle=False, batch_size=2, **kwargs
|
|
|
)
|
|
|
epoch_size = len(train_loader)
|
|
|
# print("epoch_size (train):", epoch_size)
|