|
@@ -19,7 +19,7 @@ from models.config.config_tool import read_yaml
|
|
|
from models.keypoint.trainer import train_cfg
|
|
|
|
|
|
from tools import utils
|
|
|
-
|
|
|
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
|
|
|
|
class KeypointRCNNModel(nn.Module):
|
|
|
|
|
@@ -72,7 +72,7 @@ class KeypointRCNNModel(nn.Module):
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
- keypoint_model = KeypointRCNNModel(num_keypoints=17)
|
|
|
+ keypoint_model = KeypointRCNNModel(num_keypoints=2)
|
|
|
|
|
|
|
|
|
keypoint_model.train(cfg='train.yaml')
|