|
@@ -95,7 +95,7 @@ class Trainer(BaseTrainer):
|
|
dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=10, drop_last=True)
|
|
|
|
|
|
+ train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
|
|
@@ -104,7 +104,7 @@ class Trainer(BaseTrainer):
|
|
dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
|
|
dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
val_sampler = torch.utils.data.RandomSampler(dataset_val)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
# test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
|
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=10, drop_last=True)
|
|
|
|
|
|
+ val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
|
|
val_collate_fn = utils.collate_fn_wirepoint
|
|
val_collate_fn = utils.collate_fn_wirepoint
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|
|
dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
|
|
@@ -148,7 +148,8 @@ class Trainer(BaseTrainer):
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.step()
|
|
- self.writer_loss(writer, losses, epoch)
|
|
|
|
|
|
+ self.writer_loss(writer, losses, global_step)
|
|
|
|
+ global_step+=1
|
|
|
|
|
|
|
|
|
|
avg_train_loss = total_train_loss / len(data_loader_train)
|
|
avg_train_loss = total_train_loss / len(data_loader_train)
|