block.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """Block modules."""
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import math
  7. from ultralytics.utils.torch_utils import fuse_conv_and_bn
  8. from .conv import Conv, DSConv, DWConv, GhostConv, LightConv, RepConv, autopad
  9. from .transformer import TransformerBlock
  10. __all__ = (
  11. "DFL",
  12. "HGBlock",
  13. "HGStem",
  14. "SPP",
  15. "SPPF",
  16. "C1",
  17. "C2",
  18. "C3",
  19. "C2f",
  20. "C2fAttn",
  21. "ImagePoolingAttn",
  22. "ContrastiveHead",
  23. "BNContrastiveHead",
  24. "C3x",
  25. "C3TR",
  26. "C3Ghost",
  27. "GhostBottleneck",
  28. "Bottleneck",
  29. "BottleneckCSP",
  30. "Proto",
  31. "RepC3",
  32. "ResNetLayer",
  33. "RepNCSPELAN4",
  34. "ELAN1",
  35. "ADown",
  36. "AConv",
  37. "SPPELAN",
  38. "CBFuse",
  39. "CBLinear",
  40. "C3k2",
  41. "C2fPSA",
  42. "C2PSA",
  43. "RepVGGDW",
  44. "CIB",
  45. "C2fCIB",
  46. "Attention",
  47. "PSA",
  48. "SCDown",
  49. "TorchVision",
  50. "HyperACE",
  51. "DownsampleConv",
  52. "FullPAD_Tunnel",
  53. "DSC3k2"
  54. )
  55. class DFL(nn.Module):
  56. """
  57. Integral module of Distribution Focal Loss (DFL).
  58. Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
  59. """
  60. def __init__(self, c1=16):
  61. """Initialize a convolutional layer with a given number of input channels."""
  62. super().__init__()
  63. self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
  64. x = torch.arange(c1, dtype=torch.float)
  65. self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
  66. self.c1 = c1
  67. def forward(self, x):
  68. """Applies a transformer layer on input tensor 'x' and returns a tensor."""
  69. b, _, a = x.shape # batch, channels, anchors
  70. return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
  71. # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
  72. class Proto(nn.Module):
  73. """YOLOv8 mask Proto module for segmentation models."""
  74. def __init__(self, c1, c_=256, c2=32):
  75. """
  76. Initializes the YOLOv8 mask Proto module with specified number of protos and masks.
  77. Input arguments are ch_in, number of protos, number of masks.
  78. """
  79. super().__init__()
  80. self.cv1 = Conv(c1, c_, k=3)
  81. self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
  82. self.cv2 = Conv(c_, c_, k=3)
  83. self.cv3 = Conv(c_, c2)
  84. def forward(self, x):
  85. """Performs a forward pass through layers using an upsampled input image."""
  86. return self.cv3(self.cv2(self.upsample(self.cv1(x))))
  87. class HGStem(nn.Module):
  88. """
  89. StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
  90. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  91. """
  92. def __init__(self, c1, cm, c2):
  93. """Initialize the SPP layer with input/output channels and specified kernel sizes for max pooling."""
  94. super().__init__()
  95. self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
  96. self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
  97. self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
  98. self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
  99. self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
  100. self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
  101. def forward(self, x):
  102. """Forward pass of a PPHGNetV2 backbone layer."""
  103. x = self.stem1(x)
  104. x = F.pad(x, [0, 1, 0, 1])
  105. x2 = self.stem2a(x)
  106. x2 = F.pad(x2, [0, 1, 0, 1])
  107. x2 = self.stem2b(x2)
  108. x1 = self.pool(x)
  109. x = torch.cat([x1, x2], dim=1)
  110. x = self.stem3(x)
  111. x = self.stem4(x)
  112. return x
  113. class HGBlock(nn.Module):
  114. """
  115. HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
  116. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  117. """
  118. def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
  119. """Initializes a CSP Bottleneck with 1 convolution using specified input and output channels."""
  120. super().__init__()
  121. block = LightConv if lightconv else Conv
  122. self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
  123. self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
  124. self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
  125. self.add = shortcut and c1 == c2
  126. def forward(self, x):
  127. """Forward pass of a PPHGNetV2 backbone layer."""
  128. y = [x]
  129. y.extend(m(y[-1]) for m in self.m)
  130. y = self.ec(self.sc(torch.cat(y, 1)))
  131. return y + x if self.add else y
  132. class SPP(nn.Module):
  133. """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
  134. def __init__(self, c1, c2, k=(5, 9, 13)):
  135. """Initialize the SPP layer with input/output channels and pooling kernel sizes."""
  136. super().__init__()
  137. c_ = c1 // 2 # hidden channels
  138. self.cv1 = Conv(c1, c_, 1, 1)
  139. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  140. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  141. def forward(self, x):
  142. """Forward pass of the SPP layer, performing spatial pyramid pooling."""
  143. x = self.cv1(x)
  144. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  145. class SPPF(nn.Module):
  146. """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
  147. def __init__(self, c1, c2, k=5):
  148. """
  149. Initializes the SPPF layer with given input/output channels and kernel size.
  150. This module is equivalent to SPP(k=(5, 9, 13)).
  151. """
  152. super().__init__()
  153. c_ = c1 // 2 # hidden channels
  154. self.cv1 = Conv(c1, c_, 1, 1)
  155. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  156. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  157. def forward(self, x):
  158. """Forward pass through Ghost Convolution block."""
  159. y = [self.cv1(x)]
  160. y.extend(self.m(y[-1]) for _ in range(3))
  161. return self.cv2(torch.cat(y, 1))
  162. class C1(nn.Module):
  163. """CSP Bottleneck with 1 convolution."""
  164. def __init__(self, c1, c2, n=1):
  165. """Initializes the CSP Bottleneck with configurations for 1 convolution with arguments ch_in, ch_out, number."""
  166. super().__init__()
  167. self.cv1 = Conv(c1, c2, 1, 1)
  168. self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
  169. def forward(self, x):
  170. """Applies cross-convolutions to input in the C3 module."""
  171. y = self.cv1(x)
  172. return self.m(y) + y
  173. class C2(nn.Module):
  174. """CSP Bottleneck with 2 convolutions."""
  175. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  176. """Initializes a CSP Bottleneck with 2 convolutions and optional shortcut connection."""
  177. super().__init__()
  178. self.c = int(c2 * e) # hidden channels
  179. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  180. self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
  181. # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
  182. self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
  183. def forward(self, x):
  184. """Forward pass through the CSP bottleneck with 2 convolutions."""
  185. a, b = self.cv1(x).chunk(2, 1)
  186. return self.cv2(torch.cat((self.m(a), b), 1))
  187. class C2f(nn.Module):
  188. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  189. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
  190. """Initializes a CSP bottleneck with 2 convolutions and n Bottleneck blocks for faster processing."""
  191. super().__init__()
  192. self.c = int(c2 * e) # hidden channels
  193. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  194. self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  195. self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  196. def forward(self, x):
  197. """Forward pass through C2f layer."""
  198. y = list(self.cv1(x).chunk(2, 1))
  199. y.extend(m(y[-1]) for m in self.m)
  200. return self.cv2(torch.cat(y, 1))
  201. def forward_split(self, x):
  202. """Forward pass using split() instead of chunk()."""
  203. y = self.cv1(x).split((self.c, self.c), 1)
  204. y = [y[0], y[1]]
  205. y.extend(m(y[-1]) for m in self.m)
  206. return self.cv2(torch.cat(y, 1))
  207. class C3(nn.Module):
  208. """CSP Bottleneck with 3 convolutions."""
  209. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  210. """Initialize the CSP Bottleneck with given channels, number, shortcut, groups, and expansion values."""
  211. super().__init__()
  212. c_ = int(c2 * e) # hidden channels
  213. self.cv1 = Conv(c1, c_, 1, 1)
  214. self.cv2 = Conv(c1, c_, 1, 1)
  215. self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
  216. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
  217. def forward(self, x):
  218. """Forward pass through the CSP bottleneck with 2 convolutions."""
  219. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  220. class C3x(C3):
  221. """C3 module with cross-convolutions."""
  222. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  223. """Initialize C3TR instance and set default parameters."""
  224. super().__init__(c1, c2, n, shortcut, g, e)
  225. self.c_ = int(c2 * e)
  226. self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
  227. class RepC3(nn.Module):
  228. """Rep C3."""
  229. def __init__(self, c1, c2, n=3, e=1.0):
  230. """Initialize CSP Bottleneck with a single convolution using input channels, output channels, and number."""
  231. super().__init__()
  232. c_ = int(c2 * e) # hidden channels
  233. self.cv1 = Conv(c1, c_, 1, 1)
  234. self.cv2 = Conv(c1, c_, 1, 1)
  235. self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
  236. self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
  237. def forward(self, x):
  238. """Forward pass of RT-DETR neck layer."""
  239. return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
  240. class C3TR(C3):
  241. """C3 module with TransformerBlock()."""
  242. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  243. """Initialize C3Ghost module with GhostBottleneck()."""
  244. super().__init__(c1, c2, n, shortcut, g, e)
  245. c_ = int(c2 * e)
  246. self.m = TransformerBlock(c_, c_, 4, n)
  247. class C3Ghost(C3):
  248. """C3 module with GhostBottleneck()."""
  249. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  250. """Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
  251. super().__init__(c1, c2, n, shortcut, g, e)
  252. c_ = int(c2 * e) # hidden channels
  253. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  254. class GhostBottleneck(nn.Module):
  255. """Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
  256. def __init__(self, c1, c2, k=3, s=1):
  257. """Initializes GhostBottleneck module with arguments ch_in, ch_out, kernel, stride."""
  258. super().__init__()
  259. c_ = c2 // 2
  260. self.conv = nn.Sequential(
  261. GhostConv(c1, c_, 1, 1), # pw
  262. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  263. GhostConv(c_, c2, 1, 1, act=False), # pw-linear
  264. )
  265. self.shortcut = (
  266. nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  267. )
  268. def forward(self, x):
  269. """Applies skip connection and concatenation to input tensor."""
  270. return self.conv(x) + self.shortcut(x)
  271. class Bottleneck(nn.Module):
  272. """Standard bottleneck."""
  273. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
  274. """Initializes a standard bottleneck module with optional shortcut connection and configurable parameters."""
  275. super().__init__()
  276. c_ = int(c2 * e) # hidden channels
  277. self.cv1 = Conv(c1, c_, k[0], 1)
  278. self.cv2 = Conv(c_, c2, k[1], 1, g=g)
  279. self.add = shortcut and c1 == c2
  280. def forward(self, x):
  281. """Applies the YOLO FPN to input data."""
  282. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  283. class BottleneckCSP(nn.Module):
  284. """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
  285. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  286. """Initializes the CSP Bottleneck given arguments for ch_in, ch_out, number, shortcut, groups, expansion."""
  287. super().__init__()
  288. c_ = int(c2 * e) # hidden channels
  289. self.cv1 = Conv(c1, c_, 1, 1)
  290. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  291. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  292. self.cv4 = Conv(2 * c_, c2, 1, 1)
  293. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  294. self.act = nn.SiLU()
  295. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  296. def forward(self, x):
  297. """Applies a CSP bottleneck with 3 convolutions."""
  298. y1 = self.cv3(self.m(self.cv1(x)))
  299. y2 = self.cv2(x)
  300. return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
  301. class ResNetBlock(nn.Module):
  302. """ResNet block with standard convolution layers."""
  303. def __init__(self, c1, c2, s=1, e=4):
  304. """Initialize convolution with given parameters."""
  305. super().__init__()
  306. c3 = e * c2
  307. self.cv1 = Conv(c1, c2, k=1, s=1, act=True)
  308. self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True)
  309. self.cv3 = Conv(c2, c3, k=1, act=False)
  310. self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity()
  311. def forward(self, x):
  312. """Forward pass through the ResNet block."""
  313. return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x))
  314. class ResNetLayer(nn.Module):
  315. """ResNet layer with multiple ResNet blocks."""
  316. def __init__(self, c1, c2, s=1, is_first=False, n=1, e=4):
  317. """Initializes the ResNetLayer given arguments."""
  318. super().__init__()
  319. self.is_first = is_first
  320. if self.is_first:
  321. self.layer = nn.Sequential(
  322. Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  323. )
  324. else:
  325. blocks = [ResNetBlock(c1, c2, s, e=e)]
  326. blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])
  327. self.layer = nn.Sequential(*blocks)
  328. def forward(self, x):
  329. """Forward pass through the ResNet layer."""
  330. return self.layer(x)
  331. class MaxSigmoidAttnBlock(nn.Module):
  332. """Max Sigmoid attention block."""
  333. def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False):
  334. """Initializes MaxSigmoidAttnBlock with specified arguments."""
  335. super().__init__()
  336. self.nh = nh
  337. self.hc = c2 // nh
  338. self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None
  339. self.gl = nn.Linear(gc, ec)
  340. self.bias = nn.Parameter(torch.zeros(nh))
  341. self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)
  342. self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0
  343. def forward(self, x, guide):
  344. """Forward process."""
  345. bs, _, h, w = x.shape
  346. guide = self.gl(guide)
  347. guide = guide.view(bs, -1, self.nh, self.hc)
  348. embed = self.ec(x) if self.ec is not None else x
  349. embed = embed.view(bs, self.nh, self.hc, h, w)
  350. aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide)
  351. aw = aw.max(dim=-1)[0]
  352. aw = aw / (self.hc**0.5)
  353. aw = aw + self.bias[None, :, None, None]
  354. aw = aw.sigmoid() * self.scale
  355. x = self.proj_conv(x)
  356. x = x.view(bs, self.nh, -1, h, w)
  357. x = x * aw.unsqueeze(2)
  358. return x.view(bs, -1, h, w)
  359. class C2fAttn(nn.Module):
  360. """C2f module with an additional attn module."""
  361. def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5):
  362. """Initializes C2f module with attention mechanism for enhanced feature extraction and processing."""
  363. super().__init__()
  364. self.c = int(c2 * e) # hidden channels
  365. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  366. self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  367. self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  368. self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)
  369. def forward(self, x, guide):
  370. """Forward pass through C2f layer."""
  371. y = list(self.cv1(x).chunk(2, 1))
  372. y.extend(m(y[-1]) for m in self.m)
  373. y.append(self.attn(y[-1], guide))
  374. return self.cv2(torch.cat(y, 1))
  375. def forward_split(self, x, guide):
  376. """Forward pass using split() instead of chunk()."""
  377. y = list(self.cv1(x).split((self.c, self.c), 1))
  378. y.extend(m(y[-1]) for m in self.m)
  379. y.append(self.attn(y[-1], guide))
  380. return self.cv2(torch.cat(y, 1))
  381. class ImagePoolingAttn(nn.Module):
  382. """ImagePoolingAttn: Enhance the text embeddings with image-aware information."""
  383. def __init__(self, ec=256, ch=(), ct=512, nh=8, k=3, scale=False):
  384. """Initializes ImagePoolingAttn with specified arguments."""
  385. super().__init__()
  386. nf = len(ch)
  387. self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec))
  388. self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
  389. self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
  390. self.proj = nn.Linear(ec, ct)
  391. self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0
  392. self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch])
  393. self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)])
  394. self.ec = ec
  395. self.nh = nh
  396. self.nf = nf
  397. self.hc = ec // nh
  398. self.k = k
  399. def forward(self, x, text):
  400. """Executes attention mechanism on input tensor x and guide tensor."""
  401. bs = x[0].shape[0]
  402. assert len(x) == self.nf
  403. num_patches = self.k**2
  404. x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)]
  405. x = torch.cat(x, dim=-1).transpose(1, 2)
  406. q = self.query(text)
  407. k = self.key(x)
  408. v = self.value(x)
  409. # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1)
  410. q = q.reshape(bs, -1, self.nh, self.hc)
  411. k = k.reshape(bs, -1, self.nh, self.hc)
  412. v = v.reshape(bs, -1, self.nh, self.hc)
  413. aw = torch.einsum("bnmc,bkmc->bmnk", q, k)
  414. aw = aw / (self.hc**0.5)
  415. aw = F.softmax(aw, dim=-1)
  416. x = torch.einsum("bmnk,bkmc->bnmc", aw, v)
  417. x = self.proj(x.reshape(bs, -1, self.ec))
  418. return x * self.scale + text
  419. class ContrastiveHead(nn.Module):
  420. """Implements contrastive learning head for region-text similarity in vision-language models."""
  421. def __init__(self):
  422. """Initializes ContrastiveHead with specified region-text similarity parameters."""
  423. super().__init__()
  424. # NOTE: use -10.0 to keep the init cls loss consistency with other losses
  425. self.bias = nn.Parameter(torch.tensor([-10.0]))
  426. self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
  427. def forward(self, x, w):
  428. """Forward function of contrastive learning."""
  429. x = F.normalize(x, dim=1, p=2)
  430. w = F.normalize(w, dim=-1, p=2)
  431. x = torch.einsum("bchw,bkc->bkhw", x, w)
  432. return x * self.logit_scale.exp() + self.bias
  433. class BNContrastiveHead(nn.Module):
  434. """
  435. Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization.
  436. Args:
  437. embed_dims (int): Embed dimensions of text and image features.
  438. """
  439. def __init__(self, embed_dims: int):
  440. """Initialize ContrastiveHead with region-text similarity parameters."""
  441. super().__init__()
  442. self.norm = nn.BatchNorm2d(embed_dims)
  443. # NOTE: use -10.0 to keep the init cls loss consistency with other losses
  444. self.bias = nn.Parameter(torch.tensor([-10.0]))
  445. # use -1.0 is more stable
  446. self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
  447. def forward(self, x, w):
  448. """Forward function of contrastive learning."""
  449. x = self.norm(x)
  450. w = F.normalize(w, dim=-1, p=2)
  451. x = torch.einsum("bchw,bkc->bkhw", x, w)
  452. return x * self.logit_scale.exp() + self.bias
  453. class RepBottleneck(Bottleneck):
  454. """Rep bottleneck."""
  455. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
  456. """Initializes a RepBottleneck module with customizable in/out channels, shortcuts, groups and expansion."""
  457. super().__init__(c1, c2, shortcut, g, k, e)
  458. c_ = int(c2 * e) # hidden channels
  459. self.cv1 = RepConv(c1, c_, k[0], 1)
  460. class RepCSP(C3):
  461. """Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
  462. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  463. """Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio."""
  464. super().__init__(c1, c2, n, shortcut, g, e)
  465. c_ = int(c2 * e) # hidden channels
  466. self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  467. class RepNCSPELAN4(nn.Module):
  468. """CSP-ELAN."""
  469. def __init__(self, c1, c2, c3, c4, n=1):
  470. """Initializes CSP-ELAN layer with specified channel sizes, repetitions, and convolutions."""
  471. super().__init__()
  472. self.c = c3 // 2
  473. self.cv1 = Conv(c1, c3, 1, 1)
  474. self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1))
  475. self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1))
  476. self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
  477. def forward(self, x):
  478. """Forward pass through RepNCSPELAN4 layer."""
  479. y = list(self.cv1(x).chunk(2, 1))
  480. y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
  481. return self.cv4(torch.cat(y, 1))
  482. def forward_split(self, x):
  483. """Forward pass using split() instead of chunk()."""
  484. y = list(self.cv1(x).split((self.c, self.c), 1))
  485. y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
  486. return self.cv4(torch.cat(y, 1))
  487. class ELAN1(RepNCSPELAN4):
  488. """ELAN1 module with 4 convolutions."""
  489. def __init__(self, c1, c2, c3, c4):
  490. """Initializes ELAN1 layer with specified channel sizes."""
  491. super().__init__(c1, c2, c3, c4)
  492. self.c = c3 // 2
  493. self.cv1 = Conv(c1, c3, 1, 1)
  494. self.cv2 = Conv(c3 // 2, c4, 3, 1)
  495. self.cv3 = Conv(c4, c4, 3, 1)
  496. self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
  497. class AConv(nn.Module):
  498. """AConv."""
  499. def __init__(self, c1, c2):
  500. """Initializes AConv module with convolution layers."""
  501. super().__init__()
  502. self.cv1 = Conv(c1, c2, 3, 2, 1)
  503. def forward(self, x):
  504. """Forward pass through AConv layer."""
  505. x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
  506. return self.cv1(x)
  507. class ADown(nn.Module):
  508. """ADown."""
  509. def __init__(self, c1, c2):
  510. """Initializes ADown module with convolution layers to downsample input from channels c1 to c2."""
  511. super().__init__()
  512. self.c = c2 // 2
  513. self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
  514. self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
  515. def forward(self, x):
  516. """Forward pass through ADown layer."""
  517. x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
  518. x1, x2 = x.chunk(2, 1)
  519. x1 = self.cv1(x1)
  520. x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
  521. x2 = self.cv2(x2)
  522. return torch.cat((x1, x2), 1)
  523. class SPPELAN(nn.Module):
  524. """SPP-ELAN."""
  525. def __init__(self, c1, c2, c3, k=5):
  526. """Initializes SPP-ELAN block with convolution and max pooling layers for spatial pyramid pooling."""
  527. super().__init__()
  528. self.c = c3
  529. self.cv1 = Conv(c1, c3, 1, 1)
  530. self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  531. self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  532. self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  533. self.cv5 = Conv(4 * c3, c2, 1, 1)
  534. def forward(self, x):
  535. """Forward pass through SPPELAN layer."""
  536. y = [self.cv1(x)]
  537. y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
  538. return self.cv5(torch.cat(y, 1))
  539. class CBLinear(nn.Module):
  540. """CBLinear."""
  541. def __init__(self, c1, c2s, k=1, s=1, p=None, g=1):
  542. """Initializes the CBLinear module, passing inputs unchanged."""
  543. super().__init__()
  544. self.c2s = c2s
  545. self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
  546. def forward(self, x):
  547. """Forward pass through CBLinear layer."""
  548. return self.conv(x).split(self.c2s, dim=1)
  549. class CBFuse(nn.Module):
  550. """CBFuse."""
  551. def __init__(self, idx):
  552. """Initializes CBFuse module with layer index for selective feature fusion."""
  553. super().__init__()
  554. self.idx = idx
  555. def forward(self, xs):
  556. """Forward pass through CBFuse layer."""
  557. target_size = xs[-1].shape[2:]
  558. res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
  559. return torch.sum(torch.stack(res + xs[-1:]), dim=0)
  560. class C3f(nn.Module):
  561. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  562. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
  563. """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
  564. expansion.
  565. """
  566. super().__init__()
  567. c_ = int(c2 * e) # hidden channels
  568. self.cv1 = Conv(c1, c_, 1, 1)
  569. self.cv2 = Conv(c1, c_, 1, 1)
  570. self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)
  571. self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  572. def forward(self, x):
  573. """Forward pass through C2f layer."""
  574. y = [self.cv2(x), self.cv1(x)]
  575. y.extend(m(y[-1]) for m in self.m)
  576. return self.cv3(torch.cat(y, 1))
  577. class C3k2(C2f):
  578. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  579. def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True):
  580. """Initializes the C3k2 module, a faster CSP Bottleneck with 2 convolutions and optional C3k blocks."""
  581. super().__init__(c1, c2, n, shortcut, g, e)
  582. self.m = nn.ModuleList(
  583. C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
  584. )
  585. class C3k(C3):
  586. """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
  587. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
  588. """Initializes the C3k module with specified channels, number of layers, and configurations."""
  589. super().__init__(c1, c2, n, shortcut, g, e)
  590. c_ = int(c2 * e) # hidden channels
  591. # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
  592. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
  593. class RepVGGDW(torch.nn.Module):
  594. """RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
  595. def __init__(self, ed) -> None:
  596. """Initializes RepVGGDW with depthwise separable convolutional layers for efficient processing."""
  597. super().__init__()
  598. self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
  599. self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
  600. self.dim = ed
  601. self.act = nn.SiLU()
  602. def forward(self, x):
  603. """
  604. Performs a forward pass of the RepVGGDW block.
  605. Args:
  606. x (torch.Tensor): Input tensor.
  607. Returns:
  608. (torch.Tensor): Output tensor after applying the depth wise separable convolution.
  609. """
  610. return self.act(self.conv(x) + self.conv1(x))
  611. def forward_fuse(self, x):
  612. """
  613. Performs a forward pass of the RepVGGDW block without fusing the convolutions.
  614. Args:
  615. x (torch.Tensor): Input tensor.
  616. Returns:
  617. (torch.Tensor): Output tensor after applying the depth wise separable convolution.
  618. """
  619. return self.act(self.conv(x))
  620. @torch.no_grad()
  621. def fuse(self):
  622. """
  623. Fuses the convolutional layers in the RepVGGDW block.
  624. This method fuses the convolutional layers and updates the weights and biases accordingly.
  625. """
  626. conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
  627. conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
  628. conv_w = conv.weight
  629. conv_b = conv.bias
  630. conv1_w = conv1.weight
  631. conv1_b = conv1.bias
  632. conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
  633. final_conv_w = conv_w + conv1_w
  634. final_conv_b = conv_b + conv1_b
  635. conv.weight.data.copy_(final_conv_w)
  636. conv.bias.data.copy_(final_conv_b)
  637. self.conv = conv
  638. del self.conv1
  639. class CIB(nn.Module):
  640. """
  641. Conditional Identity Block (CIB) module.
  642. Args:
  643. c1 (int): Number of input channels.
  644. c2 (int): Number of output channels.
  645. shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
  646. e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
  647. lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
  648. """
  649. def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
  650. """Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer."""
  651. super().__init__()
  652. c_ = int(c2 * e) # hidden channels
  653. self.cv1 = nn.Sequential(
  654. Conv(c1, c1, 3, g=c1),
  655. Conv(c1, 2 * c_, 1),
  656. RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_),
  657. Conv(2 * c_, c2, 1),
  658. Conv(c2, c2, 3, g=c2),
  659. )
  660. self.add = shortcut and c1 == c2
  661. def forward(self, x):
  662. """
  663. Forward pass of the CIB module.
  664. Args:
  665. x (torch.Tensor): Input tensor.
  666. Returns:
  667. (torch.Tensor): Output tensor.
  668. """
  669. return x + self.cv1(x) if self.add else self.cv1(x)
  670. class C2fCIB(C2f):
  671. """
  672. C2fCIB class represents a convolutional block with C2f and CIB modules.
  673. Args:
  674. c1 (int): Number of input channels.
  675. c2 (int): Number of output channels.
  676. n (int, optional): Number of CIB modules to stack. Defaults to 1.
  677. shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
  678. lk (bool, optional): Whether to use local key connection. Defaults to False.
  679. g (int, optional): Number of groups for grouped convolution. Defaults to 1.
  680. e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
  681. """
  682. def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
  683. """Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion."""
  684. super().__init__(c1, c2, n, shortcut, g, e)
  685. self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
  686. class Attention(nn.Module):
  687. """
  688. Attention module that performs self-attention on the input tensor.
  689. Args:
  690. dim (int): The input tensor dimension.
  691. num_heads (int): The number of attention heads.
  692. attn_ratio (float): The ratio of the attention key dimension to the head dimension.
  693. Attributes:
  694. num_heads (int): The number of attention heads.
  695. head_dim (int): The dimension of each attention head.
  696. key_dim (int): The dimension of the attention key.
  697. scale (float): The scaling factor for the attention scores.
  698. qkv (Conv): Convolutional layer for computing the query, key, and value.
  699. proj (Conv): Convolutional layer for projecting the attended values.
  700. pe (Conv): Convolutional layer for positional encoding.
  701. """
  702. def __init__(self, dim, num_heads=8, attn_ratio=0.5):
  703. """Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
  704. super().__init__()
  705. self.num_heads = num_heads
  706. self.head_dim = dim // num_heads
  707. self.key_dim = int(self.head_dim * attn_ratio)
  708. self.scale = self.key_dim**-0.5
  709. nh_kd = self.key_dim * num_heads
  710. h = dim + nh_kd * 2
  711. self.qkv = Conv(dim, h, 1, act=False)
  712. self.proj = Conv(dim, dim, 1, act=False)
  713. self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
  714. def forward(self, x):
  715. """
  716. Forward pass of the Attention module.
  717. Args:
  718. x (torch.Tensor): The input tensor.
  719. Returns:
  720. (torch.Tensor): The output tensor after self-attention.
  721. """
  722. B, C, H, W = x.shape
  723. N = H * W
  724. qkv = self.qkv(x)
  725. q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
  726. [self.key_dim, self.key_dim, self.head_dim], dim=2
  727. )
  728. attn = (q.transpose(-2, -1) @ k) * self.scale
  729. attn = attn.softmax(dim=-1)
  730. x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
  731. x = self.proj(x)
  732. return x
  733. class PSABlock(nn.Module):
  734. """
  735. PSABlock class implementing a Position-Sensitive Attention block for neural networks.
  736. This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
  737. with optional shortcut connections.
  738. Attributes:
  739. attn (Attention): Multi-head attention module.
  740. ffn (nn.Sequential): Feed-forward neural network module.
  741. add (bool): Flag indicating whether to add shortcut connections.
  742. Methods:
  743. forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.
  744. Examples:
  745. Create a PSABlock and perform a forward pass
  746. >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)
  747. >>> input_tensor = torch.randn(1, 128, 32, 32)
  748. >>> output_tensor = psablock(input_tensor)
  749. """
  750. def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None:
  751. """Initializes the PSABlock with attention and feed-forward layers for enhanced feature extraction."""
  752. super().__init__()
  753. self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
  754. self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
  755. self.add = shortcut
  756. def forward(self, x):
  757. """Executes a forward pass through PSABlock, applying attention and feed-forward layers to the input tensor."""
  758. x = x + self.attn(x) if self.add else self.attn(x)
  759. x = x + self.ffn(x) if self.add else self.ffn(x)
  760. return x
  761. class PSA(nn.Module):
  762. """
  763. PSA class for implementing Position-Sensitive Attention in neural networks.
  764. This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to
  765. input tensors, enhancing feature extraction and processing capabilities.
  766. Attributes:
  767. c (int): Number of hidden channels after applying the initial convolution.
  768. cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
  769. cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
  770. attn (Attention): Attention module for position-sensitive attention.
  771. ffn (nn.Sequential): Feed-forward network for further processing.
  772. Methods:
  773. forward: Applies position-sensitive attention and feed-forward network to the input tensor.
  774. Examples:
  775. Create a PSA module and apply it to an input tensor
  776. >>> psa = PSA(c1=128, c2=128, e=0.5)
  777. >>> input_tensor = torch.randn(1, 128, 64, 64)
  778. >>> output_tensor = psa.forward(input_tensor)
  779. """
  780. def __init__(self, c1, c2, e=0.5):
  781. """Initializes the PSA module with input/output channels and attention mechanism for feature extraction."""
  782. super().__init__()
  783. assert c1 == c2
  784. self.c = int(c1 * e)
  785. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  786. self.cv2 = Conv(2 * self.c, c1, 1)
  787. self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
  788. self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
  789. def forward(self, x):
  790. """Executes forward pass in PSA module, applying attention and feed-forward layers to the input tensor."""
  791. a, b = self.cv1(x).split((self.c, self.c), dim=1)
  792. b = b + self.attn(b)
  793. b = b + self.ffn(b)
  794. return self.cv2(torch.cat((a, b), 1))
  795. class C2PSA(nn.Module):
  796. """
  797. C2PSA module with attention mechanism for enhanced feature extraction and processing.
  798. This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
  799. capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
  800. Attributes:
  801. c (int): Number of hidden channels.
  802. cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
  803. cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
  804. m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.
  805. Methods:
  806. forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.
  807. Notes:
  808. This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.
  809. Examples:
  810. >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)
  811. >>> input_tensor = torch.randn(1, 256, 64, 64)
  812. >>> output_tensor = c2psa(input_tensor)
  813. """
  814. def __init__(self, c1, c2, n=1, e=0.5):
  815. """Initializes the C2PSA module with specified input/output channels, number of layers, and expansion ratio."""
  816. super().__init__()
  817. assert c1 == c2
  818. self.c = int(c1 * e)
  819. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  820. self.cv2 = Conv(2 * self.c, c1, 1)
  821. self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
  822. def forward(self, x):
  823. """Processes the input tensor 'x' through a series of PSA blocks and returns the transformed tensor."""
  824. a, b = self.cv1(x).split((self.c, self.c), dim=1)
  825. b = self.m(b)
  826. return self.cv2(torch.cat((a, b), 1))
  827. class C2fPSA(C2f):
  828. """
  829. C2fPSA module with enhanced feature extraction using PSA blocks.
  830. This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.
  831. Attributes:
  832. c (int): Number of hidden channels.
  833. cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
  834. cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
  835. m (nn.ModuleList): List of PSA blocks for feature extraction.
  836. Methods:
  837. forward: Performs a forward pass through the C2fPSA module.
  838. forward_split: Performs a forward pass using split() instead of chunk().
  839. Examples:
  840. >>> import torch
  841. >>> from ultralytics.models.common import C2fPSA
  842. >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)
  843. >>> x = torch.randn(1, 64, 128, 128)
  844. >>> output = model(x)
  845. >>> print(output.shape)
  846. """
  847. def __init__(self, c1, c2, n=1, e=0.5):
  848. """Initializes the C2fPSA module, a variant of C2f with PSA blocks for enhanced feature extraction."""
  849. assert c1 == c2
  850. super().__init__(c1, c2, n=n, e=e)
  851. self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))
  852. class SCDown(nn.Module):
  853. """
  854. SCDown module for downsampling with separable convolutions.
  855. This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in
  856. efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.
  857. Attributes:
  858. cv1 (Conv): Pointwise convolution layer that reduces the number of channels.
  859. cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.
  860. Methods:
  861. forward: Applies the SCDown module to the input tensor.
  862. Examples:
  863. >>> import torch
  864. >>> from ultralytics import SCDown
  865. >>> model = SCDown(c1=64, c2=128, k=3, s=2)
  866. >>> x = torch.randn(1, 64, 128, 128)
  867. >>> y = model(x)
  868. >>> print(y.shape)
  869. torch.Size([1, 128, 64, 64])
  870. """
  871. def __init__(self, c1, c2, k, s):
  872. """Initializes the SCDown module with specified input/output channels, kernel size, and stride."""
  873. super().__init__()
  874. self.cv1 = Conv(c1, c2, 1, 1)
  875. self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
  876. def forward(self, x):
  877. """Applies convolution and downsampling to the input tensor in the SCDown module."""
  878. return self.cv2(self.cv1(x))
  879. class TorchVision(nn.Module):
  880. """
  881. TorchVision module to allow loading any torchvision model.
  882. This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
  883. Attributes:
  884. m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
  885. Args:
  886. c1 (int): Input channels.
  887. c2 (): Output channels.
  888. model (str): Name of the torchvision model to load.
  889. weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
  890. unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
  891. truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
  892. split (bool, optional): Returns output from intermediate child modules as list. Default is False.
  893. """
  894. def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
  895. """Load the model and weights from torchvision."""
  896. import torchvision # scope for faster 'import ultralytics'
  897. super().__init__()
  898. if hasattr(torchvision.models, "get_model"):
  899. self.m = torchvision.models.get_model(model, weights=weights)
  900. else:
  901. self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
  902. if unwrap:
  903. layers = list(self.m.children())[:-truncate]
  904. if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
  905. layers = [*list(layers[0].children()), *layers[1:]]
  906. self.m = nn.Sequential(*layers)
  907. self.split = split
  908. else:
  909. self.split = False
  910. self.m.head = self.m.heads = nn.Identity()
  911. def forward(self, x):
  912. """Forward pass through the model."""
  913. if self.split:
  914. y = [x]
  915. y.extend(m(y[-1]) for m in self.m)
  916. else:
  917. y = self.m(x)
  918. return y
  919. import logging
  920. logger = logging.getLogger(__name__)
  921. USE_FLASH_ATTN = False
  922. try:
  923. import torch
  924. if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: # Ampere or newer
  925. from flash_attn.flash_attn_interface import flash_attn_func
  926. USE_FLASH_ATTN = True
  927. else:
  928. from torch.nn.functional import scaled_dot_product_attention as sdpa
  929. logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
  930. except Exception:
  931. from torch.nn.functional import scaled_dot_product_attention as sdpa
  932. logger.warning("FlashAttention is not available on this device. Using scaled_dot_product_attention instead.")
  933. class AAttn(nn.Module):
  934. """
  935. Area-attention module with the requirement of flash attention.
  936. Attributes:
  937. dim (int): Number of hidden channels;
  938. num_heads (int): Number of heads into which the attention mechanism is divided;
  939. area (int, optional): Number of areas the feature map is divided. Defaults to 1.
  940. Methods:
  941. forward: Performs a forward process of input tensor and outputs a tensor after the execution of the area attention mechanism.
  942. Examples:
  943. >>> import torch
  944. >>> from ultralytics.nn.modules import AAttn
  945. >>> model = AAttn(dim=64, num_heads=2, area=4)
  946. >>> x = torch.randn(2, 64, 128, 128)
  947. >>> output = model(x)
  948. >>> print(output.shape)
  949. Notes:
  950. recommend that dim//num_heads be a multiple of 32 or 64.
  951. """
  952. def __init__(self, dim, num_heads, area=1):
  953. """Initializes the area-attention module, a simple yet efficient attention module for YOLO."""
  954. super().__init__()
  955. self.area = area
  956. self.num_heads = num_heads
  957. self.head_dim = head_dim = dim // num_heads
  958. all_head_dim = head_dim * self.num_heads
  959. self.qk = Conv(dim, all_head_dim * 2, 1, act=False)
  960. self.v = Conv(dim, all_head_dim, 1, act=False)
  961. self.proj = Conv(all_head_dim, dim, 1, act=False)
  962. self.pe = Conv(all_head_dim, dim, 5, 1, 2, g=dim, act=False)
  963. def forward(self, x):
  964. """Processes the input tensor 'x' through the area-attention"""
  965. B, C, H, W = x.shape
  966. N = H * W
  967. qk = self.qk(x).flatten(2).transpose(1, 2)
  968. v = self.v(x)
  969. pp = self.pe(v)
  970. v = v.flatten(2).transpose(1, 2)
  971. if self.area > 1:
  972. qk = qk.reshape(B * self.area, N // self.area, C * 2)
  973. v = v.reshape(B * self.area, N // self.area, C)
  974. B, N, _ = qk.shape
  975. q, k = qk.split([C, C], dim=2)
  976. if x.is_cuda and USE_FLASH_ATTN:
  977. q = q.view(B, N, self.num_heads, self.head_dim)
  978. k = k.view(B, N, self.num_heads, self.head_dim)
  979. v = v.view(B, N, self.num_heads, self.head_dim)
  980. x = flash_attn_func(
  981. q.contiguous().half(),
  982. k.contiguous().half(),
  983. v.contiguous().half()
  984. ).to(q.dtype)
  985. else:
  986. q = q.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
  987. k = k.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
  988. v = v.transpose(1, 2).view(B, self.num_heads, self.head_dim, N)
  989. attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)
  990. max_attn = attn.max(dim=-1, keepdim=True).values
  991. exp_attn = torch.exp(attn - max_attn)
  992. attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True)
  993. x = (v @ attn.transpose(-2, -1))
  994. x = x.permute(0, 3, 1, 2)
  995. if self.area > 1:
  996. x = x.reshape(B // self.area, N * self.area, C)
  997. B, N, _ = x.shape
  998. x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
  999. return self.proj(x + pp)
  1000. class ABlock(nn.Module):
  1001. """
  1002. ABlock class implementing a Area-Attention block with effective feature extraction.
  1003. This class encapsulates the functionality for applying multi-head attention with feature map are dividing into areas
  1004. and feed-forward neural network layers.
  1005. Attributes:
  1006. dim (int): Number of hidden channels;
  1007. num_heads (int): Number of heads into which the attention mechanism is divided;
  1008. mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
  1009. area (int, optional): Number of areas the feature map is divided. Defaults to 1.
  1010. Methods:
  1011. forward: Performs a forward pass through the ABlock, applying area-attention and feed-forward layers.
  1012. Examples:
  1013. Create a ABlock and perform a forward pass
  1014. >>> model = ABlock(dim=64, num_heads=2, mlp_ratio=1.2, area=4)
  1015. >>> x = torch.randn(2, 64, 128, 128)
  1016. >>> output = model(x)
  1017. >>> print(output.shape)
  1018. Notes:
  1019. recommend that dim//num_heads be a multiple of 32 or 64.
  1020. """
  1021. def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1):
  1022. """Initializes the ABlock with area-attention and feed-forward layers for faster feature extraction."""
  1023. super().__init__()
  1024. self.attn = AAttn(dim, num_heads=num_heads, area=area)
  1025. mlp_hidden_dim = int(dim * mlp_ratio)
  1026. self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
  1027. self.apply(self._init_weights)
  1028. def _init_weights(self, m):
  1029. """Initialize weights using a truncated normal distribution."""
  1030. if isinstance(m, nn.Conv2d):
  1031. nn.init.trunc_normal_(m.weight, std=0.02)
  1032. if m.bias is not None:
  1033. nn.init.constant_(m.bias, 0)
  1034. def forward(self, x):
  1035. """Executes a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
  1036. x = x + self.attn(x)
  1037. x = x + self.mlp(x)
  1038. return x
  1039. class A2C2f(nn.Module):
  1040. """
  1041. A2C2f module with residual enhanced feature extraction using ABlock blocks with area-attention. Also known as R-ELAN
  1042. This class extends the C2f module by incorporating ABlock blocks for fast attention mechanisms and feature extraction.
  1043. Attributes:
  1044. c1 (int): Number of input channels;
  1045. c2 (int): Number of output channels;
  1046. n (int, optional): Number of 2xABlock modules to stack. Defaults to 1;
  1047. a2 (bool, optional): Whether use area-attention. Defaults to True;
  1048. area (int, optional): Number of areas the feature map is divided. Defaults to 1;
  1049. residual (bool, optional): Whether use the residual (with layer scale). Defaults to False;
  1050. mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2;
  1051. e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5;
  1052. g (int, optional): Number of groups for grouped convolution. Defaults to 1;
  1053. shortcut (bool, optional): Whether to use shortcut connection. Defaults to True;
  1054. Methods:
  1055. forward: Performs a forward pass through the A2C2f module.
  1056. Examples:
  1057. >>> import torch
  1058. >>> from ultralytics.nn.modules import A2C2f
  1059. >>> model = A2C2f(c1=64, c2=64, n=2, a2=True, area=4, residual=True, e=0.5)
  1060. >>> x = torch.randn(2, 64, 128, 128)
  1061. >>> output = model(x)
  1062. >>> print(output.shape)
  1063. """
  1064. def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True):
  1065. super().__init__()
  1066. c_ = int(c2 * e) # hidden channels
  1067. assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32."
  1068. # num_heads = c_ // 64 if c_ // 64 >= 2 else c_ // 32
  1069. num_heads = c_ // 32
  1070. self.cv1 = Conv(c1, c_, 1, 1)
  1071. self.cv2 = Conv((1 + n) * c_, c2, 1) # optional act=FReLU(c2)
  1072. init_values = 0.01 # or smaller
  1073. self.gamma = nn.Parameter(init_values * torch.ones((c2)), requires_grad=True) if a2 and residual else None
  1074. self.m = nn.ModuleList(
  1075. nn.Sequential(*(ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2))) if a2 else C3k(c_, c_, 2, shortcut, g) for _ in range(n)
  1076. )
  1077. def forward(self, x):
  1078. """Forward pass through R-ELAN layer."""
  1079. y = [self.cv1(x)]
  1080. y.extend(m(y[-1]) for m in self.m)
  1081. if self.gamma is not None:
  1082. return x + self.gamma.view(1, -1, 1, 1) * self.cv2(torch.cat(y, 1))
  1083. return self.cv2(torch.cat(y, 1))
  1084. class DSBottleneck(nn.Module):
  1085. def __init__(self, c1, c2, shortcut=True, e=0.5, k1=3, k2=5, d2=1):
  1086. super().__init__()
  1087. c_ = int(c2 * e)
  1088. self.cv1 = DSConv(c1, c_, k1, s=1, p=None, d=1)
  1089. self.cv2 = DSConv(c_, c2, k2, s=1, p=None, d=d2)
  1090. self.add = shortcut and c1 == c2
  1091. def forward(self, x):
  1092. y = self.cv2(self.cv1(x))
  1093. return x + y if self.add else y
  1094. class DSC3k(C3):
  1095. def __init__(
  1096. self,
  1097. c1,
  1098. c2,
  1099. n=1,
  1100. shortcut=True,
  1101. g=1,
  1102. e=0.5,
  1103. k1=3,
  1104. k2=5,
  1105. d2=1
  1106. ):
  1107. super().__init__(c1, c2, n, shortcut, g, e)
  1108. c_ = int(c2 * e)
  1109. self.m = nn.Sequential(
  1110. *(
  1111. DSBottleneck(
  1112. c_, c_,
  1113. shortcut=shortcut,
  1114. e=1.0,
  1115. k1=k1,
  1116. k2=k2,
  1117. d2=d2
  1118. )
  1119. for _ in range(n)
  1120. )
  1121. )
  1122. class DSC3k2(C2f):
  1123. def __init__(
  1124. self,
  1125. c1,
  1126. c2,
  1127. n=1,
  1128. dsc3k=False,
  1129. e=0.5,
  1130. g=1,
  1131. shortcut=True,
  1132. k1=3,
  1133. k2=7,
  1134. d2=1
  1135. ):
  1136. super().__init__(c1, c2, n, shortcut, g, e)
  1137. if dsc3k:
  1138. self.m = nn.ModuleList(
  1139. DSC3k(
  1140. self.c, self.c,
  1141. n=2,
  1142. shortcut=shortcut,
  1143. g=g,
  1144. e=1.0,
  1145. k1=k1,
  1146. k2=k2,
  1147. d2=d2
  1148. )
  1149. for _ in range(n)
  1150. )
  1151. else:
  1152. self.m = nn.ModuleList(
  1153. DSBottleneck(
  1154. self.c, self.c,
  1155. shortcut=shortcut,
  1156. e=1.0,
  1157. k1=k1,
  1158. k2=k2,
  1159. d2=d2
  1160. )
  1161. for _ in range(n)
  1162. )
  1163. class AdaHyperedgeGen(nn.Module):
  1164. def __init__(self, node_dim, num_hyperedges, num_heads=4, dropout=0.1, context="both"):
  1165. super().__init__()
  1166. self.num_heads = num_heads
  1167. self.num_hyperedges = num_hyperedges
  1168. self.head_dim = node_dim // num_heads
  1169. self.context = context
  1170. self.prototype_base = nn.Parameter(torch.Tensor(num_hyperedges, node_dim))
  1171. nn.init.xavier_uniform_(self.prototype_base)
  1172. if context in ("mean", "max"):
  1173. self.context_net = nn.Linear(node_dim, num_hyperedges * node_dim)
  1174. elif context == "both":
  1175. self.context_net = nn.Linear(2*node_dim, num_hyperedges * node_dim)
  1176. else:
  1177. raise ValueError(
  1178. f"Unsupported context '{context}'. "
  1179. "Expected one of: 'mean', 'max', 'both'."
  1180. )
  1181. self.pre_head_proj = nn.Linear(node_dim, node_dim)
  1182. self.dropout = nn.Dropout(dropout)
  1183. self.scaling = math.sqrt(self.head_dim)
  1184. def forward(self, X):
  1185. B, N, D = X.shape
  1186. if self.context == "mean":
  1187. context_cat = X.mean(dim=1)
  1188. elif self.context == "max":
  1189. context_cat, _ = X.max(dim=1)
  1190. else:
  1191. avg_context = X.mean(dim=1)
  1192. max_context, _ = X.max(dim=1)
  1193. context_cat = torch.cat([avg_context, max_context], dim=-1)
  1194. prototype_offsets = self.context_net(context_cat).view(B, self.num_hyperedges, D)
  1195. prototypes = self.prototype_base.unsqueeze(0) + prototype_offsets
  1196. X_proj = self.pre_head_proj(X)
  1197. X_heads = X_proj.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  1198. proto_heads = prototypes.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
  1199. X_heads_flat = X_heads.reshape(B * self.num_heads, N, self.head_dim)
  1200. proto_heads_flat = proto_heads.reshape(B * self.num_heads, self.num_hyperedges, self.head_dim).transpose(1, 2)
  1201. logits = torch.bmm(X_heads_flat, proto_heads_flat) / self.scaling
  1202. logits = logits.view(B, self.num_heads, N, self.num_hyperedges).mean(dim=1)
  1203. logits = self.dropout(logits)
  1204. return F.softmax(logits, dim=1)
  1205. class AdaHGConv(nn.Module):
  1206. def __init__(self, embed_dim, num_hyperedges=16, num_heads=4, dropout=0.1, context="both"):
  1207. super().__init__()
  1208. self.edge_generator = AdaHyperedgeGen(embed_dim, num_hyperedges, num_heads, dropout, context)
  1209. self.edge_proj = nn.Sequential(
  1210. nn.Linear(embed_dim, embed_dim ),
  1211. nn.GELU()
  1212. )
  1213. self.node_proj = nn.Sequential(
  1214. nn.Linear(embed_dim, embed_dim ),
  1215. nn.GELU()
  1216. )
  1217. def forward(self, X):
  1218. A = self.edge_generator(X)
  1219. He = torch.bmm(A.transpose(1, 2), X)
  1220. He = self.edge_proj(He)
  1221. X_new = torch.bmm(A, He)
  1222. X_new = self.node_proj(X_new)
  1223. return X_new + X
  1224. class AdaHGComputation(nn.Module):
  1225. def __init__(self, embed_dim, num_hyperedges=16, num_heads=8, dropout=0.1, context="both"):
  1226. super().__init__()
  1227. self.embed_dim = embed_dim
  1228. self.hgnn = AdaHGConv(
  1229. embed_dim=embed_dim,
  1230. num_hyperedges=num_hyperedges,
  1231. num_heads=num_heads,
  1232. dropout=dropout,
  1233. context=context
  1234. )
  1235. def forward(self, x):
  1236. B, C, H, W = x.shape
  1237. tokens = x.flatten(2).transpose(1, 2)
  1238. tokens = self.hgnn(tokens)
  1239. x_out = tokens.transpose(1, 2).view(B, C, H, W)
  1240. return x_out
  1241. class C3AH(nn.Module):
  1242. def __init__(self, c1, c2, e=1.0, num_hyperedges=8, context="both"):
  1243. super().__init__()
  1244. c_ = int(c2 * e)
  1245. assert c_ % 16 == 0, "Dimension of AdaHGComputation should be a multiple of 16."
  1246. num_heads = c_ // 16
  1247. self.cv1 = Conv(c1, c_, 1, 1)
  1248. self.cv2 = Conv(c1, c_, 1, 1)
  1249. self.m = AdaHGComputation(embed_dim=c_,
  1250. num_hyperedges=num_hyperedges,
  1251. num_heads=num_heads,
  1252. dropout=0.1,
  1253. context=context)
  1254. self.cv3 = Conv(2 * c_, c2, 1)
  1255. def forward(self, x):
  1256. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  1257. class FuseModule(nn.Module):
  1258. def __init__(self, c_in, channel_adjust):
  1259. super(FuseModule, self).__init__()
  1260. self.downsample = nn.AvgPool2d(kernel_size=2)
  1261. self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
  1262. if channel_adjust:
  1263. self.conv_out = Conv(4 * c_in, c_in, 1)
  1264. else:
  1265. self.conv_out = Conv(3 * c_in, c_in, 1)
  1266. def forward(self, x):
  1267. x1_ds = self.downsample(x[0])
  1268. x3_up = self.upsample(x[2])
  1269. x_cat = torch.cat([x1_ds, x[1], x3_up], dim=1)
  1270. out = self.conv_out(x_cat)
  1271. return out
  1272. class HyperACE(nn.Module):
  1273. def __init__(self, c1, c2, n=1, num_hyperedges=8, dsc3k=True, shortcut=False, e1=0.5, e2=1, context="both", channel_adjust=True):
  1274. super().__init__()
  1275. self.c = int(c2 * e1)
  1276. self.cv1 = Conv(c1, 3 * self.c, 1, 1)
  1277. self.cv2 = Conv((4 + n) * self.c, c2, 1)
  1278. self.m = nn.ModuleList(
  1279. DSC3k(self.c, self.c, 2, shortcut, k1=3, k2=7) if dsc3k else DSBottleneck(self.c, self.c, shortcut=shortcut) for _ in range(n)
  1280. )
  1281. self.fuse = FuseModule(c1, channel_adjust)
  1282. self.branch1 = C3AH(self.c, self.c, e2, num_hyperedges, context)
  1283. self.branch2 = C3AH(self.c, self.c, e2, num_hyperedges, context)
  1284. def forward(self, X):
  1285. x = self.fuse(X)
  1286. y = list(self.cv1(x).chunk(3, 1))
  1287. out1 = self.branch1(y[1])
  1288. out2 = self.branch2(y[1])
  1289. y.extend(m(y[-1]) for m in self.m)
  1290. y[1] = out1
  1291. y.append(out2)
  1292. return self.cv2(torch.cat(y, 1))
  1293. class DownsampleConv(nn.Module):
  1294. def __init__(self, in_channels, channel_adjust=True):
  1295. super().__init__()
  1296. self.downsample = nn.AvgPool2d(kernel_size=2)
  1297. if channel_adjust:
  1298. self.channel_adjust = Conv(in_channels, in_channels * 2, 1)
  1299. else:
  1300. self.channel_adjust = nn.Identity()
  1301. def forward(self, x):
  1302. return self.channel_adjust(self.downsample(x))
  1303. class FullPAD_Tunnel(nn.Module):
  1304. def __init__(self):
  1305. super().__init__()
  1306. self.gate = nn.Parameter(torch.tensor(0.0))
  1307. def forward(self, x):
  1308. out = x[0] + self.gate * x[1]
  1309. return out