|
@@ -150,7 +150,7 @@ class Trainer(BaseTrainer):
|
|
|
def train(self, model, **kwargs):
|
|
def train(self, model, **kwargs):
|
|
|
dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
|
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
train_sampler = torch.utils.data.RandomSampler(dataset_train)
|
|
|
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
|
|
|
|
|
|
|
+ train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=8, drop_last=True)
|
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
train_collate_fn = utils.collate_fn_wirepoint
|
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=1, collate_fn=train_collate_fn
|
|
dataset_train, batch_sampler=train_batch_sampler, num_workers=1, collate_fn=train_collate_fn
|
|
@@ -234,8 +234,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
|
# model = LineNet('line_net.yaml')
|
|
# model = LineNet('line_net.yaml')
|
|
|
# model = linenet_resnet50_fpn().to(device)
|
|
# model = linenet_resnet50_fpn().to(device)
|
|
|
- # model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
|
|
|
|
|
- model=get_line_net_convnext_fpn(num_classes=2).to(device)
|
|
|
|
|
|
|
+ model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
|
|
|
|
|
+ # model=get_line_net_convnext_fpn(num_classes=2).to(device)
|
|
|
# model=linenet_resnet18_fpn()
|
|
# model=linenet_resnet18_fpn()
|
|
|
trainer = Trainer()
|
|
trainer = Trainer()
|
|
|
trainer.train_cfg(model,cfg='./train.yaml')
|
|
trainer.train_cfg(model,cfg='./train.yaml')
|