浏览代码

添加目标检测

RenLiqiang 5 月之前
父节点
当前提交
a3964d94ef
共有 4 个文件被更改,包括 4 次插入4 次删除
  1. 2 2
      models/keypoint/kepointrcnn.py
  2. 1 1
      models/keypoint/keypoint_dataset.py
  3. 1 1
      models/keypoint/train.yaml
  4. 0 0
      models/obj/__init__.py

+ 2 - 2
models/keypoint/kepointrcnn.py

@@ -19,7 +19,7 @@ from models.config.config_tool import read_yaml
 from models.keypoint.trainer import train_cfg
 
 from tools import utils
-
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
 
 class KeypointRCNNModel(nn.Module):
 
@@ -72,7 +72,7 @@ class KeypointRCNNModel(nn.Module):
 
 if __name__ == '__main__':
     # ins_model = MaskRCNNModel(num_classes=5)
-    keypoint_model = KeypointRCNNModel(num_keypoints=17)
+    keypoint_model = KeypointRCNNModel(num_keypoints=2)
     # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
     # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
     keypoint_model.train(cfg='train.yaml')

+ 1 - 1
models/keypoint/keypoint_dataset.py

@@ -127,7 +127,7 @@ class KeypointDataset(BaseDataset):
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
         keypoints= wire_labels["junc_coords"]
-        keypoints[:,2]=1
+        keypoints[:,2]=2
         # keypoints[:,0]=keypoints[:,0]/shape[0]
         # keypoints[:, 1] = keypoints[:, 1] / shape[1]
         target["keypoints"]=keypoints

+ 1 - 1
models/keypoint/train.yaml

@@ -4,7 +4,7 @@ dataset_path: I:/wirenet_dateset
 
 #train parameters
 num_classes: 2
-num_keypoints: 17
+num_keypoints: 2
 opt: 'adamw'
 batch_size: 2
 epochs: 10

+ 0 - 0
models/obj/__init__.py