| 12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- import torch
- from torch import nn
- class ArcHeads(nn.Sequential):
- def __init__(self, in_channels, layers):
- d = []
- next_feature = in_channels
- for out_channels in layers:
- d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
- d.append(nn.ReLU(inplace=True))
- next_feature = out_channels
- super().__init__(*d)
- for m in self.children():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- nn.init.constant_(m.bias, 0)
- class ArcPredictor(nn.Module):
- def __init__(self, in_channels, out_channels=1 ):
- super().__init__()
- input_features = in_channels
- deconv_kernel = 4
- self.kps_score_lowres = nn.ConvTranspose2d(
- input_features,
- out_channels,
- deconv_kernel,
- stride=2,
- padding=deconv_kernel // 2 - 1,
- )
- nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
- nn.init.constant_(self.kps_score_lowres.bias, 0)
- self.up_scale = 2
- self.out_channels = out_channels
- def forward(self, x):
- print(f'before kps_score_lowres x:{x.shape}')
- x = self.kps_score_lowres(x)
- print(f'kps_score_lowres x:{x.shape}')
- return x
- # return torch.nn.functional.interpolate(
- # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
- # )
|