浏览代码

a100 train edition,fixed global_step

lstrlq 3 月之前
父节点
当前提交
60aff90b5c
共有 1 个文件被更改,包括 4 次插入3 次删除
  1. 4 3
      models/line_detect/trainer.py

+ 4 - 3
models/line_detect/trainer.py

@@ -95,7 +95,7 @@ class Trainer(BaseTrainer):
         dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=10, drop_last=True)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
         train_collate_fn = utils.collate_fn_wirepoint
         data_loader_train = torch.utils.data.DataLoader(
             dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
@@ -104,7 +104,7 @@ class Trainer(BaseTrainer):
         dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=10, drop_last=True)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
         val_collate_fn = utils.collate_fn_wirepoint
         data_loader_val = torch.utils.data.DataLoader(
             dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
@@ -148,7 +148,8 @@ class Trainer(BaseTrainer):
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
-                self.writer_loss(writer, losses, epoch)
+                self.writer_loss(writer, losses, global_step)
+                global_step+=1
 
 
             avg_train_loss = total_train_loss / len(data_loader_train)