import torch from torch import nn import torch import torch.nn.functional as F class ArcEquationPredictor(nn.Module): def __init__(self, h=512, w=672, num_outputs=7): super().__init__() self.h = h self.w = w self.num_outputs = num_outputs # Fully connected layer to map flattened feature map to arc parameters self.fc = nn.Linear(h * w, num_outputs) def forward(self, feature_logits, arc_pos_matched_idxs): """ Args: feature_logits (Tensor): shape [total_num_boxes, 1, H, W], contains all proposals from all images in the batch. arc_pos_matched_idxs (list[Tensor]): list of length B, only used for reference, not used for splitting. Returns: arc_params (Tensor): shape [total_num_boxes, num_outputs], predicted arc parameters for all proposals. """ assert feature_logits.dim() == 4 and feature_logits.shape[1] == 1, \ f"Expected [total_num_boxes, 1, H, W], got {feature_logits.shape}" total_num_boxes, _, H, W = feature_logits.shape # Flatten spatial dimensions x = feature_logits.view(total_num_boxes, -1) # [total_num_boxes, H*W] # Predict arc parameters for each proposal arc_params = self.fc(x) # [total_num_boxes, num_outputs] # Map raw outputs into valid ranges arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * self.w # cx arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * self.h # cy arc_params[..., 2] = F.relu(arc_params[..., 2]) # long_axis arc_params[..., 3] = F.relu(arc_params[..., 3]) # short_axis arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * 2 * 3.1415926 # ¦È1 arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * 2 * 3.1415926 # ¦È2 arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * 2 * 3.1415926 # ¦È3 # Directly return all predictions together return arc_params