import torch from torch import nn class ArcHeads(nn.Sequential): def __init__(self, in_channels, layers): d = [] next_feature = in_channels for out_channels in layers: d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1)) d.append(nn.ReLU(inplace=True)) next_feature = out_channels super().__init__(*d) for m in self.children(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(m.bias, 0) class ArcPredictor(nn.Module): def __init__(self, in_channels, out_channels=1 ): super().__init__() input_features = in_channels deconv_kernel = 4 self.kps_score_lowres = nn.ConvTranspose2d( input_features, out_channels, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1, ) nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(self.kps_score_lowres.bias, 0) self.up_scale = 2 self.out_channels = out_channels def forward(self, x): print(f'before kps_score_lowres x:{x.shape}') x = self.kps_score_lowres(x) print(f'kps_score_lowres x:{x.shape}') return x # return torch.nn.functional.interpolate( # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False # ) import torch import torch.nn as nn import torch.nn.functional as F class ArcEquationHead(nn.Module): """ Input: feature_logits : [N, 1, H, W] Output: arc_params : [N, 7] 0: center x (cx) 1: center y (cy) 2: long axis length (a) 3: short axis length (b) 4: ellipse angle (theta) 5: auxiliary x coordinate 6: auxiliary y coordinate """ def __init__(self, num_outputs=9, hidden=512): super().__init__() # Use GAP to remove spatial dependency self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Final MLP that maps pooled feature ¡ú arc parameters self.mlp = nn.Sequential( nn.Linear(1, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, num_outputs) ) def forward(self, feature_logits): """ feature_logits: [N, 1, H, W] """ N, _, H, W = feature_logits.shape # -------------------------------------------- # Global average pooling # Input : [N, 1, H, W] # Output : [N, 1] # -------------------------------------------- x = self.gap(feature_logits) x = x.view(N, -1) # Predict raw parameters arc_params = self.mlp(x) # [N, 7] # -------------------------------------------- # Parameter constraints # -------------------------------------------- # H=1500 # W = 2000 # Ellipse center arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx in image width range arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy in image height range # Axes lengths must be positive arc_params[..., 2] = torch.sigmoid(arc_params[..., 2]) * W # cx in image width range arc_params[..., 3] = torch.sigmoid(arc_params[..., 3]) * W # cy in image height range # arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0 # arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0 # Angle between 0~2¦Ð arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535) # ------------------------------------------------ # Last two values are auxiliary points # Now mapped to the same spatial range as image # ------------------------------------------------ arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary print(f'arc_params in head:{arc_params}') return arc_params # class ArcEquationHead(nn.Module): # """ # Input: # feature_logits : [N, 1, H, W] # N:Ô²»¡Êý£¬H/W:ÌØÕ÷ͼ³ß´ç£¨¶ÔӦԭʼͼÏñ¿Õ¼äλÖã© # Output: # arc_params : [N, 9] # [cx, cy, a, b, theta, x1, y1, x2, y2] # """ # # def __init__(self, num_outputs=9, hidden=1024, feat_size=(672, 672)): # super().__init__() # self.feat_H, self.feat_W = feat_size # ÌØÕ÷ͼµÄ¹Ì¶¨³ß´ç£¨ÐèÓëfeature_logitsÒ»Ö£© # # self.flatten = nn.Flatten() # self.input_dim = self.feat_H * self.feat_W # ÊäÈëά¶È£ºH*W£¨¶ø·Ç1£© # # self.mlp = nn.Sequential( # nn.Linear(self.input_dim, hidden), # nn.ReLU(inplace=True), # nn.Dropout(0.2), # ·ÀÖ¹¹ýÄâºÏ # nn.Linear(hidden, hidden // 2), # nn.ReLU(inplace=True), # nn.Dropout(0.1), # nn.Linear(hidden // 2, num_outputs) # ) # # self._init_weights() # # def _init_weights(self): # for m in self.mlp.modules(): # if isinstance(m, nn.Linear): # nn.init.xavier_uniform_(m.weight) # ¾ùÔȳõʼ»¯£¬±ÜÃâÊä³ö¼¯ÖÐ # if m.bias is not None: # nn.init.zeros_(m.bias) # # def forward(self, feature_logits): # N, C, H, W = feature_logits.shape # assert H == self.feat_H and W == self.feat_W, "ÌØÕ÷ͼ³ß´çÐèÓë³õʼ»¯Ê±µÄfeat_sizeÒ»ÖÂ" # # # 1. Flatten¿Õ¼äÌØÕ÷£º[N,1,H,W] ¡ú [N, H*W]£¨±£Áôÿ¸öÏñËØµÄ¿Õ¼äÐÅÏ¢£© # x = self.flatten(feature_logits) # [N, H*W] # # # 2. MLPÔ¤²âԭʼ²ÎÊý # arc_params = self.mlp(x) # [N,9] # # # 3. ÓÅ»¯²ÎÊýÔ¼Êø£¨±ÜÃâÖÐÐÄ/Ö᳤Òì³££© # # ÍÖÔ²ÖÐÐÄ£ºÓ³Éäµ½ÌØÕ÷ͼ³ß´ç£¨ÈôÌØÕ÷ͼÊÇԭʼͼÏñϲÉÑù£¬Ðè³ËÒÔËõ·ÅÒò×Ó£© # arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx # arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy # # arc_params[..., 2] = torch.sigmoid(arc_params[..., 2]) * W # cx in image width range # arc_params[..., 3] = torch.sigmoid(arc_params[..., 3]) * H # cy in image height range # # arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0 # # arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0 # # # Angle between 0~2¦Ð # arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535) # # # ------------------------------------------------ # # Last two values are auxiliary points # # Now mapped to the same spatial range as image # # ------------------------------------------------ # arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary # arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary # arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary # arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary # print(f'arc_params in head:{arc_params}') # return arc_params