123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- from typing import Any, Optional
- import torch
- from torch import nn
- from torchvision.ops import MultiScaleRoIAlign
- from libs.vision_libs.ops import misc as misc_nn_ops
- from libs.vision_libs.transforms._presets import ObjectDetection
- from .roi_heads import RoIHeads
- from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
- from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
- from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
- from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
- from libs.vision_libs.models.detection._utils import overwrite_eps
- from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
- from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
- from models.config.config_tool import read_yaml
- import numpy as np
- import torch.nn.functional as F
- FEATURE_DIM = 8
- def non_maximum_suppression(a):
- ap = F.max_pool2d(a, 3, stride=1, padding=1)
- mask = (a == ap).float().clamp(min=0.0)
- return a * mask
- class Bottleneck1D(nn.Module):
- def __init__(self, inplanes, outplanes):
- super(Bottleneck1D, self).__init__()
- planes = outplanes // 2
- self.op = nn.Sequential(
- nn.BatchNorm1d(inplanes),
- nn.ReLU(inplace=True),
- nn.Conv1d(inplanes, planes, kernel_size=1),
- nn.BatchNorm1d(planes),
- nn.ReLU(inplace=True),
- nn.Conv1d(planes, planes, kernel_size=3, padding=1),
- nn.BatchNorm1d(planes),
- nn.ReLU(inplace=True),
- nn.Conv1d(planes, outplanes, kernel_size=1),
- )
- def forward(self, x):
- return x + self.op(x)
- class LineRCNNPredictor(nn.Module):
- def __init__(self,n_pts0 = 32,
- n_pts1 = 8,
- n_stc_posl =300,
- dim_loi = 128,
- use_conv = 0,
- dim_fc = 1024,
- n_out_line = 2500,
- n_out_junc =250,
- n_dyn_junc = 300,
- eval_junc_thres = 0.008,
- n_dyn_posl =300,
- n_dyn_negl =80,
- n_dyn_othr = 600,
- use_cood = 0,
- use_slop = 0,
- n_stc_negl = 40,
- head_size = [[2], [1], [2]] ,
- **kwargs):
- super().__init__()
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- self.n_pts0 = n_pts0
- self.n_pts1 = n_pts1
- self.n_stc_posl =n_stc_posl
- self.dim_loi = dim_loi
- self.use_conv = use_conv
- self.dim_fc = dim_fc
- self.n_out_line = n_out_line
- self.n_out_junc =n_out_junc
-
- self.n_dyn_junc = n_dyn_junc
- self.eval_junc_thres = eval_junc_thres
- self.n_dyn_posl =n_dyn_posl
- self.n_dyn_negl = n_dyn_negl
- self.n_dyn_othr = n_dyn_othr
- self.use_cood = use_cood
- self.use_slop = use_slop
- self.n_stc_negl = n_stc_negl
- self.head_size = head_size
- self.num_class = sum(sum(self.head_size, []))
- self.head_off = np.cumsum([sum(h) for h in self.head_size])
- lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
- self.register_buffer("lambda_", lambda_)
- self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
- self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
- scale_factor = self.n_pts0 // self.n_pts1
- if self.use_conv:
- self.pooling = nn.Sequential(
- nn.MaxPool1d(scale_factor, scale_factor),
- Bottleneck1D(self.dim_loi, self.dim_loi),
- )
- self.fc2 = nn.Sequential(
- nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
- )
- else:
- self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
- self.fc2 = nn.Sequential(
- nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
- nn.ReLU(inplace=True),
- nn.Linear(self.dim_fc, self.dim_fc),
- nn.ReLU(inplace=True),
- nn.Linear(self.dim_fc, 1),
- )
- self.loss = nn.BCEWithLogitsLoss(reduction="none")
- def forward(self, inputs, features, targets=None):
-
-
-
-
- batch, channel, row, col = inputs.shape
-
-
- if targets is not None:
- self.training = True
-
- wires_targets = [t["wires"] for t in targets]
-
-
- junc_maps = [d["junc_map"] for d in wires_targets]
- junc_offsets = [d["junc_offset"] for d in wires_targets]
- line_maps = [d["line_map"] for d in wires_targets]
- junc_map_tensor = torch.stack(junc_maps, dim=0)
- junc_offset_tensor = torch.stack(junc_offsets, dim=0)
- line_map_tensor = torch.stack(line_maps, dim=0)
- wires_meta = {
- "junc_map": junc_map_tensor,
- "junc_offset": junc_offset_tensor,
-
- }
- else:
- self.training = False
- t = {
- "junc_coords": torch.zeros(1, 2),
- "jtyp": torch.zeros(1, dtype=torch.uint8),
- "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
- "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
- "junc_map": torch.zeros([1, 1, 128, 128]),
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
- }
- wires_targets = [t for b in range(inputs.size(0))]
- wires_meta = {
- "junc_map": torch.zeros([1, 1, 128, 128]),
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
- }
- T = wires_meta.copy()
- n_jtyp = T["junc_map"].shape[1]
- offset = self.head_off
- result = {}
- for stack, output in enumerate([inputs]):
- output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
-
- jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
- lmap = output[offset[0]: offset[1]].squeeze(0)
- joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
- if stack == 0:
- result["preds"] = {
- "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
- "lmap": lmap.sigmoid(),
- "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
- }
-
-
-
- h = result["preds"]
-
- x = self.fc1(features)
-
- n_batch, n_channel, row, col = x.shape
-
- xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
- for i, meta in enumerate(wires_targets):
- p, label, feat, jc = self.sample_lines(
- meta, h["jmap"][i], h["joff"][i],
- )
-
- ys.append(label)
- if self.training and self.do_static_sampling:
- p = torch.cat([p, meta["lpre"]])
- feat = torch.cat([feat, meta["lpre_feat"]])
- ys.append(meta["lpre_label"])
- del jc
- else:
- jcs.append(jc)
- ps.append(p)
- fs.append(feat)
- p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
- p = p.reshape(-1, 2)
- px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
- px0 = px.floor().clamp(min=0, max=127)
- py0 = py.floor().clamp(min=0, max=127)
- px1 = (px0 + 1).clamp(min=0, max=127)
- py1 = (py0 + 1).clamp(min=0, max=127)
- px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
-
- xp = (
- (
- x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
- + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
- + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
- + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
- )
- .reshape(n_channel, -1, self.n_pts0)
- .permute(1, 0, 2)
- )
- xp = self.pooling(xp)
-
- xs.append(xp)
- idx.append(idx[-1] + xp.shape[0])
-
- x, y = torch.cat(xs), torch.cat(ys)
- f = torch.cat(fs)
- x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-
- x = torch.cat([x, f], 1)
-
- x = x.to(dtype=torch.float32)
-
- x = self.fc2(x).flatten()
-
- return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
-
-
-
- def sample_lines(self, meta, jmap, joff):
- device = jmap.device
- with torch.no_grad():
- junc = meta["junc_coords"].to(device)
- jtyp = meta["jtyp"].to(device)
- Lpos = meta["line_pos_idx"].to(device)
- Lneg = meta["line_neg_idx"].to(device)
- n_type = jmap.shape[0]
- jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
- joff = joff.reshape(n_type, 2, -1)
- max_K = self.n_dyn_junc // n_type
- N = len(junc)
-
- if not self.training:
- K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
- else:
- K = min(int(N * 2 + 2), max_K)
- if K < 2:
- K = 2
- device = jmap.device
-
- score, index = torch.topk(jmap, k=K)
- y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
- x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
-
- xy = torch.cat([y[..., None], x[..., None]], dim=-1)
- xy_ = xy[..., None, :]
- del x, y, index
-
- dist = torch.sum((xy_ - junc) ** 2, -1)
- cost, match = torch.min(dist, -1)
-
-
- for t in range(n_type):
- match[t, jtyp[match[t]] != t] = N
- match[cost > 1.5 * 1.5] = N
- match = match.flatten()
- _ = torch.arange(n_type * K, device=device)
- u, v = torch.meshgrid(_, _)
- u, v = u.flatten(), v.flatten()
- up, vp = match[u], match[v]
- label = Lpos[up, vp]
-
- if self.training:
- c = torch.zeros_like(label, dtype=torch.bool)
-
- cdx = label.nonzero().flatten()
- if len(cdx) > self.n_dyn_posl:
-
- perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
- cdx = cdx[perm]
- c[cdx] = 1
-
- cdx = Lneg[up, vp].nonzero().flatten()
- if len(cdx) > self.n_dyn_negl:
-
- perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
- cdx = cdx[perm]
- c[cdx] = 1
-
- cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
- c[cdx] = 1
- else:
- c = (u < v).flatten()
-
- u, v, label = u[c], v[c], label[c]
- xy = xy.reshape(n_type * K, 2)
- xyu, xyv = xy[u], xy[v]
- u2v = xyu - xyv
- u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
- feat = torch.cat(
- [
- xyu / 128 * self.use_cood,
- xyv / 128 * self.use_cood,
- u2v * self.use_slop,
- (u[:, None] > K).float(),
- (v[:, None] > K).float(),
- ],
- 1,
- )
- line = torch.cat([xyu[:, None], xyv[:, None]], 1)
- xy = xy.reshape(n_type, K, 2)
- jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
- return line, label.float(), feat, jcs
- _COMMON_META = {
- "categories": _COCO_PERSON_CATEGORIES,
- "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
- "min_size": (1, 1),
- }
|