1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- import time
- import numpy as np
- import torch
- from matplotlib import pyplot as plt
- from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
- from torchvision.io import decode_image, read_image
- import torchvision.transforms.functional as F
- from torchvision.utils import draw_keypoints, draw_bounding_boxes
- from models.keypoint.kepointrcnn import KeypointRCNNModel
- def show(imgs):
- if not isinstance(imgs, list):
- imgs = [imgs]
- fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
- for i, img in enumerate(imgs):
- img = img.detach()
- img = F.to_pil_image(img)
- axs[0, i].imshow(np.asarray(img))
- axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
- # img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
- # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
- img_path = r"I:\datasets\wirenet_1000\images\train\00031644_0.png"
- img_int = read_image(img_path)
- # person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
- weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
- transforms = weights.transforms()
- print(f'transforms:{transforms}')
- img = transforms(img_int)
- person_float = transforms(img)
- model = KeypointRCNNModel(num_keypoints=2)
- print(f'start to load pretraine weight!')
- model.load_weight('./train_results/20241226_171710/weights/best.pt')
- print(f'loaded weight !!!')
- # model.to(device)
- model.eval()
- # model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
- # model = model.eval()
- t1 = time.time()
- # img = torch.ones((3, 3, 512, 512))
- print(f't1:{t1}')
- outputs = model([img])
- t2 = time.time()
- print(f'time:{t2 - t1}')
- # print(f'outputs:{outputs}')
- kpts = outputs[0]['keypoints']
- scores = outputs[0]['scores']
- boxes= outputs[0]['boxes']
- print(f'kpts:{kpts}')
- print(f'scores:{scores}')
- detect_threshold = 0.001
- idx = torch.where(scores > detect_threshold)
- keypoints = kpts[idx]
- # print(f'keypoints:{keypoints}')
- res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
- res_box=draw_bounding_boxes(img_int,boxes)
- show(res_box)
- plt.show()
|