| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import torch
- from torch import nn
- class ArcHeads(nn.Sequential):
- def __init__(self, in_channels, layers):
- d = []
- next_feature = in_channels
- for out_channels in layers:
- d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
- d.append(nn.ReLU(inplace=True))
- next_feature = out_channels
- super().__init__(*d)
- for m in self.children():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- nn.init.constant_(m.bias, 0)
- class ArcPredictor(nn.Module):
- def __init__(self, in_channels, out_channels=1 ):
- super().__init__()
- input_features = in_channels
- deconv_kernel = 4
- self.kps_score_lowres = nn.ConvTranspose2d(
- input_features,
- out_channels,
- deconv_kernel,
- stride=2,
- padding=deconv_kernel // 2 - 1,
- )
- nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
- nn.init.constant_(self.kps_score_lowres.bias, 0)
- self.up_scale = 2
- self.out_channels = out_channels
- def forward(self, x):
- print(f'before kps_score_lowres x:{x.shape}')
- x = self.kps_score_lowres(x)
- print(f'kps_score_lowres x:{x.shape}')
- return x
- # return torch.nn.functional.interpolate(
- # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
- # )
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class ArcEquationHead(nn.Module):
- """
- Input:
- feature_logits : [N, 1, H, W]
- Output:
- arc_params : [N, 7]
- 0: center x (cx)
- 1: center y (cy)
- 2: long axis length (a)
- 3: short axis length (b)
- 4: ellipse angle (theta)
- 5: auxiliary x coordinate
- 6: auxiliary y coordinate
- """
- def __init__(self, num_outputs=9, hidden=512):
- super().__init__()
- # Use GAP to remove spatial dependency
- self.gap = nn.AdaptiveAvgPool2d((1, 1))
- # Final MLP that maps pooled feature ¡ú arc parameters
- self.mlp = nn.Sequential(
- nn.Linear(1, hidden),
- nn.ReLU(inplace=True),
- nn.Linear(hidden, num_outputs)
- )
- def forward(self, feature_logits):
- """
- feature_logits: [N, 1, H, W]
- """
- N, _, H, W = feature_logits.shape
- # --------------------------------------------
- # Global average pooling
- # Input : [N, 1, H, W]
- # Output : [N, 1]
- # --------------------------------------------
- x = self.gap(feature_logits)
- x = x.view(N, -1)
- # Predict raw parameters
- arc_params = self.mlp(x) # [N, 7]
- # --------------------------------------------
- # Parameter constraints
- # --------------------------------------------
- # H=1500
- # W = 2000
- # Ellipse center
- arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx in image width range
- arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy in image height range
- # Axes lengths must be positive
- arc_params[..., 2] = torch.sigmoid(arc_params[..., 2]) * W # cx in image width range
- arc_params[..., 3] = torch.sigmoid(arc_params[..., 3]) * W # cy in image height range
- # arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0
- # arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0
- # Angle between 0~2¦Ð
- arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535)
- # ------------------------------------------------
- # Last two values are auxiliary points
- # Now mapped to the same spatial range as image
- # ------------------------------------------------
- arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary
- arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary
- arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary
- arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary
- print(f'arc_params in head:{arc_params}')
- return arc_params
- # class ArcEquationHead(nn.Module):
- # """
- # Input:
- # feature_logits : [N, 1, H, W] # N:Ô²»¡Êý£¬H/W:ÌØÕ÷ͼ³ß´ç£¨¶ÔÓ¦ÔʼͼÏñ¿Õ¼äλÖã©
- # Output:
- # arc_params : [N, 9] # [cx, cy, a, b, theta, x1, y1, x2, y2]
- # """
- #
- # def __init__(self, num_outputs=9, hidden=1024, feat_size=(672, 672)):
- # super().__init__()
- # self.feat_H, self.feat_W = feat_size # ÌØÕ÷ͼµÄ¹Ì¶¨³ß´ç£¨ÐèÓëfeature_logitsÒ»Ö£©
- #
- # self.flatten = nn.Flatten()
- # self.input_dim = self.feat_H * self.feat_W # ÊäÈëά¶È£ºH*W£¨¶ø·Ç1£©
- #
- # self.mlp = nn.Sequential(
- # nn.Linear(self.input_dim, hidden),
- # nn.ReLU(inplace=True),
- # nn.Dropout(0.2), # ·ÀÖ¹¹ýÄâºÏ
- # nn.Linear(hidden, hidden // 2),
- # nn.ReLU(inplace=True),
- # nn.Dropout(0.1),
- # nn.Linear(hidden // 2, num_outputs)
- # )
- #
- # self._init_weights()
- #
- # def _init_weights(self):
- # for m in self.mlp.modules():
- # if isinstance(m, nn.Linear):
- # nn.init.xavier_uniform_(m.weight) # ¾ùÔȳõʼ»¯£¬±ÜÃâÊä³ö¼¯ÖÐ
- # if m.bias is not None:
- # nn.init.zeros_(m.bias)
- #
- # def forward(self, feature_logits):
- # N, C, H, W = feature_logits.shape
- # assert H == self.feat_H and W == self.feat_W, "ÌØÕ÷ͼ³ß´çÐèÓë³õʼ»¯Ê±µÄfeat_sizeÒ»ÖÂ"
- #
- # # 1. Flatten¿Õ¼äÌØÕ÷£º[N,1,H,W] ¡ú [N, H*W]£¨±£Áôÿ¸öÏñËØµÄ¿Õ¼äÐÅÏ¢£©
- # x = self.flatten(feature_logits) # [N, H*W]
- #
- # # 2. MLPÔ¤²âÔʼ²ÎÊý
- # arc_params = self.mlp(x) # [N,9]
- #
- # # 3. ÓÅ»¯²ÎÊýÔ¼Êø£¨±ÜÃâÖÐÐÄ/Ö᳤Òì³££©
- # # ÍÖÔ²ÖÐÐÄ£ºÓ³Éäµ½ÌØÕ÷ͼ³ß´ç£¨ÈôÌØÕ÷ͼÊÇÔʼͼÏñϲÉÑù£¬Ðè³ËÒÔËõ·ÅÒò×Ó£©
- # arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx
- # arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy
- #
- # arc_params[..., 2] = torch.sigmoid(arc_params[..., 2]) * W # cx in image width range
- # arc_params[..., 3] = torch.sigmoid(arc_params[..., 3]) * H # cy in image height range
- # # arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0
- # # arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0
- #
- # # Angle between 0~2¦Ð
- # arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535)
- #
- # # ------------------------------------------------
- # # Last two values are auxiliary points
- # # Now mapped to the same spatial range as image
- # # ------------------------------------------------
- # arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary
- # arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary
- # arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary
- # arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary
- # print(f'arc_params in head:{arc_params}')
- # return arc_params
|