import torch from torch import nn class PointHeads(nn.Sequential): def __init__(self, in_channels, layers): d = [] next_feature = in_channels for out_channels in layers: d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1)) d.append(nn.ReLU(inplace=True)) next_feature = out_channels super().__init__(*d) for m in self.children(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(m.bias, 0) class PointPredictor(nn.Module): def __init__(self, in_channels, out_channels=1 ): super().__init__() input_features = in_channels deconv_kernel = 4 self.kps_score_lowres = nn.ConvTranspose2d( input_features, out_channels, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1, ) nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(self.kps_score_lowres.bias, 0) self.up_scale = 2 self.out_channels = out_channels def forward(self, x): # print(f'before kps_score_lowres x:{x.shape}') x = self.kps_score_lowres(x) # print(f'kps_score_lowres x:{x.shape}') return torch.nn.functional.interpolate( x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False )