from typing import Any, Optional import torch from torch import nn from torchvision.ops import MultiScaleRoIAlign from libs.vision_libs.ops import misc as misc_nn_ops from libs.vision_libs.transforms._presets import ObjectDetection from .roi_heads import RoIHeads from libs.vision_libs.models._api import register_model, Weights, WeightsEnum from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights from libs.vision_libs.models.detection._utils import overwrite_eps from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor from models.config.config_tool import read_yaml import numpy as np import torch.nn.functional as F FEATURE_DIM = 8 def non_maximum_suppression(a): ap = F.max_pool2d(a, 3, stride=1, padding=1) mask = (a == ap).float().clamp(min=0.0) return a * mask class LineRCNNPredictor(nn.Module): def __init__(self, cfg): super().__init__() # self.backbone = backbone # self.cfg = read_yaml(cfg) # self.cfg = read_yaml(r'./config/wireframe.yaml') self.cfg = cfg self.n_pts0 = self.cfg['n_pts0'] self.n_pts1 = self.cfg['n_pts1'] self.n_stc_posl = self.cfg['n_stc_posl'] self.dim_loi = self.cfg['dim_loi'] self.use_conv = self.cfg['use_conv'] self.dim_fc = self.cfg['dim_fc'] self.n_out_line = self.cfg['n_out_line'] self.n_out_junc = self.cfg['n_out_junc'] self.loss_weight = self.cfg['loss_weight'] self.n_dyn_junc = self.cfg['n_dyn_junc'] self.eval_junc_thres = self.cfg['eval_junc_thres'] self.n_dyn_posl = self.cfg['n_dyn_posl'] self.n_dyn_negl = self.cfg['n_dyn_negl'] self.n_dyn_othr = self.cfg['n_dyn_othr'] self.use_cood = self.cfg['use_cood'] self.use_slop = self.cfg['use_slop'] self.n_stc_negl = self.cfg['n_stc_negl'] self.head_size = self.cfg['head_size'] self.num_class = sum(sum(self.head_size, [])) self.head_off = np.cumsum([sum(h) for h in self.head_size]) lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None] self.register_buffer("lambda_", lambda_) self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0 self.fc1 = nn.Conv2d(256, self.dim_loi, 1) scale_factor = self.n_pts0 // self.n_pts1 if self.use_conv: self.pooling = nn.Sequential( nn.MaxPool1d(scale_factor, scale_factor), Bottleneck1D(self.dim_loi, self.dim_loi), ) self.fc2 = nn.Sequential( nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1) ) else: self.pooling = nn.MaxPool1d(scale_factor, scale_factor) self.fc2 = nn.Sequential( nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc), nn.ReLU(inplace=True), nn.Linear(self.dim_fc, self.dim_fc), nn.ReLU(inplace=True), nn.Linear(self.dim_fc, 1), ) self.loss = nn.BCEWithLogitsLoss(reduction="none") def forward(self, inputs, features, targets=None): # outputs, features = input # for out in outputs: # print(f'out:{out.shape}') # outputs=merge_features(outputs,100) batch, channel, row, col = inputs.shape # print(f'outputs:{inputs.shape}') # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}') if targets is not None: self.training = True # print(f'target:{targets}') wires_targets = [t["wires"] for t in targets] # print(f'wires_target:{wires_targets}') # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量 junc_maps = [d["junc_map"] for d in wires_targets] junc_offsets = [d["junc_offset"] for d in wires_targets] line_maps = [d["line_map"] for d in wires_targets] junc_map_tensor = torch.stack(junc_maps, dim=0) junc_offset_tensor = torch.stack(junc_offsets, dim=0) line_map_tensor = torch.stack(line_maps, dim=0) wires_meta = { "junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, # "line_map": line_map_tensor, } else: self.training = False t = { "junc_coords": torch.zeros(1, 2), "jtyp": torch.zeros(1, dtype=torch.uint8), "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8), "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8), "junc_map": torch.zeros([1, 1, 128, 128]), "junc_offset": torch.zeros([1, 1, 2, 128, 128]), } wires_targets = [t for b in range(inputs.size(0))] wires_meta = { "junc_map": torch.zeros([1, 1, 128, 128]), "junc_offset": torch.zeros([1, 1, 2, 128, 128]), } T = wires_meta.copy() n_jtyp = T["junc_map"].shape[1] offset = self.head_off result = {} for stack, output in enumerate([inputs]): output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous() # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状 jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col) lmap = output[offset[0]: offset[1]].squeeze(0) joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col) if stack == 0: result["preds"] = { "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1], "lmap": lmap.sigmoid(), "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5, } # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}") # visualize_feature_map(lmap, title=f"lmap - Stack {stack}") # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}") h = result["preds"] # print(f'features shape:{features.shape}') x = self.fc1(features) # print(f'x:{x.shape}') n_batch, n_channel, row, col = x.shape # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}') xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], [] for i, meta in enumerate(wires_targets): p, label, feat, jc = self.sample_lines( meta, h["jmap"][i], h["joff"][i], ) # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}") ys.append(label) if self.training and self.do_static_sampling: p = torch.cat([p, meta["lpre"]]) feat = torch.cat([feat, meta["lpre_feat"]]) ys.append(meta["lpre_label"]) del jc else: jcs.append(jc) ps.append(p) fs.append(feat) p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5 p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY] px, py = p[:, 0].contiguous(), p[:, 1].contiguous() px0 = px.floor().clamp(min=0, max=127) py0 = py.floor().clamp(min=0, max=127) px1 = (px0 + 1).clamp(min=0, max=127) py1 = (py0 + 1).clamp(min=0, max=127) px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long() # xp: [N_LINE, N_CHANNEL, N_POINT] xp = ( ( x[i, :, px0l, py0l] * (px1 - px) * (py1 - py) + x[i, :, px1l, py0l] * (px - px0) * (py1 - py) + x[i, :, px0l, py1l] * (px1 - px) * (py - py0) + x[i, :, px1l, py1l] * (px - px0) * (py - py0) ) .reshape(n_channel, -1, self.n_pts0) .permute(1, 0, 2) ) xp = self.pooling(xp) # print(f'xp.shape:{xp.shape}') xs.append(xp) idx.append(idx[-1] + xp.shape[0]) # print(f'idx__:{idx}') x, y = torch.cat(xs), torch.cat(ys) f = torch.cat(fs) x = x.reshape(-1, self.n_pts1 * self.dim_loi) # print("Weight dtype:", self.fc2.weight.dtype) x = torch.cat([x, f], 1) # print("Input dtype:", x.dtype) x = x.to(dtype=torch.float32) # print("Input dtype1:", x.dtype) x = self.fc2(x).flatten() # return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc # if mode != "training": # self.inference(x, idx, jcs, n_batch, ps) # return result def sample_lines(self, meta, jmap, joff): device = jmap.device with torch.no_grad(): junc = meta["junc_coords"].to(device) # [N, 2] jtyp = meta["jtyp"].to(device) # [N] Lpos = meta["line_pos_idx"].to(device) Lneg = meta["line_neg_idx"].to(device) n_type = jmap.shape[0] jmap = non_maximum_suppression(jmap).reshape(n_type, -1) joff = joff.reshape(n_type, 2, -1) max_K = self.n_dyn_junc // n_type N = len(junc) # if mode != "training": if not self.training: K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K) else: K = min(int(N * 2 + 2), max_K) if K < 2: K = 2 device = jmap.device # index: [N_TYPE, K] score, index = torch.topk(jmap, k=K) y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5 x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5 # xy: [N_TYPE, K, 2] xy = torch.cat([y[..., None], x[..., None]], dim=-1) xy_ = xy[..., None, :] del x, y, index # dist: [N_TYPE, K, N] dist = torch.sum((xy_ - junc) ** 2, -1) cost, match = torch.min(dist, -1) # xy: [N_TYPE * K, 2] # match: [N_TYPE, K] for t in range(n_type): match[t, jtyp[match[t]] != t] = N match[cost > 1.5 * 1.5] = N match = match.flatten() _ = torch.arange(n_type * K, device=device) u, v = torch.meshgrid(_, _) u, v = u.flatten(), v.flatten() up, vp = match[u], match[v] label = Lpos[up, vp] # if mode == "training": if self.training: c = torch.zeros_like(label, dtype=torch.bool) # sample positive lines cdx = label.nonzero().flatten() if len(cdx) > self.n_dyn_posl: # print("too many positive lines") perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl] cdx = cdx[perm] c[cdx] = 1 # sample negative lines cdx = Lneg[up, vp].nonzero().flatten() if len(cdx) > self.n_dyn_negl: # print("too many negative lines") perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl] cdx = cdx[perm] c[cdx] = 1 # sample other (unmatched) lines cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device) c[cdx] = 1 else: c = (u < v).flatten() # sample lines u, v, label = u[c], v[c], label[c] xy = xy.reshape(n_type * K, 2) xyu, xyv = xy[u], xy[v] u2v = xyu - xyv u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6) feat = torch.cat( [ xyu / 128 * self.use_cood, xyv / 128 * self.use_cood, u2v * self.use_slop, (u[:, None] > K).float(), (v[:, None] > K).float(), ], 1, ) line = torch.cat([xyu[:, None], xyv[:, None]], 1) xy = xy.reshape(n_type, K, 2) jcs = [xy[i, score[i] > 0.03] for i in range(n_type)] return line, label.float(), feat, jcs _COMMON_META = { "categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, "min_size": (1, 1), }