|
@@ -19,7 +19,7 @@ from models.config.config_tool import read_yaml
|
|
|
from models.keypoint.trainer import train_cfg
|
|
|
|
|
|
from tools import utils
|
|
|
-
|
|
|
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
|
|
|
|
class KeypointRCNNModel(nn.Module):
|
|
|
|
|
@@ -72,7 +72,7 @@ class KeypointRCNNModel(nn.Module):
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
# ins_model = MaskRCNNModel(num_classes=5)
|
|
|
- keypoint_model = KeypointRCNNModel(num_keypoints=17)
|
|
|
+ keypoint_model = KeypointRCNNModel(num_keypoints=2)
|
|
|
# data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
|
|
|
# ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
|
|
|
keypoint_model.train(cfg='train.yaml')
|