line_rcnn.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020
  1. from typing import Any, Optional
  2. import torch
  3. from torch import nn
  4. from torchvision.ops import MultiScaleRoIAlign
  5. from libs.vision_libs.ops import misc as misc_nn_ops
  6. from libs.vision_libs.transforms._presets import ObjectDetection
  7. from .roi_heads import RoIHeads
  8. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  9. from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
  10. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  11. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
  12. from libs.vision_libs.models.detection._utils import overwrite_eps
  13. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  14. from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
  15. from models.config.config_tool import read_yaml
  16. import numpy as np
  17. import torch.nn.functional as F
  18. FEATURE_DIM = 8
  19. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  20. __all__ = [
  21. "LineRCNN",
  22. "LineRCNN_ResNet50_FPN_Weights",
  23. "linercnn_resnet50_fpn",
  24. ]
  25. def non_maximum_suppression(a):
  26. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  27. mask = (a == ap).float().clamp(min=0.0)
  28. return a * mask
  29. class Bottleneck1D(nn.Module):
  30. def __init__(self, inplanes, outplanes):
  31. super(Bottleneck1D, self).__init__()
  32. planes = outplanes // 2
  33. self.op = nn.Sequential(
  34. nn.BatchNorm1d(inplanes),
  35. nn.ReLU(inplace=True),
  36. nn.Conv1d(inplanes, planes, kernel_size=1),
  37. nn.BatchNorm1d(planes),
  38. nn.ReLU(inplace=True),
  39. nn.Conv1d(planes, planes, kernel_size=3, padding=1),
  40. nn.BatchNorm1d(planes),
  41. nn.ReLU(inplace=True),
  42. nn.Conv1d(planes, outplanes, kernel_size=1),
  43. )
  44. def forward(self, x):
  45. return x + self.op(x)
  46. class LineRCNN(FasterRCNN):
  47. """
  48. Implements Keypoint R-CNN.
  49. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  50. image, and should be in 0-1 range. Different images can have different sizes.
  51. The behavior of the model changes depending on if it is in training or evaluation mode.
  52. During training, the model expects both the input tensors and targets (list of dictionary),
  53. containing:
  54. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  55. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  56. - labels (Int64Tensor[N]): the class label for each ground-truth box
  57. - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
  58. format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
  59. The model returns a Dict[Tensor] during training, containing the classification and regression
  60. losses for both the RPN and the R-CNN, and the keypoint loss.
  61. During inference, the model requires only the input tensors, and returns the post-processed
  62. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  63. follows:
  64. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  65. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  66. - labels (Int64Tensor[N]): the predicted labels for each image
  67. - scores (Tensor[N]): the scores or each prediction
  68. - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
  69. Args:
  70. backbone (nn.Module): the network used to compute the features for the model.
  71. It should contain an out_channels attribute, which indicates the number of output
  72. channels that each feature map has (and it should be the same for all feature maps).
  73. The backbone should return a single Tensor or and OrderedDict[Tensor].
  74. num_classes (int): number of output classes of the model (including the background).
  75. If box_predictor is specified, num_classes should be None.
  76. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  77. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  78. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  79. They are generally the mean values of the dataset on which the backbone has been trained
  80. on
  81. image_std (Tuple[float, float, float]): std values used for input normalization.
  82. They are generally the std values of the dataset on which the backbone has been trained on
  83. rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  84. maps.
  85. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
  86. rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
  87. rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
  88. rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
  89. rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
  90. rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  91. rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  92. considered as positive during training of the RPN.
  93. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  94. considered as negative during training of the RPN.
  95. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  96. for computing the loss
  97. rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
  98. of the RPN
  99. rpn_score_thresh (float): during inference, only return proposals with a classification score
  100. greater than rpn_score_thresh
  101. box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  102. the locations indicated by the bounding boxes
  103. box_head (nn.Module): module that takes the cropped feature maps as input
  104. box_predictor (nn.Module): module that takes the output of box_head and returns the
  105. classification logits and box regression deltas.
  106. box_score_thresh (float): during inference, only return proposals with a classification score
  107. greater than box_score_thresh
  108. box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
  109. box_detections_per_img (int): maximum number of detections per image, for all classes.
  110. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
  111. considered as positive during training of the classification head
  112. box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
  113. considered as negative during training of the classification head
  114. box_batch_size_per_image (int): number of proposals that are sampled during training of the
  115. classification head
  116. box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
  117. of the classification head
  118. bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
  119. bounding boxes
  120. keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  121. the locations indicated by the bounding boxes, which will be used for the keypoint head.
  122. keypoint_head (nn.Module): module that takes the cropped feature maps as input
  123. keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
  124. heatmap logits
  125. Example::
  126. >>> import torch
  127. >>> import torchvision
  128. >>> from torchvision.models.detection import KeypointRCNN
  129. >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
  130. >>>
  131. >>> # load a pre-trained model for classification and return
  132. >>> # only the features
  133. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  134. >>> # KeypointRCNN needs to know the number of
  135. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  136. >>> # so we need to add it here
  137. >>> backbone.out_channels = 1280
  138. >>>
  139. >>> # let's make the RPN generate 5 x 3 anchors per spatial
  140. >>> # location, with 5 different sizes and 3 different aspect
  141. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  142. >>> # map could potentially have different sizes and
  143. >>> # aspect ratios
  144. >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  145. >>> aspect_ratios=((0.5, 1.0, 2.0),))
  146. >>>
  147. >>> # let's define what are the feature maps that we will
  148. >>> # use to perform the region of interest cropping, as well as
  149. >>> # the size of the crop after rescaling.
  150. >>> # if your backbone returns a Tensor, featmap_names is expected to
  151. >>> # be ['0']. More generally, the backbone should return an
  152. >>> # OrderedDict[Tensor], and in featmap_names you can choose which
  153. >>> # feature maps to use.
  154. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  155. >>> output_size=7,
  156. >>> sampling_ratio=2)
  157. >>>
  158. >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  159. >>> output_size=14,
  160. >>> sampling_ratio=2)
  161. >>> # put the pieces together inside a KeypointRCNN model
  162. >>> model = KeypointRCNN(backbone,
  163. >>> num_classes=2,
  164. >>> rpn_anchor_generator=anchor_generator,
  165. >>> box_roi_pool=roi_pooler,
  166. >>> keypoint_roi_pool=keypoint_roi_pooler)
  167. >>> model.eval()
  168. >>> model.eval()
  169. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  170. >>> predictions = model(x)
  171. """
  172. def __init__(
  173. self,
  174. backbone,
  175. num_classes=None,
  176. # transform parameters
  177. min_size=512, # 原为None
  178. max_size=1333,
  179. image_mean=None,
  180. image_std=None,
  181. # RPN parameters
  182. rpn_anchor_generator=None,
  183. rpn_head=None,
  184. rpn_pre_nms_top_n_train=2000,
  185. rpn_pre_nms_top_n_test=1000,
  186. rpn_post_nms_top_n_train=2000,
  187. rpn_post_nms_top_n_test=1000,
  188. rpn_nms_thresh=0.7,
  189. rpn_fg_iou_thresh=0.7,
  190. rpn_bg_iou_thresh=0.3,
  191. rpn_batch_size_per_image=256,
  192. rpn_positive_fraction=0.5,
  193. rpn_score_thresh=0.0,
  194. # Box parameters
  195. box_roi_pool=None,
  196. box_head=None,
  197. box_predictor=None,
  198. box_score_thresh=0.05,
  199. box_nms_thresh=0.5,
  200. box_detections_per_img=100,
  201. box_fg_iou_thresh=0.5,
  202. box_bg_iou_thresh=0.5,
  203. box_batch_size_per_image=512,
  204. box_positive_fraction=0.25,
  205. bbox_reg_weights=None,
  206. # line parameters
  207. line_head=None,
  208. line_predictor=None,
  209. **kwargs,
  210. ):
  211. # if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
  212. # raise TypeError(
  213. # "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
  214. # )
  215. # if min_size is None:
  216. # min_size = (640, 672, 704, 736, 768, 800)
  217. #
  218. # if num_keypoints is not None:
  219. # if keypoint_predictor is not None:
  220. # raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
  221. # else:
  222. # num_keypoints = 17
  223. out_channels = backbone.out_channels
  224. if line_head is None:
  225. # keypoint_layers = tuple(512 for _ in range(8))
  226. num_class = 5
  227. line_head = LineRCNNHeads(out_channels, num_class)
  228. if line_predictor is None:
  229. keypoint_dim_reduced = 512 # == keypoint_layers[-1]
  230. line_predictor = LineRCNNPredictor()
  231. super().__init__(
  232. backbone,
  233. num_classes,
  234. # transform parameters
  235. min_size,
  236. max_size,
  237. image_mean,
  238. image_std,
  239. # RPN-specific parameters
  240. rpn_anchor_generator,
  241. rpn_head,
  242. rpn_pre_nms_top_n_train,
  243. rpn_pre_nms_top_n_test,
  244. rpn_post_nms_top_n_train,
  245. rpn_post_nms_top_n_test,
  246. rpn_nms_thresh,
  247. rpn_fg_iou_thresh,
  248. rpn_bg_iou_thresh,
  249. rpn_batch_size_per_image,
  250. rpn_positive_fraction,
  251. rpn_score_thresh,
  252. # Box parameters
  253. box_roi_pool,
  254. box_head,
  255. box_predictor,
  256. box_score_thresh,
  257. box_nms_thresh,
  258. box_detections_per_img,
  259. box_fg_iou_thresh,
  260. box_bg_iou_thresh,
  261. box_batch_size_per_image,
  262. box_positive_fraction,
  263. bbox_reg_weights,
  264. **kwargs,
  265. )
  266. if box_roi_pool is None:
  267. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
  268. if box_head is None:
  269. resolution = box_roi_pool.output_size[0]
  270. representation_size = 1024
  271. box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  272. if box_predictor is None:
  273. representation_size = 1024
  274. box_predictor = FastRCNNPredictor(representation_size, num_classes)
  275. roi_heads = RoIHeads(
  276. # Box
  277. box_roi_pool,
  278. box_head,
  279. box_predictor,
  280. line_head,
  281. line_predictor,
  282. box_fg_iou_thresh,
  283. box_bg_iou_thresh,
  284. box_batch_size_per_image,
  285. box_positive_fraction,
  286. bbox_reg_weights,
  287. box_score_thresh,
  288. box_nms_thresh,
  289. box_detections_per_img,
  290. )
  291. # super().roi_heads = roi_heads
  292. self.roi_heads = roi_heads
  293. self.roi_heads.line_head = line_head
  294. self.roi_heads.line_predictor = line_predictor
  295. class LineRCNNHeads(nn.Sequential):
  296. def __init__(self, input_channels, num_class):
  297. super(LineRCNNHeads, self).__init__()
  298. # print("输入的维度是:", input_channels)
  299. m = int(input_channels / 4)
  300. heads = []
  301. self.head_size = [[2], [1], [2]]
  302. for output_channels in sum(self.head_size, []):
  303. heads.append(
  304. nn.Sequential(
  305. nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
  306. nn.ReLU(inplace=True),
  307. nn.Conv2d(m, output_channels, kernel_size=1),
  308. )
  309. )
  310. self.heads = nn.ModuleList(heads)
  311. assert num_class == sum(sum(self.head_size, []))
  312. def forward(self, x):
  313. return torch.cat([head(x) for head in self.heads], dim=1)
  314. # def __init__(self, in_channels, layers):
  315. # d = []
  316. # next_feature = in_channels
  317. # for out_channels in layers:
  318. # d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
  319. # d.append(nn.ReLU(inplace=True))
  320. # next_feature = out_channels
  321. # super().__init__(*d)
  322. # for m in self.children():
  323. # if isinstance(m, nn.Conv2d):
  324. # nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  325. # nn.init.constant_(m.bias, 0)
  326. class LineRCNNPredictor(nn.Module):
  327. def __init__(self):
  328. super().__init__()
  329. # self.backbone = backbone
  330. # self.cfg = read_yaml(cfg)
  331. self.cfg = read_yaml(r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml')
  332. self.n_pts0 = self.cfg['model']['n_pts0']
  333. self.n_pts1 = self.cfg['model']['n_pts1']
  334. self.n_stc_posl = self.cfg['model']['n_stc_posl']
  335. self.dim_loi = self.cfg['model']['dim_loi']
  336. self.use_conv = self.cfg['model']['use_conv']
  337. self.dim_fc = self.cfg['model']['dim_fc']
  338. self.n_out_line = self.cfg['model']['n_out_line']
  339. self.n_out_junc = self.cfg['model']['n_out_junc']
  340. self.loss_weight = self.cfg['model']['loss_weight']
  341. self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
  342. self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
  343. self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
  344. self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
  345. self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
  346. self.use_cood = self.cfg['model']['use_cood']
  347. self.use_slop = self.cfg['model']['use_slop']
  348. self.n_stc_negl = self.cfg['model']['n_stc_negl']
  349. self.head_size = self.cfg['model']['head_size']
  350. self.num_class = sum(sum(self.head_size, []))
  351. self.head_off = np.cumsum([sum(h) for h in self.head_size])
  352. lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
  353. self.register_buffer("lambda_", lambda_)
  354. self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0
  355. self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
  356. scale_factor = self.n_pts0 // self.n_pts1
  357. if self.use_conv:
  358. self.pooling = nn.Sequential(
  359. nn.MaxPool1d(scale_factor, scale_factor),
  360. Bottleneck1D(self.dim_loi, self.dim_loi),
  361. )
  362. self.fc2 = nn.Sequential(
  363. nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
  364. )
  365. else:
  366. self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
  367. self.fc2 = nn.Sequential(
  368. nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
  369. nn.ReLU(inplace=True),
  370. nn.Linear(self.dim_fc, self.dim_fc),
  371. nn.ReLU(inplace=True),
  372. nn.Linear(self.dim_fc, 1),
  373. )
  374. self.loss = nn.BCEWithLogitsLoss(reduction="none")
  375. def forward(self, inputs, features, targets=None):
  376. # outputs, features = input
  377. # for out in outputs:
  378. # print(f'out:{out.shape}')
  379. # outputs=merge_features(outputs,100)
  380. batch, channel, row, col = inputs.shape
  381. # print(f'outputs:{inputs.shape}')
  382. # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
  383. if targets is not None:
  384. self.training = True
  385. # print(f'target:{targets}')
  386. wires_targets = [t["wires"] for t in targets]
  387. # print(f'wires_target:{wires_targets}')
  388. # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  389. junc_maps = [d["junc_map"] for d in wires_targets]
  390. junc_offsets = [d["junc_offset"] for d in wires_targets]
  391. line_maps = [d["line_map"] for d in wires_targets]
  392. junc_map_tensor = torch.stack(junc_maps, dim=0)
  393. junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  394. line_map_tensor = torch.stack(line_maps, dim=0)
  395. wires_meta = {
  396. "junc_map": junc_map_tensor,
  397. "junc_offset": junc_offset_tensor,
  398. # "line_map": line_map_tensor,
  399. }
  400. else:
  401. self.training = False
  402. t = {
  403. "junc_coords": torch.zeros(1, 2),
  404. "jtyp": torch.zeros(1, dtype=torch.uint8),
  405. "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
  406. "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
  407. "junc_map": torch.zeros([1, 1, 128, 128]),
  408. "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
  409. }
  410. wires_targets = [t for b in range(inputs.size(0))]
  411. wires_meta = {
  412. "junc_map": torch.zeros([1, 1, 128, 128]),
  413. "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
  414. }
  415. T = wires_meta.copy()
  416. n_jtyp = T["junc_map"].shape[1]
  417. offset = self.head_off
  418. result = {}
  419. for stack, output in enumerate([inputs]):
  420. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  421. # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
  422. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  423. lmap = output[offset[0]: offset[1]].squeeze(0)
  424. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  425. if stack == 0:
  426. result["preds"] = {
  427. "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  428. "lmap": lmap.sigmoid(),
  429. "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  430. }
  431. # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
  432. # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
  433. # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
  434. h = result["preds"]
  435. # print(f'features shape:{features.shape}')
  436. x = self.fc1(features)
  437. # print(f'x:{x.shape}')
  438. n_batch, n_channel, row, col = x.shape
  439. # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
  440. xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
  441. for i, meta in enumerate(wires_targets):
  442. p, label, feat, jc = self.sample_lines(
  443. meta, h["jmap"][i], h["joff"][i],
  444. )
  445. # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
  446. ys.append(label)
  447. if self.training and self.do_static_sampling:
  448. p = torch.cat([p, meta["lpre"]])
  449. feat = torch.cat([feat, meta["lpre_feat"]])
  450. ys.append(meta["lpre_label"])
  451. del jc
  452. else:
  453. jcs.append(jc)
  454. ps.append(p)
  455. fs.append(feat)
  456. p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
  457. p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
  458. px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
  459. px0 = px.floor().clamp(min=0, max=127)
  460. py0 = py.floor().clamp(min=0, max=127)
  461. px1 = (px0 + 1).clamp(min=0, max=127)
  462. py1 = (py0 + 1).clamp(min=0, max=127)
  463. px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
  464. # xp: [N_LINE, N_CHANNEL, N_POINT]
  465. xp = (
  466. (
  467. x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
  468. + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
  469. + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
  470. + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
  471. )
  472. .reshape(n_channel, -1, self.n_pts0)
  473. .permute(1, 0, 2)
  474. )
  475. xp = self.pooling(xp)
  476. # print(f'xp.shape:{xp.shape}')
  477. xs.append(xp)
  478. idx.append(idx[-1] + xp.shape[0])
  479. # print(f'idx__:{idx}')
  480. x, y = torch.cat(xs), torch.cat(ys)
  481. f = torch.cat(fs)
  482. x = x.reshape(-1, self.n_pts1 * self.dim_loi)
  483. # print("Weight dtype:", self.fc2.weight.dtype)
  484. x = torch.cat([x, f], 1)
  485. # print("Input dtype:", x.dtype)
  486. x = x.to(dtype=torch.float32)
  487. # print("Input dtype1:", x.dtype)
  488. x = self.fc2(x).flatten()
  489. # return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
  490. return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
  491. # if mode != "training":
  492. # self.inference(x, idx, jcs, n_batch, ps)
  493. # return result
  494. def sample_lines(self, meta, jmap, joff):
  495. with torch.no_grad():
  496. junc = meta["junc_coords"] # [N, 2]
  497. jtyp = meta["jtyp"] # [N]
  498. Lpos = meta["line_pos_idx"]
  499. Lneg = meta["line_neg_idx"]
  500. n_type = jmap.shape[0]
  501. jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
  502. joff = joff.reshape(n_type, 2, -1)
  503. max_K = self.n_dyn_junc // n_type
  504. N = len(junc)
  505. # if mode != "training":
  506. if not self.training:
  507. K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
  508. else:
  509. K = min(int(N * 2 + 2), max_K)
  510. if K < 2:
  511. K = 2
  512. device = jmap.device
  513. # index: [N_TYPE, K]
  514. score, index = torch.topk(jmap, k=K)
  515. y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
  516. x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
  517. # xy: [N_TYPE, K, 2]
  518. xy = torch.cat([y[..., None], x[..., None]], dim=-1)
  519. xy_ = xy[..., None, :]
  520. del x, y, index
  521. # dist: [N_TYPE, K, N]
  522. dist = torch.sum((xy_ - junc) ** 2, -1)
  523. cost, match = torch.min(dist, -1)
  524. # xy: [N_TYPE * K, 2]
  525. # match: [N_TYPE, K]
  526. for t in range(n_type):
  527. match[t, jtyp[match[t]] != t] = N
  528. match[cost > 1.5 * 1.5] = N
  529. match = match.flatten()
  530. _ = torch.arange(n_type * K, device=device)
  531. u, v = torch.meshgrid(_, _)
  532. u, v = u.flatten(), v.flatten()
  533. up, vp = match[u], match[v]
  534. label = Lpos[up, vp]
  535. # if mode == "training":
  536. if self.training:
  537. c = torch.zeros_like(label, dtype=torch.bool)
  538. # sample positive lines
  539. cdx = label.nonzero().flatten()
  540. if len(cdx) > self.n_dyn_posl:
  541. # print("too many positive lines")
  542. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
  543. cdx = cdx[perm]
  544. c[cdx] = 1
  545. # sample negative lines
  546. cdx = Lneg[up, vp].nonzero().flatten()
  547. if len(cdx) > self.n_dyn_negl:
  548. # print("too many negative lines")
  549. perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
  550. cdx = cdx[perm]
  551. c[cdx] = 1
  552. # sample other (unmatched) lines
  553. cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
  554. c[cdx] = 1
  555. else:
  556. c = (u < v).flatten()
  557. # sample lines
  558. u, v, label = u[c], v[c], label[c]
  559. xy = xy.reshape(n_type * K, 2)
  560. xyu, xyv = xy[u], xy[v]
  561. u2v = xyu - xyv
  562. u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
  563. feat = torch.cat(
  564. [
  565. xyu / 128 * self.use_cood,
  566. xyv / 128 * self.use_cood,
  567. u2v * self.use_slop,
  568. (u[:, None] > K).float(),
  569. (v[:, None] > K).float(),
  570. ],
  571. 1,
  572. )
  573. line = torch.cat([xyu[:, None], xyv[:, None]], 1)
  574. xy = xy.reshape(n_type, K, 2)
  575. jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
  576. return line, label.float(), feat, jcs
  577. # def forward(self, result, targets=None):
  578. #
  579. # # result = self.backbone(input_dict)
  580. # h = result["preds"]
  581. # x = self.fc1(result["feature"])
  582. # n_batch, n_channel, row, col = x.shape
  583. #
  584. # if targets is not None:
  585. # self.training = True
  586. # # print(f'target:{targets}')
  587. # wires_targets = [t["wires"] for t in targets]
  588. # # print(f'wires_target:{wires_targets}')
  589. # # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  590. # junc_maps = [d["junc_map"] for d in wires_targets]
  591. # junc_offsets = [d["junc_offset"] for d in wires_targets]
  592. # line_maps = [d["line_map"] for d in wires_targets]
  593. #
  594. # junc_map_tensor = torch.stack(junc_maps, dim=0)
  595. # junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  596. # line_map_tensor = torch.stack(line_maps, dim=0)
  597. #
  598. # wires_meta = {
  599. # "junc_map": junc_map_tensor,
  600. # "junc_offset": junc_offset_tensor,
  601. # # "line_map": line_map_tensor,
  602. # }
  603. # else:
  604. # self.training = False
  605. # # self.training = False
  606. # t = {
  607. # "junc_coords": torch.zeros(1, 2).to(device),
  608. # "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
  609. # "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  610. # "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  611. # "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
  612. # "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
  613. # }
  614. # wires_targets = [t for b in range(inputs.size(0))]
  615. #
  616. # wires_meta = {
  617. # "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
  618. # "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
  619. # }
  620. #
  621. # xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
  622. # for i, meta in enumerate(input_dict["meta"]):
  623. # p, label, feat, jc = self.sample_lines(
  624. # meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
  625. # )
  626. # # print("p.shape:", p.shape)
  627. # ys.append(label)
  628. # if input_dict["mode"] == "training" and self.do_static_sampling:
  629. # p = torch.cat([p, meta["lpre"]])
  630. # feat = torch.cat([feat, meta["lpre_feat"]])
  631. # ys.append(meta["lpre_label"])
  632. # del jc
  633. # else:
  634. # jcs.append(jc)
  635. # ps.append(p)
  636. # fs.append(feat)
  637. #
  638. # p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
  639. # p = p.reshape(-1, 2) # [N_LINE x N_POINT, 2_XY]
  640. # px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
  641. # px0 = px.floor().clamp(min=0, max=127)
  642. # py0 = py.floor().clamp(min=0, max=127)
  643. # px1 = (px0 + 1).clamp(min=0, max=127)
  644. # py1 = (py0 + 1).clamp(min=0, max=127)
  645. # px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
  646. #
  647. # # xp: [N_LINE, N_CHANNEL, N_POINT]
  648. # xp = (
  649. # (
  650. # x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
  651. # + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
  652. # + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
  653. # + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
  654. # )
  655. # .reshape(n_channel, -1, M.n_pts0)
  656. # .permute(1, 0, 2)
  657. # )
  658. # xp = self.pooling(xp)
  659. # xs.append(xp)
  660. # idx.append(idx[-1] + xp.shape[0])
  661. #
  662. #
  663. # x, y = torch.cat(xs), torch.cat(ys)
  664. # f = torch.cat(fs)
  665. # x = x.reshape(-1, self.n_pts1 * self.dim_loi)
  666. # x = torch.cat([x, f], 1)
  667. # x = x.to(dtype=torch.float32)
  668. # x = self.fc2(x).flatten()
  669. #
  670. # # return x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
  671. # all=[x, ys, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc]
  672. # return all
  673. # # return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc
  674. #
  675. # # if mode != "training":
  676. # # self.inference(x, idx, jcs, n_batch, ps)
  677. #
  678. # # return result
  679. #
  680. # def sample_lines(self, meta, jmap, joff):
  681. # with torch.no_grad():
  682. # junc = meta["junc_coords"] # [N, 2]
  683. # jtyp = meta["jtyp"] # [N]
  684. # Lpos = meta["line_pos_idx"]
  685. # Lneg = meta["line_neg_idx"]
  686. #
  687. # n_type = jmap.shape[0]
  688. # jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
  689. # joff = joff.reshape(n_type, 2, -1)
  690. # max_K = self.n_dyn_junc // n_type
  691. # N = len(junc)
  692. # # if mode != "training":
  693. # if not self.training:
  694. # K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
  695. # else:
  696. # K = min(int(N * 2 + 2), max_K)
  697. # if K < 2:
  698. # K = 2
  699. # device = jmap.device
  700. #
  701. # # index: [N_TYPE, K]
  702. # score, index = torch.topk(jmap, k=K)
  703. # y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
  704. # x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
  705. #
  706. # # xy: [N_TYPE, K, 2]
  707. # xy = torch.cat([y[..., None], x[..., None]], dim=-1)
  708. # xy_ = xy[..., None, :]
  709. # del x, y, index
  710. #
  711. # # print(f"xy_.is_cuda: {xy_.is_cuda}")
  712. # # print(f"junc.is_cuda: {junc.is_cuda}")
  713. #
  714. # # dist: [N_TYPE, K, N]
  715. # dist = torch.sum((xy_ - junc) ** 2, -1)
  716. # cost, match = torch.min(dist, -1)
  717. #
  718. # # xy: [N_TYPE * K, 2]
  719. # # match: [N_TYPE, K]
  720. # for t in range(n_type):
  721. # match[t, jtyp[match[t]] != t] = N
  722. # match[cost > 1.5 * 1.5] = N
  723. # match = match.flatten()
  724. #
  725. # _ = torch.arange(n_type * K, device=device)
  726. # u, v = torch.meshgrid(_, _)
  727. # u, v = u.flatten(), v.flatten()
  728. # up, vp = match[u], match[v]
  729. # label = Lpos[up, vp]
  730. #
  731. # # if mode == "training":
  732. # if self.training:
  733. # c = torch.zeros_like(label, dtype=torch.bool)
  734. #
  735. # # sample positive lines
  736. # cdx = label.nonzero().flatten()
  737. # if len(cdx) > self.n_dyn_posl:
  738. # # print("too many positive lines")
  739. # perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
  740. # cdx = cdx[perm]
  741. # c[cdx] = 1
  742. #
  743. # # sample negative lines
  744. # cdx = Lneg[up, vp].nonzero().flatten()
  745. # if len(cdx) > self.n_dyn_negl:
  746. # # print("too many negative lines")
  747. # perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
  748. # cdx = cdx[perm]
  749. # c[cdx] = 1
  750. #
  751. # # sample other (unmatched) lines
  752. # cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
  753. # c[cdx] = 1
  754. # else:
  755. # c = (u < v).flatten()
  756. #
  757. # # sample lines
  758. # u, v, label = u[c], v[c], label[c]
  759. # xy = xy.reshape(n_type * K, 2)
  760. # xyu, xyv = xy[u], xy[v]
  761. #
  762. # u2v = xyu - xyv
  763. # u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
  764. # feat = torch.cat(
  765. # [
  766. # xyu / 128 * self.use_cood,
  767. # xyv / 128 * self.use_cood,
  768. # u2v * self.use_slop,
  769. # (u[:, None] > K).float(),
  770. # (v[:, None] > K).float(),
  771. # ],
  772. # 1,
  773. # )
  774. # line = torch.cat([xyu[:, None], xyv[:, None]], 1)
  775. #
  776. # xy = xy.reshape(n_type, K, 2)
  777. # jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
  778. # return line, label.float(), feat, jcs
  779. _COMMON_META = {
  780. "categories": _COCO_PERSON_CATEGORIES,
  781. "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
  782. "min_size": (1, 1),
  783. }
  784. class LineRCNN_ResNet50_FPN_Weights(WeightsEnum):
  785. COCO_LEGACY = Weights(
  786. url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
  787. transforms=ObjectDetection,
  788. meta={
  789. **_COMMON_META,
  790. "num_params": 59137258,
  791. "recipe": "https://github.com/pytorch/vision/issues/1606",
  792. "_metrics": {
  793. "COCO-val2017": {
  794. "box_map": 50.6,
  795. "kp_map": 61.1,
  796. }
  797. },
  798. "_ops": 133.924,
  799. "_file_size": 226.054,
  800. "_docs": """
  801. These weights were produced by following a similar training recipe as on the paper but use a checkpoint
  802. from an early epoch.
  803. """,
  804. },
  805. )
  806. COCO_V1 = Weights(
  807. url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
  808. transforms=ObjectDetection,
  809. meta={
  810. **_COMMON_META,
  811. "num_params": 59137258,
  812. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
  813. "_metrics": {
  814. "COCO-val2017": {
  815. "box_map": 54.6,
  816. "kp_map": 65.0,
  817. }
  818. },
  819. "_ops": 137.42,
  820. "_file_size": 226.054,
  821. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  822. },
  823. )
  824. DEFAULT = COCO_V1
  825. @register_model()
  826. @handle_legacy_interface(
  827. weights=(
  828. "pretrained",
  829. lambda kwargs: LineRCNN_ResNet50_FPN_Weights.COCO_LEGACY
  830. if kwargs["pretrained"] == "legacy"
  831. else LineRCNN_ResNet50_FPN_Weights.COCO_V1,
  832. ),
  833. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  834. )
  835. def linercnn_resnet50_fpn(
  836. *,
  837. weights: Optional[LineRCNN_ResNet50_FPN_Weights] = None,
  838. progress: bool = True,
  839. num_classes: Optional[int] = None,
  840. num_keypoints: Optional[int] = None,
  841. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  842. trainable_backbone_layers: Optional[int] = None,
  843. **kwargs: Any,
  844. ) -> LineRCNN:
  845. """
  846. Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
  847. .. betastatus:: detection module
  848. Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
  849. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  850. image, and should be in ``0-1`` range. Different images can have different sizes.
  851. The behavior of the model changes depending on if it is in training or evaluation mode.
  852. During training, the model expects both the input tensors and targets (list of dictionary),
  853. containing:
  854. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  855. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  856. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  857. - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
  858. format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
  859. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  860. losses for both the RPN and the R-CNN, and the keypoint loss.
  861. During inference, the model requires only the input tensors, and returns the post-processed
  862. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  863. follows, where ``N`` is the number of detected instances:
  864. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  865. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  866. - labels (``Int64Tensor[N]``): the predicted labels for each instance
  867. - scores (``Tensor[N]``): the scores or each instance
  868. - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
  869. For more details on the output, you may refer to :ref:`instance_seg_output`.
  870. Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  871. Example::
  872. >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
  873. >>> model.eval()
  874. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  875. >>> predictions = model(x)
  876. >>>
  877. >>> # optionally, if you want to export the model to ONNX:
  878. >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
  879. Args:
  880. weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
  881. pretrained weights to use. See
  882. :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
  883. below for more details, and possible values. By default, no
  884. pre-trained weights are used.
  885. progress (bool): If True, displays a progress bar of the download to stderr
  886. num_classes (int, optional): number of output classes of the model (including the background)
  887. num_keypoints (int, optional): number of keypoints
  888. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  889. pretrained weights for the backbone.
  890. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
  891. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
  892. passed (the default) this value is set to 3.
  893. .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
  894. :members:
  895. """
  896. weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
  897. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  898. if weights is not None:
  899. weights_backbone = None
  900. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  901. num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
  902. else:
  903. if num_classes is None:
  904. num_classes = 2
  905. if num_keypoints is None:
  906. num_keypoints = 17
  907. is_trained = weights is not None or weights_backbone is not None
  908. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  909. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  910. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  911. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  912. model = LineRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
  913. if weights is not None:
  914. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  915. if weights == LineRCNN_ResNet50_FPN_Weights.COCO_V1:
  916. overwrite_eps(model, 0.0)
  917. return model