|
@@ -789,7 +789,9 @@ def line_inference(x, line_boxes):
|
|
|
|
|
|
|
|
return lines_probs, lines_scores
|
|
return lines_probs, lines_scores
|
|
|
|
|
|
|
|
-def arc_inference(x, arc_boxes):
|
|
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def arc_inference(x, arc_boxes,th):
|
|
|
# type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
|
# type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
|
|
|
|
|
|
|
points_probs = []
|
|
points_probs = []
|
|
@@ -810,7 +812,39 @@ def arc_inference(x, arc_boxes):
|
|
|
|
|
|
|
|
points_probs_tensor=torch.cat(points_probs)
|
|
points_probs_tensor=torch.cat(points_probs)
|
|
|
print(f'points_probs shape:{points_probs_tensor.shape}')
|
|
print(f'points_probs shape:{points_probs_tensor.shape}')
|
|
|
- return points_probs,points_scores
|
|
|
|
|
|
|
+
|
|
|
|
|
+ feature_logits = x
|
|
|
|
|
+ batch_size = feature_logits.shape[0]
|
|
|
|
|
+ num_proposals = len(arc_boxes[0])
|
|
|
|
|
+
|
|
|
|
|
+ results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
|
|
|
|
|
+ proposals_list = arc_boxes[0] # [[tensor(...)]]
|
|
|
|
|
+
|
|
|
|
|
+ for proposal_idx, proposal in enumerate(proposals_list):
|
|
|
|
|
+ coords = proposal.tolist()
|
|
|
|
|
+ x1, y1, x2, y2 = map(int, coords)
|
|
|
|
|
+
|
|
|
|
|
+ x1 = max(0, x1)
|
|
|
|
|
+ y1 = max(0, y1)
|
|
|
|
|
+ x2 = min(feature_logits.shape[3], x2)
|
|
|
|
|
+ y2 = min(feature_logits.shape[2], y2)
|
|
|
|
|
+
|
|
|
|
|
+ for batch_idx in range(batch_size):
|
|
|
|
|
+ region = feature_logits[batch_idx, :, y1:y2, x1:x2]
|
|
|
|
|
+ mask = region > th
|
|
|
|
|
+ coords = torch.nonzero(mask)
|
|
|
|
|
+
|
|
|
|
|
+ if coords.numel() > 0:
|
|
|
|
|
+ # 取 (y, x),然后转换为全局坐标 (x, y)
|
|
|
|
|
+ local_coords = coords[:, [2, 1]] # (x, y)
|
|
|
|
|
+ local_coords[:, 0] += x1
|
|
|
|
|
+ local_coords[:, 1] += y1
|
|
|
|
|
+
|
|
|
|
|
+ results[batch_idx][proposal_idx] = local_coords
|
|
|
|
|
+
|
|
|
|
|
+ print(f're:{results}')
|
|
|
|
|
+
|
|
|
|
|
+ return points_probs,points_scores,results
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|