build.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. # Copyright (c) Meta Platforms, Inc. and affiliates.
  3. # All rights reserved.
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from functools import partial
  7. import torch
  8. from ultralytics.utils.downloads import attempt_download_asset
  9. from .modules.decoders import MaskDecoder
  10. from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
  11. from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
  12. from .modules.sam import SAM2Model, SAMModel
  13. from .modules.tiny_encoder import TinyViT
  14. from .modules.transformer import TwoWayTransformer
  15. def build_sam_vit_h(checkpoint=None):
  16. """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
  17. return _build_sam(
  18. encoder_embed_dim=1280,
  19. encoder_depth=32,
  20. encoder_num_heads=16,
  21. encoder_global_attn_indexes=[7, 15, 23, 31],
  22. checkpoint=checkpoint,
  23. )
  24. def build_sam_vit_l(checkpoint=None):
  25. """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
  26. return _build_sam(
  27. encoder_embed_dim=1024,
  28. encoder_depth=24,
  29. encoder_num_heads=16,
  30. encoder_global_attn_indexes=[5, 11, 17, 23],
  31. checkpoint=checkpoint,
  32. )
  33. def build_sam_vit_b(checkpoint=None):
  34. """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
  35. return _build_sam(
  36. encoder_embed_dim=768,
  37. encoder_depth=12,
  38. encoder_num_heads=12,
  39. encoder_global_attn_indexes=[2, 5, 8, 11],
  40. checkpoint=checkpoint,
  41. )
  42. def build_mobile_sam(checkpoint=None):
  43. """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
  44. return _build_sam(
  45. encoder_embed_dim=[64, 128, 160, 320],
  46. encoder_depth=[2, 2, 6, 2],
  47. encoder_num_heads=[2, 4, 5, 10],
  48. encoder_global_attn_indexes=None,
  49. mobile_sam=True,
  50. checkpoint=checkpoint,
  51. )
  52. def build_sam2_t(checkpoint=None):
  53. """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
  54. return _build_sam2(
  55. encoder_embed_dim=96,
  56. encoder_stages=[1, 2, 7, 2],
  57. encoder_num_heads=1,
  58. encoder_global_att_blocks=[5, 7, 9],
  59. encoder_window_spec=[8, 4, 14, 7],
  60. encoder_backbone_channel_list=[768, 384, 192, 96],
  61. checkpoint=checkpoint,
  62. )
  63. def build_sam2_s(checkpoint=None):
  64. """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
  65. return _build_sam2(
  66. encoder_embed_dim=96,
  67. encoder_stages=[1, 2, 11, 2],
  68. encoder_num_heads=1,
  69. encoder_global_att_blocks=[7, 10, 13],
  70. encoder_window_spec=[8, 4, 14, 7],
  71. encoder_backbone_channel_list=[768, 384, 192, 96],
  72. checkpoint=checkpoint,
  73. )
  74. def build_sam2_b(checkpoint=None):
  75. """Builds and returns a SAM2 base-size model with specified architecture parameters."""
  76. return _build_sam2(
  77. encoder_embed_dim=112,
  78. encoder_stages=[2, 3, 16, 3],
  79. encoder_num_heads=2,
  80. encoder_global_att_blocks=[12, 16, 20],
  81. encoder_window_spec=[8, 4, 14, 7],
  82. encoder_window_spatial_size=[14, 14],
  83. encoder_backbone_channel_list=[896, 448, 224, 112],
  84. checkpoint=checkpoint,
  85. )
  86. def build_sam2_l(checkpoint=None):
  87. """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
  88. return _build_sam2(
  89. encoder_embed_dim=144,
  90. encoder_stages=[2, 6, 36, 4],
  91. encoder_num_heads=2,
  92. encoder_global_att_blocks=[23, 33, 43],
  93. encoder_window_spec=[8, 4, 16, 8],
  94. encoder_backbone_channel_list=[1152, 576, 288, 144],
  95. checkpoint=checkpoint,
  96. )
  97. def _build_sam(
  98. encoder_embed_dim,
  99. encoder_depth,
  100. encoder_num_heads,
  101. encoder_global_attn_indexes,
  102. checkpoint=None,
  103. mobile_sam=False,
  104. ):
  105. """
  106. Builds a Segment Anything Model (SAM) with specified encoder parameters.
  107. Args:
  108. encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
  109. encoder_depth (int | List[int]): Depth of the encoder.
  110. encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
  111. encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
  112. checkpoint (str | None): Path to the model checkpoint file.
  113. mobile_sam (bool): Whether to build a Mobile-SAM model.
  114. Returns:
  115. (SAMModel): A Segment Anything Model instance with the specified architecture.
  116. Examples:
  117. >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
  118. >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
  119. """
  120. prompt_embed_dim = 256
  121. image_size = 1024
  122. vit_patch_size = 16
  123. image_embedding_size = image_size // vit_patch_size
  124. image_encoder = (
  125. TinyViT(
  126. img_size=1024,
  127. in_chans=3,
  128. num_classes=1000,
  129. embed_dims=encoder_embed_dim,
  130. depths=encoder_depth,
  131. num_heads=encoder_num_heads,
  132. window_sizes=[7, 7, 14, 7],
  133. mlp_ratio=4.0,
  134. drop_rate=0.0,
  135. drop_path_rate=0.0,
  136. use_checkpoint=False,
  137. mbconv_expand_ratio=4.0,
  138. local_conv_size=3,
  139. layer_lr_decay=0.8,
  140. )
  141. if mobile_sam
  142. else ImageEncoderViT(
  143. depth=encoder_depth,
  144. embed_dim=encoder_embed_dim,
  145. img_size=image_size,
  146. mlp_ratio=4,
  147. norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
  148. num_heads=encoder_num_heads,
  149. patch_size=vit_patch_size,
  150. qkv_bias=True,
  151. use_rel_pos=True,
  152. global_attn_indexes=encoder_global_attn_indexes,
  153. window_size=14,
  154. out_chans=prompt_embed_dim,
  155. )
  156. )
  157. sam = SAMModel(
  158. image_encoder=image_encoder,
  159. prompt_encoder=PromptEncoder(
  160. embed_dim=prompt_embed_dim,
  161. image_embedding_size=(image_embedding_size, image_embedding_size),
  162. input_image_size=(image_size, image_size),
  163. mask_in_chans=16,
  164. ),
  165. mask_decoder=MaskDecoder(
  166. num_multimask_outputs=3,
  167. transformer=TwoWayTransformer(
  168. depth=2,
  169. embedding_dim=prompt_embed_dim,
  170. mlp_dim=2048,
  171. num_heads=8,
  172. ),
  173. transformer_dim=prompt_embed_dim,
  174. iou_head_depth=3,
  175. iou_head_hidden_dim=256,
  176. ),
  177. pixel_mean=[123.675, 116.28, 103.53],
  178. pixel_std=[58.395, 57.12, 57.375],
  179. )
  180. if checkpoint is not None:
  181. checkpoint = attempt_download_asset(checkpoint)
  182. with open(checkpoint, "rb") as f:
  183. state_dict = torch.load(f)
  184. sam.load_state_dict(state_dict)
  185. sam.eval()
  186. return sam
  187. def _build_sam2(
  188. encoder_embed_dim=1280,
  189. encoder_stages=[2, 6, 36, 4],
  190. encoder_num_heads=2,
  191. encoder_global_att_blocks=[7, 15, 23, 31],
  192. encoder_backbone_channel_list=[1152, 576, 288, 144],
  193. encoder_window_spatial_size=[7, 7],
  194. encoder_window_spec=[8, 4, 16, 8],
  195. checkpoint=None,
  196. ):
  197. """
  198. Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
  199. Args:
  200. encoder_embed_dim (int): Embedding dimension for the encoder.
  201. encoder_stages (List[int]): Number of blocks in each stage of the encoder.
  202. encoder_num_heads (int): Number of attention heads in the encoder.
  203. encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
  204. encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
  205. encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
  206. encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
  207. checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
  208. Returns:
  209. (SAM2Model): A configured and initialized SAM2 model.
  210. Examples:
  211. >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
  212. >>> sam2_model.eval()
  213. """
  214. image_encoder = ImageEncoder(
  215. trunk=Hiera(
  216. embed_dim=encoder_embed_dim,
  217. num_heads=encoder_num_heads,
  218. stages=encoder_stages,
  219. global_att_blocks=encoder_global_att_blocks,
  220. window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
  221. window_spec=encoder_window_spec,
  222. ),
  223. neck=FpnNeck(
  224. d_model=256,
  225. backbone_channel_list=encoder_backbone_channel_list,
  226. fpn_top_down_levels=[2, 3],
  227. fpn_interp_model="nearest",
  228. ),
  229. scalp=1,
  230. )
  231. memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
  232. memory_encoder = MemoryEncoder(out_dim=64)
  233. is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
  234. sam2 = SAM2Model(
  235. image_encoder=image_encoder,
  236. memory_attention=memory_attention,
  237. memory_encoder=memory_encoder,
  238. num_maskmem=7,
  239. image_size=1024,
  240. sigmoid_scale_for_mem_enc=20.0,
  241. sigmoid_bias_for_mem_enc=-10.0,
  242. use_mask_input_as_output_without_sam=True,
  243. directly_add_no_mem_embed=True,
  244. use_high_res_features_in_sam=True,
  245. multimask_output_in_sam=True,
  246. iou_prediction_use_sigmoid=True,
  247. use_obj_ptrs_in_encoder=True,
  248. add_tpos_enc_to_obj_ptrs=True,
  249. only_obj_ptrs_in_the_past_for_eval=True,
  250. pred_obj_scores=True,
  251. pred_obj_scores_mlp=True,
  252. fixed_no_obj_ptr=True,
  253. multimask_output_for_tracking=True,
  254. use_multimask_token_for_obj_ptr=True,
  255. multimask_min_pt_num=0,
  256. multimask_max_pt_num=1,
  257. use_mlp_for_obj_ptr_proj=True,
  258. compile_image_encoder=False,
  259. no_obj_embed_spatial=is_sam2_1,
  260. proj_tpos_enc_in_obj_ptrs=is_sam2_1,
  261. use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
  262. sam_mask_decoder_extra_args=dict(
  263. dynamic_multimask_via_stability=True,
  264. dynamic_multimask_stability_delta=0.05,
  265. dynamic_multimask_stability_thresh=0.98,
  266. ),
  267. )
  268. if checkpoint is not None:
  269. checkpoint = attempt_download_asset(checkpoint)
  270. with open(checkpoint, "rb") as f:
  271. state_dict = torch.load(f)["model"]
  272. sam2.load_state_dict(state_dict)
  273. sam2.eval()
  274. return sam2
  275. sam_model_map = {
  276. "sam_h.pt": build_sam_vit_h,
  277. "sam_l.pt": build_sam_vit_l,
  278. "sam_b.pt": build_sam_vit_b,
  279. "mobile_sam.pt": build_mobile_sam,
  280. "sam2_t.pt": build_sam2_t,
  281. "sam2_s.pt": build_sam2_s,
  282. "sam2_b.pt": build_sam2_b,
  283. "sam2_l.pt": build_sam2_l,
  284. "sam2.1_t.pt": build_sam2_t,
  285. "sam2.1_s.pt": build_sam2_s,
  286. "sam2.1_b.pt": build_sam2_b,
  287. "sam2.1_l.pt": build_sam2_l,
  288. }
  289. def build_sam(ckpt="sam_b.pt"):
  290. """
  291. Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
  292. Args:
  293. ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
  294. Returns:
  295. (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
  296. Raises:
  297. FileNotFoundError: If the provided checkpoint is not a supported SAM model.
  298. Examples:
  299. >>> sam_model = build_sam("sam_b.pt")
  300. >>> sam_model = build_sam("path/to/custom_checkpoint.pt")
  301. Notes:
  302. Supported pre-defined models include:
  303. - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
  304. - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
  305. """
  306. model_builder = None
  307. ckpt = str(ckpt) # to allow Path ckpt types
  308. for k in sam_model_map.keys():
  309. if ckpt.endswith(k):
  310. model_builder = sam_model_map.get(k)
  311. if not model_builder:
  312. raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
  313. return model_builder(ckpt)