lstrlq 5 달 전
부모
커밋
e91f43aad4
3개의 변경된 파일7개의 추가작업 그리고 6개의 파일을 삭제
  1. 2 2
      models/keypoint/keypoint_dataset.py
  2. 3 3
      models/keypoint/train.yaml
  3. 2 1
      models/keypoint/trainer.py

+ 2 - 2
models/keypoint/keypoint_dataset.py

@@ -38,8 +38,8 @@ class KeypointDataset(BaseDataset):
         self.data_path = dataset_path
         print(f'data_path:{dataset_path}')
         self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
         self.imgs = os.listdir(self.img_path)
         self.lbls = os.listdir(self.lbl_path)
         self.target_type = target_type

+ 3 - 3
models/keypoint/train.yaml

@@ -1,13 +1,13 @@
 
 
-dataset_path: I:/wirenet_dateset
+dataset_path: /home/admin/tmp/wirenet_1000
 
 #train parameters
 num_classes: 2
 num_keypoints: 2
 opt: 'adamw'
-batch_size: 2
-epochs: 10
+batch_size: 4
+epochs: 50000
 lr: 0.005
 momentum: 0.9
 weight_decay: 0.0001

+ 2 - 1
models/keypoint/trainer.py

@@ -278,7 +278,8 @@ def train(model, **kwargs):
                                    dataset_type='val')
 
     train_sampler = torch.utils.data.RandomSampler(dataset)
-    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    test_sampler = torch.utils.data.RandomSampler(dataset_test)
     train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
     train_collate_fn = utils.collate_fn
     data_loader = torch.utils.data.DataLoader(