test.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import time
  2. import numpy as np
  3. import torch
  4. from matplotlib import pyplot as plt
  5. from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
  6. from torchvision.io import decode_image, read_image
  7. import torchvision.transforms.functional as F
  8. from torchvision.utils import draw_keypoints
  9. def show(imgs):
  10. if not isinstance(imgs, list):
  11. imgs = [imgs]
  12. fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
  13. for i, img in enumerate(imgs):
  14. img = img.detach()
  15. img = F.to_pil_image(img)
  16. axs[0, i].imshow(np.asarray(img))
  17. axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  18. img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
  19. # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
  20. img_int = read_image(img_path)
  21. # person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
  22. weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
  23. transforms = weights.transforms()
  24. print(f'transforms:{transforms}')
  25. img = transforms(img_int)
  26. person_float = transforms(img)
  27. model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
  28. model = model.eval()
  29. t1=time.time()
  30. # img = torch.ones((3, 3, 512, 512))
  31. outputs = model([img])
  32. t2=time.time()
  33. print(f'time:{t2-t1}')
  34. # print(f'outputs:{outputs}')
  35. kpts = outputs[0]['keypoints']
  36. scores = outputs[0]['scores']
  37. print(f'kpts:{kpts}')
  38. print(f'scores:{scores}')
  39. detect_threshold = 0.75
  40. idx = torch.where(scores > detect_threshold)
  41. keypoints = kpts[idx]
  42. # print(f'keypoints:{keypoints}')
  43. res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
  44. show(res)
  45. plt.show()