import time import numpy as np import torch from matplotlib import pyplot as plt from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from torchvision.io import decode_image, read_image import torchvision.transforms.functional as F from torchvision.utils import draw_keypoints, draw_bounding_boxes from models.keypoint.kepointrcnn import KeypointRCNNModel def show(imgs): if not isinstance(imgs, list): imgs = [imgs] fig, axs = plt.subplots(ncols=len(imgs), squeeze=False) for i, img in enumerate(imgs): img = img.detach() img = F.to_pil_image(img) axs[0, i].imshow(np.asarray(img)) axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) # img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg" # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png" img_path = r"I:\datasets\wirenet_1000\images\train\00031644_0.png" img_int = read_image(img_path) # person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg") device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() print(f'transforms:{transforms}') img = transforms(img_int) person_float = transforms(img) model = KeypointRCNNModel(num_keypoints=2) print(f'start to load pretraine weight!') model.load_weight('./train_results/20241226_171710/weights/best.pt') print(f'loaded weight !!!') # model.to(device) model.eval() # model = keypointrcnn_resnet50_fpn(weights=None, progress=False) # model = model.eval() t1 = time.time() # img = torch.ones((3, 3, 512, 512)) print(f't1:{t1}') outputs = model([img]) t2 = time.time() print(f'time:{t2 - t1}') # print(f'outputs:{outputs}') kpts = outputs[0]['keypoints'] scores = outputs[0]['scores'] boxes= outputs[0]['boxes'] print(f'kpts:{kpts}') print(f'scores:{scores}') detect_threshold = 0.001 idx = torch.where(scores > detect_threshold) keypoints = kpts[idx] # print(f'keypoints:{keypoints}') res = draw_keypoints(img_int, keypoints, colors="blue", radius=3) res_box=draw_bounding_boxes(img_int,boxes) show(res_box) plt.show()