| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- import torch
- from torch import nn
- import torch
- import torch.nn.functional as F
- class ArcEquationPredictor(nn.Module):
- def __init__(self, h=512, w=672, num_outputs=7):
- super().__init__()
- self.h = h
- self.w = w
- self.num_outputs = num_outputs
- # Fully connected layer to map flattened feature map to arc parameters
- self.fc = nn.Linear(h * w, num_outputs)
- def forward(self, feature_logits, arc_pos_matched_idxs):
- """
- Args:
- feature_logits (Tensor): shape [total_num_boxes, 1, H, W],
- contains all proposals from all images in the batch.
- arc_pos_matched_idxs (list[Tensor]): list of length B,
- only used for reference, not used for splitting.
- Returns:
- arc_params (Tensor): shape [total_num_boxes, num_outputs],
- predicted arc parameters for all proposals.
- """
- assert feature_logits.dim() == 4 and feature_logits.shape[1] == 1, \
- f"Expected [total_num_boxes, 1, H, W], got {feature_logits.shape}"
- total_num_boxes, _, H, W = feature_logits.shape
- # Flatten spatial dimensions
- x = feature_logits.view(total_num_boxes, -1) # [total_num_boxes, H*W]
- # Predict arc parameters for each proposal
- arc_params = self.fc(x) # [total_num_boxes, num_outputs]
- # Map raw outputs into valid ranges
- arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * self.w # cx
- arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * self.h # cy
- arc_params[..., 2] = F.relu(arc_params[..., 2]) # long_axis
- arc_params[..., 3] = F.relu(arc_params[..., 3]) # short_axis
- arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * 2 * 3.1415926 # ¦Ã1
- arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * 2 * 3.1415926 # ¦Ã2
- arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * 2 * 3.1415926 # ¦Ã3
- # Directly return all predictions together
- return arc_params
|