|
@@ -1103,8 +1103,9 @@ class RoIHeads(nn.Module):
|
|
losses.update(loss_keypoint)
|
|
losses.update(loss_keypoint)
|
|
|
|
|
|
if self.has_wirepoint():
|
|
if self.has_wirepoint():
|
|
- # print(f'result:{result}')
|
|
|
|
|
|
+ print(f'wirepoint result:{result}')
|
|
wirepoint_proposals = [p["boxes"] for p in result]
|
|
wirepoint_proposals = [p["boxes"] for p in result]
|
|
|
|
+
|
|
if self.training:
|
|
if self.training:
|
|
# during training, only focus on positive boxes
|
|
# during training, only focus on positive boxes
|
|
num_images = len(proposals)
|
|
num_images = len(proposals)
|
|
@@ -1121,20 +1122,20 @@ class RoIHeads(nn.Module):
|
|
pos_matched_idxs = None
|
|
pos_matched_idxs = None
|
|
|
|
|
|
# print(f'proposals:{len(proposals)}')
|
|
# print(f'proposals:{len(proposals)}')
|
|
|
|
+ print(f'wirepoint_proposals:{wirepoint_proposals}')
|
|
wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
|
|
wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
|
|
|
|
|
|
# tmp = keypoint_features[0][0]
|
|
# tmp = keypoint_features[0][0]
|
|
# plt.imshow(tmp.detach().numpy())
|
|
# plt.imshow(tmp.detach().numpy())
|
|
- # print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
|
|
|
|
|
|
+ print(f'wirepoint_features from roi_pool:{wirepoint_features.shape}')
|
|
outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
|
|
outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+ print(f'outputs1 from head:{outputs.shape}')
|
|
|
|
|
|
outputs = merge_features(outputs, wirepoint_proposals)
|
|
outputs = merge_features(outputs, wirepoint_proposals)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
|
|
|
|
|
|
+ # wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
|
|
|
|
|
|
print(f'outpust:{outputs.shape}')
|
|
print(f'outpust:{outputs.shape}')
|
|
|
|
|
|
@@ -1208,7 +1209,7 @@ class RoIHeads(nn.Module):
|
|
# return merged_features
|
|
# return merged_features
|
|
|
|
|
|
def merge_features(features, proposals):
|
|
def merge_features(features, proposals):
|
|
- print(f'features:{features.shape}')
|
|
|
|
|
|
+ print(f'features in merge_features:{features.shape}')
|
|
print(f'proposals:{len(proposals)}')
|
|
print(f'proposals:{len(proposals)}')
|
|
def diagnose_input(features, proposals):
|
|
def diagnose_input(features, proposals):
|
|
"""诊断输入数据"""
|
|
"""诊断输入数据"""
|