|
@@ -66,7 +66,7 @@ class KeypointDataset(BaseDataset):
|
|
|
return len(self.imgs)
|
|
|
|
|
|
def read_target(self, item, lbl_path, shape, extra=None):
|
|
|
- print(f'shape:{shape}')
|
|
|
+ # print(f'shape:{shape}')
|
|
|
# print(f'lbl_path:{lbl_path}')
|
|
|
with open(lbl_path, 'r') as file:
|
|
|
lable_all = json.load(file)
|
|
@@ -123,17 +123,18 @@ class KeypointDataset(BaseDataset):
|
|
|
|
|
|
target["labels"] = torch.stack(labels)
|
|
|
# print(f'labels:{target["labels"]}')
|
|
|
- target["boxes"] = line_boxes(target)
|
|
|
+ # target["boxes"] = line_boxes(target)
|
|
|
+ target["boxes"], keypoints = line_boxes(target)
|
|
|
# visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
|
|
|
|
|
|
- keypoints= wire_labels["junc_coords"]
|
|
|
- keypoints[:,2]=2
|
|
|
- # keypoints[:,0]=keypoints[:,0]/shape[0]
|
|
|
- # keypoints[:, 1] = keypoints[:, 1] / shape[1]
|
|
|
- target["keypoints"]=keypoints
|
|
|
+ # keypoints= wire_labels["junc_coords"]
|
|
|
+ a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
|
|
|
+ keypoints = torch.cat((keypoints, a), dim=1)
|
|
|
+ target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
|
|
|
+ # print(f'boxes:{target["boxes"].shape}')
|
|
|
# 在 __getitem__ 方法中调用此函数
|
|
|
validate_keypoints(keypoints, shape[0], shape[1])
|
|
|
- print(f'keypoints:{target["keypoints"].shape}')
|
|
|
+ # print(f'keypoints:{target["keypoints"].shape}')
|
|
|
return target
|
|
|
|
|
|
def show(self, idx):
|