|
@@ -0,0 +1,61 @@
|
|
|
+import time
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
|
|
+from torchvision.io import decode_image, read_image
|
|
|
+import torchvision.transforms.functional as F
|
|
|
+from torchvision.utils import draw_keypoints
|
|
|
+def show(imgs):
|
|
|
+ if not isinstance(imgs, list):
|
|
|
+ imgs = [imgs]
|
|
|
+ fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
|
|
|
+ for i, img in enumerate(imgs):
|
|
|
+ img = img.detach()
|
|
|
+ img = F.to_pil_image(img)
|
|
|
+ axs[0, i].imshow(np.asarray(img))
|
|
|
+ axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
|
|
|
+
|
|
|
+img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
|
|
|
+# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
|
|
|
+img_int = read_image(img_path)
|
|
|
+
|
|
|
+
|
|
|
+# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
|
|
|
+
|
|
|
+weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
|
|
+transforms = weights.transforms()
|
|
|
+print(f'transforms:{transforms}')
|
|
|
+img = transforms(img_int)
|
|
|
+
|
|
|
+person_float = transforms(img)
|
|
|
+
|
|
|
+model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
|
|
|
+model = model.eval()
|
|
|
+t1=time.time()
|
|
|
+# img = torch.ones((3, 3, 512, 512))
|
|
|
+
|
|
|
+
|
|
|
+outputs = model([img])
|
|
|
+t2=time.time()
|
|
|
+print(f'time:{t2-t1}')
|
|
|
+# print(f'outputs:{outputs}')
|
|
|
+
|
|
|
+kpts = outputs[0]['keypoints']
|
|
|
+scores = outputs[0]['scores']
|
|
|
+
|
|
|
+print(f'kpts:{kpts}')
|
|
|
+print(f'scores:{scores}')
|
|
|
+
|
|
|
+detect_threshold = 0.75
|
|
|
+idx = torch.where(scores > detect_threshold)
|
|
|
+keypoints = kpts[idx]
|
|
|
+
|
|
|
+# print(f'keypoints:{keypoints}')
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
|
|
|
+show(res)
|
|
|
+plt.show()
|