| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- import torch
- from torch import nn
- class ArcHeads(nn.Sequential):
- def __init__(self, in_channels, layers):
- d = []
- next_feature = in_channels
- for out_channels in layers:
- d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
- d.append(nn.ReLU(inplace=True))
- next_feature = out_channels
- super().__init__(*d)
- for m in self.children():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- nn.init.constant_(m.bias, 0)
- class ArcPredictor(nn.Module):
- def __init__(self, in_channels, out_channels=1 ):
- super().__init__()
- input_features = in_channels
- deconv_kernel = 4
- self.kps_score_lowres = nn.ConvTranspose2d(
- input_features,
- out_channels,
- deconv_kernel,
- stride=2,
- padding=deconv_kernel // 2 - 1,
- )
- nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
- nn.init.constant_(self.kps_score_lowres.bias, 0)
- self.up_scale = 2
- self.out_channels = out_channels
- def forward(self, x):
- print(f'before kps_score_lowres x:{x.shape}')
- x = self.kps_score_lowres(x)
- print(f'kps_score_lowres x:{x.shape}')
- return x
- # return torch.nn.functional.interpolate(
- # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
- # )
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class ArcEquationHead(nn.Module):
- """
- Input:
- feature_logits : [N, 1, H, W]
- Output:
- arc_params : [N, 7]
- 0: center x (cx)
- 1: center y (cy)
- 2: long axis length (a)
- 3: short axis length (b)
- 4: ellipse angle (theta)
- 5: auxiliary x coordinate
- 6: auxiliary y coordinate
- """
- def __init__(self, num_outputs=9, hidden=512):
- super().__init__()
- # Use GAP to remove spatial dependency
- self.gap = nn.AdaptiveAvgPool2d((1, 1))
- # Final MLP that maps pooled feature ¡ú arc parameters
- self.mlp = nn.Sequential(
- nn.Linear(1, hidden),
- nn.ReLU(inplace=True),
- nn.Linear(hidden, num_outputs)
- )
- def forward(self, feature_logits):
- """
- feature_logits: [N, 1, H, W]
- """
- N, _, H, W = feature_logits.shape
- # --------------------------------------------
- # Global average pooling
- # Input : [N, 1, H, W]
- # Output : [N, 1]
- # --------------------------------------------
- x = self.gap(feature_logits)
- x = x.view(N, -1)
- # Predict raw parameters
- arc_params = self.mlp(x) # [N, 7]
- # --------------------------------------------
- # Parameter constraints
- # --------------------------------------------
- # Ellipse center
- arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx in image width range
- arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy in image height range
- # Axes lengths must be positive
- arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0
- arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0
- # Angle between 0~2¦Ð
- arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535)
- # ------------------------------------------------
- # Last two values are auxiliary points
- # Now mapped to the same spatial range as image
- # ------------------------------------------------
- arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary
- arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary
- arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary
- arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary
- return arc_params
|