|
@@ -95,7 +95,7 @@ class Trainer(BaseTrainer):
|
|
dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
|
|
|
|
|
|
+ train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
@@ -104,7 +104,7 @@ class Trainer(BaseTrainer):
|
|
dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
|
|
dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
|
|
|
|
|
|
+ val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
|
|
val_collate_fn = utils.collate_fn_wirepoint
|
|
val_collate_fn = utils.collate_fn_wirepoint
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|