import torch from torch import nn class ArcHeads(nn.Sequential): def __init__(self, in_channels, layers): d = [] next_feature = in_channels for out_channels in layers: d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1)) d.append(nn.ReLU(inplace=True)) next_feature = out_channels super().__init__(*d) for m in self.children(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(m.bias, 0) class ArcPredictor(nn.Module): def __init__(self, in_channels, out_channels=1 ): super().__init__() input_features = in_channels deconv_kernel = 4 self.kps_score_lowres = nn.ConvTranspose2d( input_features, out_channels, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1, ) nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(self.kps_score_lowres.bias, 0) self.up_scale = 2 self.out_channels = out_channels def forward(self, x): print(f'before kps_score_lowres x:{x.shape}') x = self.kps_score_lowres(x) print(f'kps_score_lowres x:{x.shape}') return x # return torch.nn.functional.interpolate( # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False # ) import torch import torch.nn as nn import torch.nn.functional as F class ArcEquationHead(nn.Module): """ Input: feature_logits : [N, 1, H, W] Output: arc_params : [N, 7] 0: center x (cx) 1: center y (cy) 2: long axis length (a) 3: short axis length (b) 4: ellipse angle (theta) 5: auxiliary x coordinate 6: auxiliary y coordinate """ def __init__(self, num_outputs=9, hidden=512): super().__init__() # Use GAP to remove spatial dependency self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Final MLP that maps pooled feature ¡ú arc parameters self.mlp = nn.Sequential( nn.Linear(1, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, num_outputs) ) def forward(self, feature_logits): """ feature_logits: [N, 1, H, W] """ N, _, H, W = feature_logits.shape # -------------------------------------------- # Global average pooling # Input : [N, 1, H, W] # Output : [N, 1] # -------------------------------------------- x = self.gap(feature_logits) x = x.view(N, -1) # Predict raw parameters arc_params = self.mlp(x) # [N, 7] # -------------------------------------------- # Parameter constraints # -------------------------------------------- # Ellipse center arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx in image width range arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy in image height range # Axes lengths must be positive arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0 arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0 # Angle between 0~2¦Ð arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535) # ------------------------------------------------ # Last two values are auxiliary points # Now mapped to the same spatial range as image # ------------------------------------------------ arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary return arc_params