ins_predictor.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. from torch import nn
  3. import torch
  4. import torch.nn.functional as F
  5. class ArcEquationPredictor(nn.Module):
  6. def __init__(self, h=672, w=672, num_outputs=7):
  7. super().__init__()
  8. self.h = h
  9. self.w = w
  10. self.num_outputs = num_outputs
  11. # Fully connected layer to map flattened feature map to arc parameters
  12. self.fc = nn.Linear(h * w, num_outputs)
  13. def forward(self, feature_logits, arc_pos_matched_idxs):
  14. """
  15. Args:
  16. feature_logits (Tensor): shape [total_num_boxes, 1, H, W],
  17. contains all proposals from all images in the batch.
  18. arc_pos_matched_idxs (list[Tensor]): list of length B,
  19. only used for reference, not used for splitting.
  20. Returns:
  21. arc_params (Tensor): shape [total_num_boxes, num_outputs],
  22. predicted arc parameters for all proposals.
  23. """
  24. assert feature_logits.dim() == 4 and feature_logits.shape[1] == 1, \
  25. f"Expected [total_num_boxes, 1, H, W], got {feature_logits.shape}"
  26. total_num_boxes, _, H, W = feature_logits.shape
  27. # Flatten spatial dimensions
  28. x = feature_logits.view(total_num_boxes, -1) # [total_num_boxes, H*W]
  29. # Predict arc parameters for each proposal
  30. arc_params = self.fc(x) # [total_num_boxes, num_outputs]
  31. # Map raw outputs into valid ranges
  32. arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * self.w # cx
  33. arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * self.h # cy
  34. arc_params[..., 2] = F.relu(arc_params[..., 2]) # long_axis
  35. arc_params[..., 3] = F.relu(arc_params[..., 3]) # short_axis
  36. arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * 2 * 3.1415926 # ¦È1
  37. arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * 2 * 3.1415926 # ¦È2
  38. arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * 2 * 3.1415926 # ¦È3
  39. # Directly return all predictions together
  40. return arc_params