test_predict.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import time
  2. import numpy as np
  3. import torch
  4. from matplotlib import pyplot as plt
  5. from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
  6. from torchvision.io import decode_image, read_image
  7. import torchvision.transforms.functional as F
  8. from torchvision.utils import draw_keypoints, draw_bounding_boxes
  9. from models.keypoint.kepointrcnn import KeypointRCNNModel
  10. def show(imgs):
  11. if not isinstance(imgs, list):
  12. imgs = [imgs]
  13. fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
  14. for i, img in enumerate(imgs):
  15. img = img.detach()
  16. img = F.to_pil_image(img)
  17. axs[0, i].imshow(np.asarray(img))
  18. axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  19. # img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
  20. # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
  21. img_path = r"I:\datasets\wirenet_1000\images\train\00031644_0.png"
  22. img_int = read_image(img_path)
  23. # person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
  24. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  25. weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
  26. transforms = weights.transforms()
  27. print(f'transforms:{transforms}')
  28. img = transforms(img_int)
  29. person_float = transforms(img)
  30. model = KeypointRCNNModel(num_keypoints=2)
  31. print(f'start to load pretraine weight!')
  32. model.load_weight('./train_results/20241226_171710/weights/best.pt')
  33. print(f'loaded weight !!!')
  34. # model.to(device)
  35. model.eval()
  36. # model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
  37. # model = model.eval()
  38. t1 = time.time()
  39. # img = torch.ones((3, 3, 512, 512))
  40. print(f't1:{t1}')
  41. outputs = model([img])
  42. t2 = time.time()
  43. print(f'time:{t2 - t1}')
  44. # print(f'outputs:{outputs}')
  45. kpts = outputs[0]['keypoints']
  46. scores = outputs[0]['scores']
  47. boxes= outputs[0]['boxes']
  48. print(f'kpts:{kpts}')
  49. print(f'scores:{scores}')
  50. detect_threshold = 0.001
  51. idx = torch.where(scores > detect_threshold)
  52. keypoints = kpts[idx]
  53. # print(f'keypoints:{keypoints}')
  54. res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
  55. res_box=draw_bounding_boxes(img_int,boxes)
  56. show(res_box)
  57. plt.show()