head.py 155 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950
  1. # from collections import OrderedDict
  2. # from typing import Dict, List, Optional, Tuple
  3. #
  4. # import matplotlib.pyplot as plt
  5. # import torch
  6. # import torch.nn.functional as F
  7. # import torchvision
  8. # from torch import nn, Tensor
  9. # from torchvision.ops import boxes as box_ops, roi_align
  10. #
  11. # from . import _utils as det_utils
  12. #
  13. # from torch.utils.data.dataloader import default_collate
  14. #
  15. #
  16. # def l2loss(input, target):
  17. # return ((target - input) ** 2).mean(2).mean(1)
  18. #
  19. #
  20. # def cross_entropy_loss(logits, positive):
  21. # nlogp = -F.log_softmax(logits, dim=0)
  22. # return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
  23. #
  24. #
  25. # def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
  26. # logp = torch.sigmoid(logits) + offset
  27. # loss = torch.abs(logp - target)
  28. # if mask is not None:
  29. # w = mask.mean(2, True).mean(1, True)
  30. # w[w == 0] = 1
  31. # loss = loss * (mask / w)
  32. #
  33. # return loss.mean(2).mean(1)
  34. #
  35. #
  36. # # def wirepoint_loss(target, outputs, feature, loss_weight,mode):
  37. # # wires = target['wires']
  38. # # result = {"feature": feature}
  39. # # batch, channel, row, col = outputs[0].shape
  40. # # print(f"Initial Output[0] shape: {outputs[0].shape}") # 打印初始输出形状
  41. # # print(f"Total Stacks: {len(outputs)}") # 打印堆栈数
  42. # #
  43. # # T = wires.copy()
  44. # # n_jtyp = T["junc_map"].shape[1]
  45. # # for task in ["junc_map"]:
  46. # # T[task] = T[task].permute(1, 0, 2, 3)
  47. # # for task in ["junc_offset"]:
  48. # # T[task] = T[task].permute(1, 2, 0, 3, 4)
  49. # #
  50. # # offset = self.head_off
  51. # # loss_weight = loss_weight
  52. # # losses = []
  53. # #
  54. # # for stack, output in enumerate(outputs):
  55. # # output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  56. # # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
  57. # # jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  58. # # lmap = output[offset[0]: offset[1]].squeeze(0)
  59. # # joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  60. # #
  61. # # if stack == 0:
  62. # # result["preds"] = {
  63. # # "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
  64. # # "lmap": lmap.sigmoid(),
  65. # # "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
  66. # # }
  67. # # # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
  68. # # # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
  69. # # # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
  70. # #
  71. # # if mode == "testing":
  72. # # return result
  73. # #
  74. # # L = OrderedDict()
  75. # # L["junc_map"] = sum(
  76. # # cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  77. # # )
  78. # # L["line_map"] = (
  79. # # F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  80. # # .mean(2)
  81. # # .mean(1)
  82. # # )
  83. # # L["junc_offset"] = sum(
  84. # # sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  85. # # for i in range(n_jtyp)
  86. # # for j in range(2)
  87. # # )
  88. # # for loss_name in L:
  89. # # L[loss_name].mul_(loss_weight[loss_name])
  90. # # losses.append(L)
  91. # #
  92. # # result["losses"] = losses
  93. # # return result
  94. #
  95. # def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
  96. # # output, feature: head返回结果
  97. # # x, y, idx : line中间生成结果
  98. # result = {}
  99. # batch, channel, row, col = output.shape
  100. #
  101. # wires_targets = [t["wires"] for t in targets]
  102. # wires_targets = wires_targets.copy()
  103. # # print(f'wires_target:{wires_targets}')
  104. # # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  105. # junc_maps = [d["junc_map"] for d in wires_targets]
  106. # junc_offsets = [d["junc_offset"] for d in wires_targets]
  107. # line_maps = [d["line_map"] for d in wires_targets]
  108. #
  109. # junc_map_tensor = torch.stack(junc_maps, dim=0)
  110. # junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  111. # line_map_tensor = torch.stack(line_maps, dim=0)
  112. # T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
  113. #
  114. # n_jtyp = T["junc_map"].shape[1]
  115. #
  116. # for task in ["junc_map"]:
  117. # T[task] = T[task].permute(1, 0, 2, 3)
  118. # for task in ["junc_offset"]:
  119. # T[task] = T[task].permute(1, 2, 0, 3, 4)
  120. #
  121. # offset = [2, 3, 5]
  122. # losses = []
  123. # output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  124. # jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  125. # lmap = output[offset[0]: offset[1]].squeeze(0)
  126. # joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  127. # L = OrderedDict()
  128. # L["junc_map"] = sum(
  129. # cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  130. # )
  131. # L["line_map"] = (
  132. # F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  133. # .mean(2)
  134. # .mean(1)
  135. # )
  136. # L["junc_offset"] = sum(
  137. # sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  138. # for i in range(n_jtyp)
  139. # for j in range(2)
  140. # )
  141. # for loss_name in L:
  142. # L[loss_name].mul_(loss_weight[loss_name])
  143. # losses.append(L)
  144. # result["losses"] = losses
  145. #
  146. # loss = nn.BCEWithLogitsLoss(reduction="none")
  147. # loss = loss(x, y)
  148. # lpos_mask, lneg_mask = y, 1 - y
  149. # loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
  150. #
  151. # def sum_batch(x):
  152. # xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
  153. # return torch.cat(xs)
  154. #
  155. # lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
  156. # lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
  157. # result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
  158. # result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
  159. #
  160. # return result
  161. #
  162. #
  163. # def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
  164. # result = {}
  165. # result["wires"] = {}
  166. # p = torch.cat(ps)
  167. # s = torch.sigmoid(input)
  168. # b = s > 0.5
  169. # lines = []
  170. # score = []
  171. # # print(f"n_batch:{n_batch}")
  172. # for i in range(n_batch):
  173. # # print(f"idx:{idx}")
  174. # p0 = p[idx[i]: idx[i + 1]]
  175. # s0 = s[idx[i]: idx[i + 1]]
  176. # mask = b[idx[i]: idx[i + 1]]
  177. # p0 = p0[mask]
  178. # s0 = s0[mask]
  179. # if len(p0) == 0:
  180. # lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
  181. # score.append(torch.zeros([1, n_out_line], device=p.device))
  182. # else:
  183. # arg = torch.argsort(s0, descending=True)
  184. # p0, s0 = p0[arg], s0[arg]
  185. # lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
  186. # score.append(s0[None, torch.arange(n_out_line) % len(s0)])
  187. # for j in range(len(jcs[i])):
  188. # if len(jcs[i][j]) == 0:
  189. # jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
  190. # jcs[i][j] = jcs[i][j][
  191. # None, torch.arange(n_out_junc) % len(jcs[i][j])
  192. # ]
  193. # result["wires"]["lines"] = torch.cat(lines)
  194. # result["wires"]["score"] = torch.cat(score)
  195. # result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
  196. #
  197. # if len(jcs[i]) > 1:
  198. # result["preds"]["junts"] = torch.cat(
  199. # [jcs[i][1] for i in range(n_batch)]
  200. # )
  201. #
  202. # return result
  203. #
  204. #
  205. # def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  206. # # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  207. # """
  208. # Computes the loss for Faster R-CNN.
  209. #
  210. # Args:
  211. # class_logits (Tensor)
  212. # box_regression (Tensor)
  213. # labels (list[BoxList])
  214. # regression_targets (Tensor)
  215. #
  216. # Returns:
  217. # classification_loss (Tensor)
  218. # box_loss (Tensor)
  219. # """
  220. #
  221. # labels = torch.cat(labels, dim=0)
  222. # regression_targets = torch.cat(regression_targets, dim=0)
  223. #
  224. # classification_loss = F.cross_entropy(class_logits, labels)
  225. #
  226. # # get indices that correspond to the regression targets for
  227. # # the corresponding ground truth labels, to be used with
  228. # # advanced indexing
  229. # sampled_pos_inds_subset = torch.where(labels > 0)[0]
  230. # labels_pos = labels[sampled_pos_inds_subset]
  231. # N, num_classes = class_logits.shape
  232. # box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  233. #
  234. # box_loss = F.smooth_l1_loss(
  235. # box_regression[sampled_pos_inds_subset, labels_pos],
  236. # regression_targets[sampled_pos_inds_subset],
  237. # beta=1 / 9,
  238. # reduction="sum",
  239. # )
  240. # box_loss = box_loss / labels.numel()
  241. #
  242. # return classification_loss, box_loss
  243. #
  244. #
  245. # def maskrcnn_inference(x, labels):
  246. # # type: (Tensor, List[Tensor]) -> List[Tensor]
  247. # """
  248. # From the results of the CNN, post process the masks
  249. # by taking the mask corresponding to the class with max
  250. # probability (which are of fixed size and directly output
  251. # by the CNN) and return the masks in the mask field of the BoxList.
  252. #
  253. # Args:
  254. # x (Tensor): the mask logits
  255. # labels (list[BoxList]): bounding boxes that are used as
  256. # reference, one for ech image
  257. #
  258. # Returns:
  259. # results (list[BoxList]): one BoxList for each image, containing
  260. # the extra field mask
  261. # """
  262. # mask_prob = x.sigmoid()
  263. #
  264. # # select masks corresponding to the predicted classes
  265. # num_masks = x.shape[0]
  266. # boxes_per_image = [label.shape[0] for label in labels]
  267. # labels = torch.cat(labels)
  268. # index = torch.arange(num_masks, device=labels.device)
  269. # mask_prob = mask_prob[index, labels][:, None]
  270. # mask_prob = mask_prob.split(boxes_per_image, dim=0)
  271. #
  272. # return mask_prob
  273. #
  274. #
  275. # def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  276. # # type: (Tensor, Tensor, Tensor, int) -> Tensor
  277. # """
  278. # Given segmentation masks and the bounding boxes corresponding
  279. # to the location of the masks in the image, this function
  280. # crops and resizes the masks in the position defined by the
  281. # boxes. This prepares the masks for them to be fed to the
  282. # loss computation as the targets.
  283. # """
  284. # matched_idxs = matched_idxs.to(boxes)
  285. # rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  286. # gt_masks = gt_masks[:, None].to(rois)
  287. # return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  288. #
  289. #
  290. # def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  291. # # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  292. # """
  293. # Args:
  294. # proposals (list[BoxList])
  295. # mask_logits (Tensor)
  296. # targets (list[BoxList])
  297. #
  298. # Return:
  299. # mask_loss (Tensor): scalar tensor containing the loss
  300. # """
  301. #
  302. # discretization_size = mask_logits.shape[-1]
  303. # # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
  304. # # print(f'mask discretization_size:{discretization_size}')
  305. # labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  306. # # print(f'mask labels:{labels}')
  307. # mask_targets = [
  308. # project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  309. # ]
  310. #
  311. # labels = torch.cat(labels, dim=0)
  312. # # print(f'mask labels1:{labels}')
  313. # mask_targets = torch.cat(mask_targets, dim=0)
  314. #
  315. # # torch.mean (in binary_cross_entropy_with_logits) doesn't
  316. # # accept empty tensors, so handle it separately
  317. # if mask_targets.numel() == 0:
  318. # return mask_logits.sum() * 0
  319. # # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
  320. # # print(f'mask_targets:{mask_targets}')
  321. # mask_loss = F.binary_cross_entropy_with_logits(
  322. # mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  323. # )
  324. # # print(f'mask_loss:{mask_loss}')
  325. # return mask_loss
  326. #
  327. #
  328. # def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  329. # # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  330. # offset_x = rois[:, 0]
  331. # offset_y = rois[:, 1]
  332. # scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  333. # scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  334. #
  335. # offset_x = offset_x[:, None]
  336. # offset_y = offset_y[:, None]
  337. # scale_x = scale_x[:, None]
  338. # scale_y = scale_y[:, None]
  339. #
  340. # x = keypoints[..., 0]
  341. # y = keypoints[..., 1]
  342. #
  343. # x_boundary_inds = x == rois[:, 2][:, None]
  344. # y_boundary_inds = y == rois[:, 3][:, None]
  345. #
  346. # x = (x - offset_x) * scale_x
  347. # x = x.floor().long()
  348. # y = (y - offset_y) * scale_y
  349. # y = y.floor().long()
  350. #
  351. # x[x_boundary_inds] = heatmap_size - 1
  352. # y[y_boundary_inds] = heatmap_size - 1
  353. #
  354. # valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  355. # vis = keypoints[..., 2] > 0
  356. # valid = (valid_loc & vis).long()
  357. #
  358. # lin_ind = y * heatmap_size + x
  359. # heatmaps = lin_ind * valid
  360. #
  361. # return heatmaps, valid
  362. #
  363. #
  364. # def _onnx_heatmaps_to_keypoints(
  365. # maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  366. # ):
  367. # num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  368. #
  369. # width_correction = widths_i / roi_map_width
  370. # height_correction = heights_i / roi_map_height
  371. #
  372. # roi_map = F.interpolate(
  373. # maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  374. # )[:, 0]
  375. #
  376. # w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  377. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  378. #
  379. # x_int = pos % w
  380. # y_int = (pos - x_int) // w
  381. #
  382. # x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  383. # dtype=torch.float32
  384. # )
  385. # y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  386. # dtype=torch.float32
  387. # )
  388. #
  389. # xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  390. # xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  391. # xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  392. # xy_preds_i = torch.stack(
  393. # [
  394. # xy_preds_i_0.to(dtype=torch.float32),
  395. # xy_preds_i_1.to(dtype=torch.float32),
  396. # xy_preds_i_2.to(dtype=torch.float32),
  397. # ],
  398. # 0,
  399. # )
  400. #
  401. # # TODO: simplify when indexing without rank will be supported by ONNX
  402. # base = num_keypoints * num_keypoints + num_keypoints + 1
  403. # ind = torch.arange(num_keypoints)
  404. # ind = ind.to(dtype=torch.int64) * base
  405. # end_scores_i = (
  406. # roi_map.index_select(1, y_int.to(dtype=torch.int64))
  407. # .index_select(2, x_int.to(dtype=torch.int64))
  408. # .view(-1)
  409. # .index_select(0, ind.to(dtype=torch.int64))
  410. # )
  411. #
  412. # return xy_preds_i, end_scores_i
  413. #
  414. #
  415. # @torch.jit._script_if_tracing
  416. # def _onnx_heatmaps_to_keypoints_loop(
  417. # maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  418. # ):
  419. # xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  420. # end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  421. #
  422. # for i in range(int(rois.size(0))):
  423. # xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  424. # maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  425. # )
  426. # xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  427. # end_scores = torch.cat(
  428. # (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  429. # )
  430. # return xy_preds, end_scores
  431. #
  432. #
  433. # def heatmaps_to_keypoints(maps, rois):
  434. # """Extract predicted keypoint locations from heatmaps. Output has shape
  435. # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  436. # for each keypoint.
  437. # """
  438. # # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  439. # # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  440. # # consistency with keypoints_to_heatmap_labels by using the conversion from
  441. # # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  442. # # continuous coordinate.
  443. # offset_x = rois[:, 0]
  444. # offset_y = rois[:, 1]
  445. #
  446. # widths = rois[:, 2] - rois[:, 0]
  447. # heights = rois[:, 3] - rois[:, 1]
  448. # widths = widths.clamp(min=1)
  449. # heights = heights.clamp(min=1)
  450. # widths_ceil = widths.ceil()
  451. # heights_ceil = heights.ceil()
  452. #
  453. # num_keypoints = maps.shape[1]
  454. #
  455. # if torchvision._is_tracing():
  456. # xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  457. # maps,
  458. # rois,
  459. # widths_ceil,
  460. # heights_ceil,
  461. # widths,
  462. # heights,
  463. # offset_x,
  464. # offset_y,
  465. # torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  466. # )
  467. # return xy_preds.permute(0, 2, 1), end_scores
  468. #
  469. # xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  470. # end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  471. # for i in range(len(rois)):
  472. # roi_map_width = int(widths_ceil[i].item())
  473. # roi_map_height = int(heights_ceil[i].item())
  474. # width_correction = widths[i] / roi_map_width
  475. # height_correction = heights[i] / roi_map_height
  476. # roi_map = F.interpolate(
  477. # maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  478. # )[:, 0]
  479. # # roi_map_probs = scores_to_probs(roi_map.copy())
  480. # w = roi_map.shape[2]
  481. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  482. #
  483. # x_int = pos % w
  484. # y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  485. # # assert (roi_map_probs[k, y_int, x_int] ==
  486. # # roi_map_probs[k, :, :].max())
  487. # x = (x_int.float() + 0.5) * width_correction
  488. # y = (y_int.float() + 0.5) * height_correction
  489. # xy_preds[i, 0, :] = x + offset_x[i]
  490. # xy_preds[i, 1, :] = y + offset_y[i]
  491. # xy_preds[i, 2, :] = 1
  492. # end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  493. #
  494. # return xy_preds.permute(0, 2, 1), end_scores
  495. #
  496. #
  497. # def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  498. # # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  499. # N, K, H, W = keypoint_logits.shape
  500. # if H != W:
  501. # raise ValueError(
  502. # f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  503. # )
  504. # discretization_size = H
  505. # heatmaps = []
  506. # valid = []
  507. # for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  508. # kp = gt_kp_in_image[midx]
  509. # heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  510. # heatmaps.append(heatmaps_per_image.view(-1))
  511. # valid.append(valid_per_image.view(-1))
  512. #
  513. # keypoint_targets = torch.cat(heatmaps, dim=0)
  514. # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  515. # valid = torch.where(valid)[0]
  516. #
  517. # # torch.mean (in binary_cross_entropy_with_logits) doesn't
  518. # # accept empty tensors, so handle it sepaartely
  519. # if keypoint_targets.numel() == 0 or len(valid) == 0:
  520. # return keypoint_logits.sum() * 0
  521. #
  522. # keypoint_logits = keypoint_logits.view(N * K, H * W)
  523. #
  524. # keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  525. # return keypoint_loss
  526. #
  527. #
  528. # def keypointrcnn_inference(x, boxes):
  529. # # print(f'x:{x.shape}')
  530. # # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  531. # kp_probs = []
  532. # kp_scores = []
  533. #
  534. # boxes_per_image = [box.size(0) for box in boxes]
  535. # x2 = x.split(boxes_per_image, dim=0)
  536. # # print(f'x2:{x2}')
  537. #
  538. # for xx, bb in zip(x2, boxes):
  539. # kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  540. # kp_probs.append(kp_prob)
  541. # kp_scores.append(scores)
  542. #
  543. # return kp_probs, kp_scores
  544. #
  545. #
  546. # def _onnx_expand_boxes(boxes, scale):
  547. # # type: (Tensor, float) -> Tensor
  548. # w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  549. # h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  550. # x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  551. # y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  552. #
  553. # w_half = w_half.to(dtype=torch.float32) * scale
  554. # h_half = h_half.to(dtype=torch.float32) * scale
  555. #
  556. # boxes_exp0 = x_c - w_half
  557. # boxes_exp1 = y_c - h_half
  558. # boxes_exp2 = x_c + w_half
  559. # boxes_exp3 = y_c + h_half
  560. # boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  561. # return boxes_exp
  562. #
  563. #
  564. # # the next two functions should be merged inside Masker
  565. # # but are kept here for the moment while we need them
  566. # # temporarily for paste_mask_in_image
  567. # def expand_boxes(boxes, scale):
  568. # # type: (Tensor, float) -> Tensor
  569. # if torchvision._is_tracing():
  570. # return _onnx_expand_boxes(boxes, scale)
  571. # w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  572. # h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  573. # x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  574. # y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  575. #
  576. # w_half *= scale
  577. # h_half *= scale
  578. #
  579. # boxes_exp = torch.zeros_like(boxes)
  580. # boxes_exp[:, 0] = x_c - w_half
  581. # boxes_exp[:, 2] = x_c + w_half
  582. # boxes_exp[:, 1] = y_c - h_half
  583. # boxes_exp[:, 3] = y_c + h_half
  584. # return boxes_exp
  585. #
  586. #
  587. # @torch.jit.unused
  588. # def expand_masks_tracing_scale(M, padding):
  589. # # type: (int, int) -> float
  590. # return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  591. #
  592. #
  593. # def expand_masks(mask, padding):
  594. # # type: (Tensor, int) -> Tuple[Tensor, float]
  595. # M = mask.shape[-1]
  596. # if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  597. # scale = expand_masks_tracing_scale(M, padding)
  598. # else:
  599. # scale = float(M + 2 * padding) / M
  600. # padded_mask = F.pad(mask, (padding,) * 4)
  601. # return padded_mask, scale
  602. #
  603. #
  604. # def paste_mask_in_image(mask, box, im_h, im_w):
  605. # # type: (Tensor, Tensor, int, int) -> Tensor
  606. # TO_REMOVE = 1
  607. # w = int(box[2] - box[0] + TO_REMOVE)
  608. # h = int(box[3] - box[1] + TO_REMOVE)
  609. # w = max(w, 1)
  610. # h = max(h, 1)
  611. #
  612. # # Set shape to [batchxCxHxW]
  613. # mask = mask.expand((1, 1, -1, -1))
  614. #
  615. # # Resize mask
  616. # mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  617. # mask = mask[0][0]
  618. #
  619. # im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  620. # x_0 = max(box[0], 0)
  621. # x_1 = min(box[2] + 1, im_w)
  622. # y_0 = max(box[1], 0)
  623. # y_1 = min(box[3] + 1, im_h)
  624. #
  625. # im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  626. # return im_mask
  627. #
  628. #
  629. # def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  630. # one = torch.ones(1, dtype=torch.int64)
  631. # zero = torch.zeros(1, dtype=torch.int64)
  632. #
  633. # w = box[2] - box[0] + one
  634. # h = box[3] - box[1] + one
  635. # w = torch.max(torch.cat((w, one)))
  636. # h = torch.max(torch.cat((h, one)))
  637. #
  638. # # Set shape to [batchxCxHxW]
  639. # mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  640. #
  641. # # Resize mask
  642. # mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  643. # mask = mask[0][0]
  644. #
  645. # x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  646. # x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  647. # y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  648. # y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  649. #
  650. # unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  651. #
  652. # # TODO : replace below with a dynamic padding when support is added in ONNX
  653. #
  654. # # pad y
  655. # zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  656. # zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  657. # concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  658. # # pad x
  659. # zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  660. # zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  661. # im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  662. # return im_mask
  663. #
  664. #
  665. # @torch.jit._script_if_tracing
  666. # def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  667. # res_append = torch.zeros(0, im_h, im_w)
  668. # for i in range(masks.size(0)):
  669. # mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  670. # mask_res = mask_res.unsqueeze(0)
  671. # res_append = torch.cat((res_append, mask_res))
  672. # return res_append
  673. #
  674. #
  675. # def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  676. # # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  677. # masks, scale = expand_masks(masks, padding=padding)
  678. # boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  679. # im_h, im_w = img_shape
  680. #
  681. # if torchvision._is_tracing():
  682. # return _onnx_paste_masks_in_image_loop(
  683. # masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  684. # )[:, None]
  685. # res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  686. # if len(res) > 0:
  687. # ret = torch.stack(res, dim=0)[:, None]
  688. # else:
  689. # ret = masks.new_empty((0, 1, im_h, im_w))
  690. # return ret
  691. #
  692. #
  693. # class RoIHeads(nn.Module):
  694. # __annotations__ = {
  695. # "box_coder": det_utils.BoxCoder,
  696. # "proposal_matcher": det_utils.Matcher,
  697. # "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  698. # }
  699. #
  700. # def __init__(
  701. # self,
  702. # box_roi_pool,
  703. # box_head,
  704. # box_predictor,
  705. # # Faster R-CNN training
  706. # fg_iou_thresh,
  707. # bg_iou_thresh,
  708. # batch_size_per_image,
  709. # positive_fraction,
  710. # bbox_reg_weights,
  711. # # Faster R-CNN inference
  712. # score_thresh,
  713. # nms_thresh,
  714. # detections_per_img,
  715. # # Mask
  716. # mask_roi_pool=None,
  717. # mask_head=None,
  718. # mask_predictor=None,
  719. # keypoint_roi_pool=None,
  720. # keypoint_head=None,
  721. # keypoint_predictor=None,
  722. # wirepoint_roi_pool=None,
  723. # wirepoint_head=None,
  724. # wirepoint_predictor=None,
  725. # ):
  726. # super().__init__()
  727. #
  728. # self.box_similarity = box_ops.box_iou
  729. # # assign ground-truth boxes for each proposal
  730. # self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  731. #
  732. # self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  733. #
  734. # if bbox_reg_weights is None:
  735. # bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  736. # self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  737. #
  738. # self.box_roi_pool = box_roi_pool
  739. # self.box_head = box_head
  740. # self.box_predictor = box_predictor
  741. #
  742. # self.score_thresh = score_thresh
  743. # self.nms_thresh = nms_thresh
  744. # self.detections_per_img = detections_per_img
  745. #
  746. # self.mask_roi_pool = mask_roi_pool
  747. # self.mask_head = mask_head
  748. # self.mask_predictor = mask_predictor
  749. #
  750. # self.keypoint_roi_pool = keypoint_roi_pool
  751. # self.keypoint_head = keypoint_head
  752. # self.keypoint_predictor = keypoint_predictor
  753. #
  754. # self.wirepoint_roi_pool = wirepoint_roi_pool
  755. # self.wirepoint_head = wirepoint_head
  756. # self.wirepoint_predictor = wirepoint_predictor
  757. #
  758. # def has_mask(self):
  759. # if self.mask_roi_pool is None:
  760. # return False
  761. # if self.mask_head is None:
  762. # return False
  763. # if self.mask_predictor is None:
  764. # return False
  765. # return True
  766. #
  767. # def has_keypoint(self):
  768. # if self.keypoint_roi_pool is None:
  769. # return False
  770. # if self.keypoint_head is None:
  771. # return False
  772. # if self.keypoint_predictor is None:
  773. # return False
  774. # return True
  775. #
  776. # def has_wirepoint(self):
  777. # if self.wirepoint_roi_pool is None:
  778. # print(f'wirepoint_roi_pool is None')
  779. # return False
  780. # if self.wirepoint_head is None:
  781. # print(f'wirepoint_head is None')
  782. # return False
  783. # if self.wirepoint_predictor is None:
  784. # print(f'wirepoint_roi_predictor is None')
  785. # return False
  786. # return True
  787. #
  788. # def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  789. # # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  790. # matched_idxs = []
  791. # labels = []
  792. # for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  793. #
  794. # if gt_boxes_in_image.numel() == 0:
  795. # # Background image
  796. # device = proposals_in_image.device
  797. # clamped_matched_idxs_in_image = torch.zeros(
  798. # (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  799. # )
  800. # labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  801. # else:
  802. # # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  803. # match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  804. # matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  805. #
  806. # clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  807. #
  808. # labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  809. # labels_in_image = labels_in_image.to(dtype=torch.int64)
  810. #
  811. # # Label background (below the low threshold)
  812. # bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  813. # labels_in_image[bg_inds] = 0
  814. #
  815. # # Label ignore proposals (between low and high thresholds)
  816. # ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  817. # labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  818. #
  819. # matched_idxs.append(clamped_matched_idxs_in_image)
  820. # labels.append(labels_in_image)
  821. # return matched_idxs, labels
  822. #
  823. # def subsample(self, labels):
  824. # # type: (List[Tensor]) -> List[Tensor]
  825. # sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  826. # sampled_inds = []
  827. # for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  828. # img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  829. # sampled_inds.append(img_sampled_inds)
  830. # return sampled_inds
  831. #
  832. # def add_gt_proposals(self, proposals, gt_boxes):
  833. # # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  834. # proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  835. #
  836. # return proposals
  837. #
  838. # def check_targets(self, targets):
  839. # # type: (Optional[List[Dict[str, Tensor]]]) -> None
  840. # if targets is None:
  841. # raise ValueError("targets should not be None")
  842. # if not all(["boxes" in t for t in targets]):
  843. # raise ValueError("Every element of targets should have a boxes key")
  844. # if not all(["labels" in t for t in targets]):
  845. # raise ValueError("Every element of targets should have a labels key")
  846. # if self.has_mask():
  847. # if not all(["masks" in t for t in targets]):
  848. # raise ValueError("Every element of targets should have a masks key")
  849. #
  850. # def select_training_samples(
  851. # self,
  852. # proposals, # type: List[Tensor]
  853. # targets, # type: Optional[List[Dict[str, Tensor]]]
  854. # ):
  855. # # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  856. # self.check_targets(targets)
  857. # if targets is None:
  858. # raise ValueError("targets should not be None")
  859. # dtype = proposals[0].dtype
  860. # device = proposals[0].device
  861. #
  862. # gt_boxes = [t["boxes"].to(dtype) for t in targets]
  863. # gt_labels = [t["labels"] for t in targets]
  864. #
  865. # # append ground-truth bboxes to propos
  866. # proposals = self.add_gt_proposals(proposals, gt_boxes)
  867. #
  868. # # get matching gt indices for each proposal
  869. # matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  870. # # sample a fixed proportion of positive-negative proposals
  871. # sampled_inds = self.subsample(labels)
  872. # matched_gt_boxes = []
  873. # num_images = len(proposals)
  874. # for img_id in range(num_images):
  875. # img_sampled_inds = sampled_inds[img_id]
  876. # proposals[img_id] = proposals[img_id][img_sampled_inds]
  877. # labels[img_id] = labels[img_id][img_sampled_inds]
  878. # matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  879. #
  880. # gt_boxes_in_image = gt_boxes[img_id]
  881. # if gt_boxes_in_image.numel() == 0:
  882. # gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  883. # matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  884. #
  885. # regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  886. # return proposals, matched_idxs, labels, regression_targets
  887. #
  888. # def postprocess_detections(
  889. # self,
  890. # class_logits, # type: Tensor
  891. # box_regression, # type: Tensor
  892. # proposals, # type: List[Tensor]
  893. # image_shapes, # type: List[Tuple[int, int]]
  894. # ):
  895. # # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  896. # device = class_logits.device
  897. # num_classes = class_logits.shape[-1]
  898. #
  899. # boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  900. # pred_boxes = self.box_coder.decode(box_regression, proposals)
  901. #
  902. # pred_scores = F.softmax(class_logits, -1)
  903. #
  904. # pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  905. # pred_scores_list = pred_scores.split(boxes_per_image, 0)
  906. #
  907. # all_boxes = []
  908. # all_scores = []
  909. # all_labels = []
  910. # for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  911. # boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  912. #
  913. # # create labels for each prediction
  914. # labels = torch.arange(num_classes, device=device)
  915. # labels = labels.view(1, -1).expand_as(scores)
  916. #
  917. # # remove predictions with the background label
  918. # boxes = boxes[:, 1:]
  919. # scores = scores[:, 1:]
  920. # labels = labels[:, 1:]
  921. #
  922. # # batch everything, by making every class prediction be a separate instance
  923. # boxes = boxes.reshape(-1, 4)
  924. # scores = scores.reshape(-1)
  925. # labels = labels.reshape(-1)
  926. #
  927. # # remove low scoring boxes
  928. # inds = torch.where(scores > self.score_thresh)[0]
  929. # boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  930. #
  931. # # remove empty boxes
  932. # keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  933. # boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  934. #
  935. # # non-maximum suppression, independently done per class
  936. # keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  937. # # keep only topk scoring predictions
  938. # keep = keep[: self.detections_per_img]
  939. # boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  940. #
  941. # all_boxes.append(boxes)
  942. # all_scores.append(scores)
  943. # all_labels.append(labels)
  944. #
  945. # return all_boxes, all_scores, all_labels
  946. #
  947. # def forward(
  948. # self,
  949. # features, # type: Dict[str, Tensor]
  950. # proposals, # type: List[Tensor]
  951. # image_shapes, # type: List[Tuple[int, int]]
  952. # targets=None, # type: Optional[List[Dict[str, Tensor]]]
  953. # ):
  954. # # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  955. # """
  956. # Args:
  957. # features (List[Tensor])
  958. # proposals (List[Tensor[N, 4]])
  959. # image_shapes (List[Tuple[H, W]])
  960. # targets (List[Dict])
  961. # """
  962. # if targets is not None:
  963. # for t in targets:
  964. # # TODO: https://github.com/pytorch/pytorch/issues/26731
  965. # floating_point_types = (torch.float, torch.double, torch.half)
  966. # if not t["boxes"].dtype in floating_point_types:
  967. # raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  968. # if not t["labels"].dtype == torch.int64:
  969. # raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  970. # if self.has_keypoint():
  971. # if not t["keypoints"].dtype == torch.float32:
  972. # raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  973. #
  974. # if self.training:
  975. # proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  976. # else:
  977. # labels = None
  978. # regression_targets = None
  979. # matched_idxs = None
  980. #
  981. # print(f"image_shapes:{image_shapes}")
  982. # box_features = self.box_roi_pool(features, proposals, image_shapes)
  983. # box_features = self.box_head(box_features)
  984. # class_logits, box_regression = self.box_predictor(box_features)
  985. #
  986. # result: List[Dict[str, torch.Tensor]] = []
  987. # losses = {}
  988. # if self.training:
  989. # if labels is None:
  990. # raise ValueError("labels cannot be None")
  991. # if regression_targets is None:
  992. # raise ValueError("regression_targets cannot be None")
  993. # loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  994. # losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  995. # else:
  996. # print('result append boxes!!!')
  997. # boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
  998. # num_images = len(boxes)
  999. # for i in range(num_images):
  1000. # result.append(
  1001. # {
  1002. # "boxes": boxes[i],
  1003. # "labels": labels[i],
  1004. # "scores": scores[i],
  1005. # }
  1006. # )
  1007. #
  1008. # if self.has_mask():
  1009. # mask_proposals = [p["boxes"] for p in result]
  1010. # if self.training:
  1011. # if matched_idxs is None:
  1012. # raise ValueError("if in training, matched_idxs should not be None")
  1013. #
  1014. # # during training, only focus on positive boxes
  1015. # num_images = len(proposals)
  1016. # mask_proposals = []
  1017. # pos_matched_idxs = []
  1018. # for img_id in range(num_images):
  1019. # pos = torch.where(labels[img_id] > 0)[0]
  1020. # mask_proposals.append(proposals[img_id][pos])
  1021. # pos_matched_idxs.append(matched_idxs[img_id][pos])
  1022. # else:
  1023. # pos_matched_idxs = None
  1024. #
  1025. # if self.mask_roi_pool is not None:
  1026. # mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  1027. # mask_features = self.mask_head(mask_features)
  1028. # mask_logits = self.mask_predictor(mask_features)
  1029. # else:
  1030. # raise Exception("Expected mask_roi_pool to be not None")
  1031. #
  1032. # loss_mask = {}
  1033. # if self.training:
  1034. # if targets is None or pos_matched_idxs is None or mask_logits is None:
  1035. # raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  1036. #
  1037. # gt_masks = [t["masks"] for t in targets]
  1038. # gt_labels = [t["labels"] for t in targets]
  1039. # rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  1040. # loss_mask = {"loss_mask": rcnn_loss_mask}
  1041. # else:
  1042. # labels = [r["labels"] for r in result]
  1043. # masks_probs = maskrcnn_inference(mask_logits, labels)
  1044. # for mask_prob, r in zip(masks_probs, result):
  1045. # r["masks"] = mask_prob
  1046. #
  1047. # losses.update(loss_mask)
  1048. #
  1049. # # keep none checks in if conditional so torchscript will conditionally
  1050. # # compile each branch
  1051. # if self.has_keypoint():
  1052. #
  1053. # keypoint_proposals = [p["boxes"] for p in result]
  1054. # if self.training:
  1055. # # during training, only focus on positive boxes
  1056. # num_images = len(proposals)
  1057. # keypoint_proposals = []
  1058. # pos_matched_idxs = []
  1059. # if matched_idxs is None:
  1060. # raise ValueError("if in trainning, matched_idxs should not be None")
  1061. #
  1062. # for img_id in range(num_images):
  1063. # pos = torch.where(labels[img_id] > 0)[0]
  1064. # keypoint_proposals.append(proposals[img_id][pos])
  1065. # pos_matched_idxs.append(matched_idxs[img_id][pos])
  1066. # else:
  1067. # pos_matched_idxs = None
  1068. #
  1069. # keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
  1070. # # tmp = keypoint_features[0][0]
  1071. # # plt.imshow(tmp.detach().numpy())
  1072. # # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
  1073. # keypoint_features = self.keypoint_head(keypoint_features)
  1074. #
  1075. # # print(f'keypoint_features:{keypoint_features.shape}')
  1076. # tmp = keypoint_features[0][0]
  1077. # plt.imshow(tmp.detach().numpy())
  1078. # keypoint_logits = self.keypoint_predictor(keypoint_features)
  1079. # # print(f'keypoint_logits:{keypoint_logits.shape}')
  1080. # """
  1081. # 接wirenet
  1082. # """
  1083. #
  1084. # loss_keypoint = {}
  1085. # if self.training:
  1086. # if targets is None or pos_matched_idxs is None:
  1087. # raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  1088. #
  1089. # gt_keypoints = [t["keypoints"] for t in targets]
  1090. # rcnn_loss_keypoint = keypointrcnn_loss(
  1091. # keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  1092. # )
  1093. # loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  1094. # else:
  1095. # if keypoint_logits is None or keypoint_proposals is None:
  1096. # raise ValueError(
  1097. # "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  1098. # )
  1099. #
  1100. # keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  1101. # for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
  1102. # r["keypoints"] = keypoint_prob
  1103. # r["keypoints_scores"] = kps
  1104. # losses.update(loss_keypoint)
  1105. #
  1106. # if self.has_wirepoint():
  1107. # # print(f'result:{result}')
  1108. # wirepoint_proposals = [p["boxes"] for p in result]
  1109. # if self.training:
  1110. # # during training, only focus on positive boxes
  1111. # num_images = len(proposals)
  1112. # wirepoint_proposals = []
  1113. # pos_matched_idxs = []
  1114. # if matched_idxs is None:
  1115. # raise ValueError("if in trainning, matched_idxs should not be None")
  1116. #
  1117. # for img_id in range(num_images):
  1118. # pos = torch.where(labels[img_id] > 0)[0]
  1119. # wirepoint_proposals.append(proposals[img_id][pos])
  1120. # pos_matched_idxs.append(matched_idxs[img_id][pos])
  1121. # else:
  1122. # pos_matched_idxs = None
  1123. #
  1124. # # print(f'proposals:{len(proposals)}')
  1125. # wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
  1126. #
  1127. # # tmp = keypoint_features[0][0]
  1128. # # plt.imshow(tmp.detach().numpy())
  1129. # # print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
  1130. # outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
  1131. #
  1132. # print(f"wirepoint_features:{wirepoint_features}")
  1133. #
  1134. #
  1135. #
  1136. # outputs = merge_features(outputs, wirepoint_proposals)
  1137. #
  1138. #
  1139. #
  1140. # wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
  1141. #
  1142. # print(f'outpust:{outputs.shape}')
  1143. #
  1144. # wirepoint_logits = self.wirepoint_predictor(inputs=outputs, features=wirepoint_features, targets=targets)
  1145. # x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
  1146. #
  1147. # # print(f'keypoint_features:{wirepoint_features.shape}')
  1148. # if self.training:
  1149. #
  1150. # if targets is None or pos_matched_idxs is None:
  1151. # raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  1152. #
  1153. # loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
  1154. # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
  1155. #
  1156. # loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
  1157. #
  1158. # else:
  1159. # pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
  1160. # result.append(pred)
  1161. #
  1162. # loss_wirepoint = {}
  1163. #
  1164. # # loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
  1165. # # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
  1166. # # loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
  1167. #
  1168. # # tmp = wirepoint_features[0][0]
  1169. # # plt.imshow(tmp.detach().numpy())
  1170. # # wirepoint_logits = self.wirepoint_predictor((outputs,wirepoint_features))
  1171. # # print(f'keypoint_logits:{wirepoint_logits.shape}')
  1172. #
  1173. # # loss_wirepoint = {} lm
  1174. # # result=wirepoint_logits
  1175. #
  1176. # # result.append(pred) lm
  1177. # losses.update(loss_wirepoint)
  1178. # # print(f"result{result}")
  1179. # # print(f"losses{losses}")
  1180. #
  1181. # return result, losses
  1182. #
  1183. #
  1184. # # def merge_features(features, proposals):
  1185. # # # 假设 roi_pool_features 是你的输入张量,形状为 [600, 256, 128, 128]
  1186. # #
  1187. # # # 使用 torch.split 按照每个图像的提议数量分割 features
  1188. # # proposals_count = sum([p.size(0) for p in proposals])
  1189. # # features_size = features.size(0)
  1190. # # # (f'proposals sum:{proposals_count},features batch:{features.size(0)}')
  1191. # # if proposals_count != features_size:
  1192. # # raise ValueError("The length of proposals must match the batch size of features.")
  1193. # #
  1194. # # split_features = []
  1195. # # start_idx = 0
  1196. # # print(f"proposals:{proposals}")
  1197. # # for proposal in proposals:
  1198. # # # 提取当前图像的特征
  1199. # # current_features = features[start_idx:start_idx + proposal.size(0)]
  1200. # # # print(f'current_features:{current_features.shape}')
  1201. # # split_features.append(current_features)
  1202. # # start_idx += 1
  1203. # #
  1204. # # features_imgs = []
  1205. # # for features_per_img in split_features:
  1206. # # features_per_img, _ = torch.max(features_per_img, dim=0, keepdim=True)
  1207. # # features_imgs.append(features_per_img)
  1208. # #
  1209. # # merged_features = torch.cat(features_imgs, dim=0)
  1210. # # # print(f' merged_features:{merged_features.shape}')
  1211. # # return merged_features
  1212. #
  1213. # def merge_features(features, proposals):
  1214. # print(f'features:{features.shape}')
  1215. # print(f'proposals:{len(proposals)}')
  1216. # def diagnose_input(features, proposals):
  1217. # """诊断输入数据"""
  1218. # print("Input Diagnostics:")
  1219. # print(f"Features type: {type(features)}, shape: {features.shape}")
  1220. # print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
  1221. # for i, p in enumerate(proposals):
  1222. # print(f"Proposal {i} shape: {p.shape}")
  1223. #
  1224. # def validate_inputs(features, proposals):
  1225. # """验证输入的有效性"""
  1226. # if features is None or proposals is None:
  1227. # raise ValueError("Features or proposals cannot be None")
  1228. #
  1229. # proposals_count = sum([p.size(0) for p in proposals])
  1230. # features_size = features.size(0)
  1231. #
  1232. # if proposals_count != features_size:
  1233. # raise ValueError(
  1234. # f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
  1235. # )
  1236. #
  1237. # def safe_max_reduction(features_per_img,proposals):
  1238. #
  1239. # print(f'proposal:{proposals.shape},features_per_img:{features_per_img.shape}')
  1240. # """安全的最大值压缩"""
  1241. # if features_per_img.numel() == 0:
  1242. # return torch.zeros_like(features_per_img).unsqueeze(0)
  1243. #
  1244. # for feature_map,roi in zip(features_per_img,proposals):
  1245. # # print(f'feature_map:{feature_map.shape},roi:{roi}')
  1246. # roi_off_x=roi[0]
  1247. # roi_off_y=roi[1]
  1248. #
  1249. #
  1250. # try:
  1251. # # 沿着第0维求最大值,保持维度
  1252. # max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
  1253. # return max_features
  1254. # except Exception as e:
  1255. # print(f"Max reduction error: {e}")
  1256. # return features_per_img.unsqueeze(0)
  1257. #
  1258. # try:
  1259. # # 诊断输入(可选)
  1260. # # diagnose_input(features, proposals)
  1261. #
  1262. # # 验证输入
  1263. # validate_inputs(features, proposals)
  1264. #
  1265. # # 分割特征
  1266. # split_features = []
  1267. # start_idx = 0
  1268. #
  1269. # for proposal in proposals:
  1270. # # 提取当前图像的特征
  1271. # current_features = features[start_idx:start_idx + proposal.size(0)]
  1272. # split_features.append(current_features)
  1273. # start_idx += proposal.size(0)
  1274. #
  1275. # # 每张图像特征压缩
  1276. # features_imgs = []
  1277. #
  1278. # print(f'split_features:{len(split_features)}')
  1279. # for features_per_img,proposal in zip(split_features,proposals):
  1280. # compressed_features = safe_max_reduction(features_per_img,proposal)
  1281. # features_imgs.append(compressed_features)
  1282. #
  1283. # # 合并特征
  1284. # merged_features = torch.cat(features_imgs, dim=0)
  1285. #
  1286. # return merged_features
  1287. #
  1288. # except Exception as e:
  1289. # print(f"Error in merge_features: {e}")
  1290. # # 返回原始特征或None
  1291. # return features
  1292. #
  1293. '''
  1294. from collections import OrderedDict
  1295. from typing import Dict, List, Optional, Tuple
  1296. import matplotlib.pyplot as plt
  1297. import torch
  1298. import torch.nn.functional as F
  1299. import torchvision
  1300. from torch import nn, Tensor
  1301. from torchvision.ops import boxes as box_ops, roi_align
  1302. from models.wirenet import _utils as det_utils
  1303. from torch.utils.data.dataloader import default_collate
  1304. def l2loss(input, target):
  1305. return ((target - input) ** 2).mean(2).mean(1)
  1306. def cross_entropy_loss(logits, positive):
  1307. nlogp = -F.log_softmax(logits, dim=0)
  1308. return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
  1309. def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
  1310. logp = torch.sigmoid(logits) + offset
  1311. loss = torch.abs(logp - target)
  1312. if mask is not None:
  1313. w = mask.mean(2, True).mean(1, True)
  1314. w[w == 0] = 1
  1315. loss = loss * (mask / w)
  1316. return loss.mean(2).mean(1)
  1317. def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
  1318. # output, feature: head返回结果
  1319. # x, y, idx : line中间生成结果
  1320. result = {}
  1321. batch, channel, row, col = output.shape
  1322. wires_targets = [t["wires"] for t in targets]
  1323. wires_targets = wires_targets.copy()
  1324. # print(f'wires_target:{wires_targets}')
  1325. # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  1326. junc_maps = [d["junc_map"] for d in wires_targets]
  1327. junc_offsets = [d["junc_offset"] for d in wires_targets]
  1328. line_maps = [d["line_map"] for d in wires_targets]
  1329. junc_map_tensor = torch.stack(junc_maps, dim=0)
  1330. junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  1331. line_map_tensor = torch.stack(line_maps, dim=0)
  1332. T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
  1333. n_jtyp = T["junc_map"].shape[1]
  1334. for task in ["junc_map"]:
  1335. T[task] = T[task].permute(1, 0, 2, 3)
  1336. for task in ["junc_offset"]:
  1337. T[task] = T[task].permute(1, 2, 0, 3, 4)
  1338. offset = [2, 3, 5]
  1339. losses = []
  1340. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  1341. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  1342. lmap = output[offset[0]: offset[1]].squeeze(0)
  1343. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  1344. L = OrderedDict()
  1345. L["junc_map"] = sum(
  1346. cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  1347. )
  1348. L["line_map"] = (
  1349. F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  1350. .mean(2)
  1351. .mean(1)
  1352. )
  1353. L["junc_offset"] = sum(
  1354. sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  1355. for i in range(n_jtyp)
  1356. for j in range(2)
  1357. )
  1358. for loss_name in L:
  1359. L[loss_name].mul_(loss_weight[loss_name])
  1360. losses.append(L)
  1361. result["losses"] = losses
  1362. loss = nn.BCEWithLogitsLoss(reduction="none")
  1363. loss = loss(x, y)
  1364. lpos_mask, lneg_mask = y, 1 - y
  1365. loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
  1366. def sum_batch(x):
  1367. xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
  1368. return torch.cat(xs)
  1369. lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
  1370. lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
  1371. result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
  1372. result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
  1373. return result
  1374. def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
  1375. result = {}
  1376. result["wires"] = {}
  1377. p = torch.cat(ps)
  1378. s = torch.sigmoid(input)
  1379. b = s > 0.5
  1380. lines = []
  1381. score = []
  1382. # print(f"n_batch:{n_batch}")
  1383. for i in range(n_batch):
  1384. # print(f"idx:{idx}")
  1385. p0 = p[idx[i]: idx[i + 1]]
  1386. s0 = s[idx[i]: idx[i + 1]]
  1387. mask = b[idx[i]: idx[i + 1]]
  1388. p0 = p0[mask]
  1389. s0 = s0[mask]
  1390. if len(p0) == 0:
  1391. lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
  1392. score.append(torch.zeros([1, n_out_line], device=p.device))
  1393. else:
  1394. arg = torch.argsort(s0, descending=True)
  1395. p0, s0 = p0[arg], s0[arg]
  1396. lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
  1397. score.append(s0[None, torch.arange(n_out_line) % len(s0)])
  1398. for j in range(len(jcs[i])):
  1399. if len(jcs[i][j]) == 0:
  1400. jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
  1401. jcs[i][j] = jcs[i][j][
  1402. None, torch.arange(n_out_junc) % len(jcs[i][j])
  1403. ]
  1404. result["wires"]["lines"] = torch.cat(lines)
  1405. result["wires"]["score"] = torch.cat(score)
  1406. result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
  1407. if len(jcs[i]) > 1:
  1408. result["preds"]["junts"] = torch.cat(
  1409. [jcs[i][1] for i in range(n_batch)]
  1410. )
  1411. return result
  1412. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  1413. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  1414. """
  1415. Computes the loss for Faster R-CNN.
  1416. Args:
  1417. class_logits (Tensor)
  1418. box_regression (Tensor)
  1419. labels (list[BoxList])
  1420. regression_targets (Tensor)
  1421. Returns:
  1422. classification_loss (Tensor)
  1423. box_loss (Tensor)
  1424. """
  1425. labels = torch.cat(labels, dim=0)
  1426. regression_targets = torch.cat(regression_targets, dim=0)
  1427. classification_loss = F.cross_entropy(class_logits, labels)
  1428. # get indices that correspond to the regression targets for
  1429. # the corresponding ground truth labels, to be used with
  1430. # advanced indexing
  1431. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  1432. labels_pos = labels[sampled_pos_inds_subset]
  1433. N, num_classes = class_logits.shape
  1434. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  1435. box_loss = F.smooth_l1_loss(
  1436. box_regression[sampled_pos_inds_subset, labels_pos],
  1437. regression_targets[sampled_pos_inds_subset],
  1438. beta=1 / 9,
  1439. reduction="sum",
  1440. )
  1441. box_loss = box_loss / labels.numel()
  1442. return classification_loss, box_loss
  1443. def maskrcnn_inference(x, labels):
  1444. # type: (Tensor, List[Tensor]) -> List[Tensor]
  1445. """
  1446. From the results of the CNN, post process the masks
  1447. by taking the mask corresponding to the class with max
  1448. probability (which are of fixed size and directly output
  1449. by the CNN) and return the masks in the mask field of the BoxList.
  1450. Args:
  1451. x (Tensor): the mask logits
  1452. labels (list[BoxList]): bounding boxes that are used as
  1453. reference, one for ech image
  1454. Returns:
  1455. results (list[BoxList]): one BoxList for each image, containing
  1456. the extra field mask
  1457. """
  1458. mask_prob = x.sigmoid()
  1459. # select masks corresponding to the predicted classes
  1460. num_masks = x.shape[0]
  1461. boxes_per_image = [label.shape[0] for label in labels]
  1462. labels = torch.cat(labels)
  1463. index = torch.arange(num_masks, device=labels.device)
  1464. mask_prob = mask_prob[index, labels][:, None]
  1465. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  1466. return mask_prob
  1467. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  1468. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  1469. """
  1470. Given segmentation masks and the bounding boxes corresponding
  1471. to the location of the masks in the image, this function
  1472. crops and resizes the masks in the position defined by the
  1473. boxes. This prepares the masks for them to be fed to the
  1474. loss computation as the targets.
  1475. """
  1476. matched_idxs = matched_idxs.to(boxes)
  1477. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  1478. gt_masks = gt_masks[:, None].to(rois)
  1479. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  1480. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  1481. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  1482. """
  1483. Args:
  1484. proposals (list[BoxList])
  1485. mask_logits (Tensor)
  1486. targets (list[BoxList])
  1487. Return:
  1488. mask_loss (Tensor): scalar tensor containing the loss
  1489. """
  1490. discretization_size = mask_logits.shape[-1]
  1491. # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
  1492. # print(f'mask discretization_size:{discretization_size}')
  1493. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  1494. # print(f'mask labels:{labels}')
  1495. mask_targets = [
  1496. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  1497. ]
  1498. labels = torch.cat(labels, dim=0)
  1499. # print(f'mask labels1:{labels}')
  1500. mask_targets = torch.cat(mask_targets, dim=0)
  1501. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  1502. # accept empty tensors, so handle it separately
  1503. if mask_targets.numel() == 0:
  1504. return mask_logits.sum() * 0
  1505. # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
  1506. # print(f'mask_targets:{mask_targets}')
  1507. mask_loss = F.binary_cross_entropy_with_logits(
  1508. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  1509. )
  1510. # print(f'mask_loss:{mask_loss}')
  1511. return mask_loss
  1512. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  1513. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  1514. offset_x = rois[:, 0]
  1515. offset_y = rois[:, 1]
  1516. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  1517. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  1518. offset_x = offset_x[:, None]
  1519. offset_y = offset_y[:, None]
  1520. scale_x = scale_x[:, None]
  1521. scale_y = scale_y[:, None]
  1522. x = keypoints[..., 0]
  1523. y = keypoints[..., 1]
  1524. x_boundary_inds = x == rois[:, 2][:, None]
  1525. y_boundary_inds = y == rois[:, 3][:, None]
  1526. x = (x - offset_x) * scale_x
  1527. x = x.floor().long()
  1528. y = (y - offset_y) * scale_y
  1529. y = y.floor().long()
  1530. x[x_boundary_inds] = heatmap_size - 1
  1531. y[y_boundary_inds] = heatmap_size - 1
  1532. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  1533. vis = keypoints[..., 2] > 0
  1534. valid = (valid_loc & vis).long()
  1535. lin_ind = y * heatmap_size + x
  1536. heatmaps = lin_ind * valid
  1537. return heatmaps, valid
  1538. def _onnx_heatmaps_to_keypoints(
  1539. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  1540. ):
  1541. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  1542. width_correction = widths_i / roi_map_width
  1543. height_correction = heights_i / roi_map_height
  1544. roi_map = F.interpolate(
  1545. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  1546. )[:, 0]
  1547. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  1548. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  1549. x_int = pos % w
  1550. y_int = (pos - x_int) // w
  1551. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  1552. dtype=torch.float32
  1553. )
  1554. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  1555. dtype=torch.float32
  1556. )
  1557. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  1558. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  1559. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  1560. xy_preds_i = torch.stack(
  1561. [
  1562. xy_preds_i_0.to(dtype=torch.float32),
  1563. xy_preds_i_1.to(dtype=torch.float32),
  1564. xy_preds_i_2.to(dtype=torch.float32),
  1565. ],
  1566. 0,
  1567. )
  1568. # TODO: simplify when indexing without rank will be supported by ONNX
  1569. base = num_keypoints * num_keypoints + num_keypoints + 1
  1570. ind = torch.arange(num_keypoints)
  1571. ind = ind.to(dtype=torch.int64) * base
  1572. end_scores_i = (
  1573. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  1574. .index_select(2, x_int.to(dtype=torch.int64))
  1575. .view(-1)
  1576. .index_select(0, ind.to(dtype=torch.int64))
  1577. )
  1578. return xy_preds_i, end_scores_i
  1579. @torch.jit._script_if_tracing
  1580. def _onnx_heatmaps_to_keypoints_loop(
  1581. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  1582. ):
  1583. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  1584. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  1585. for i in range(int(rois.size(0))):
  1586. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  1587. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  1588. )
  1589. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  1590. end_scores = torch.cat(
  1591. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  1592. )
  1593. return xy_preds, end_scores
  1594. def heatmaps_to_keypoints(maps, rois):
  1595. """Extract predicted keypoint locations from heatmaps. Output has shape
  1596. (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  1597. for each keypoint.
  1598. """
  1599. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  1600. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  1601. # consistency with keypoints_to_heatmap_labels by using the conversion from
  1602. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  1603. # continuous coordinate.
  1604. offset_x = rois[:, 0]
  1605. offset_y = rois[:, 1]
  1606. widths = rois[:, 2] - rois[:, 0]
  1607. heights = rois[:, 3] - rois[:, 1]
  1608. widths = widths.clamp(min=1)
  1609. heights = heights.clamp(min=1)
  1610. widths_ceil = widths.ceil()
  1611. heights_ceil = heights.ceil()
  1612. num_keypoints = maps.shape[1]
  1613. if torchvision._is_tracing():
  1614. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  1615. maps,
  1616. rois,
  1617. widths_ceil,
  1618. heights_ceil,
  1619. widths,
  1620. heights,
  1621. offset_x,
  1622. offset_y,
  1623. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  1624. )
  1625. return xy_preds.permute(0, 2, 1), end_scores
  1626. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  1627. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  1628. for i in range(len(rois)):
  1629. roi_map_width = int(widths_ceil[i].item())
  1630. roi_map_height = int(heights_ceil[i].item())
  1631. width_correction = widths[i] / roi_map_width
  1632. height_correction = heights[i] / roi_map_height
  1633. roi_map = F.interpolate(
  1634. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  1635. )[:, 0]
  1636. # roi_map_probs = scores_to_probs(roi_map.copy())
  1637. w = roi_map.shape[2]
  1638. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  1639. x_int = pos % w
  1640. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  1641. # assert (roi_map_probs[k, y_int, x_int] ==
  1642. # roi_map_probs[k, :, :].max())
  1643. x = (x_int.float() + 0.5) * width_correction
  1644. y = (y_int.float() + 0.5) * height_correction
  1645. xy_preds[i, 0, :] = x + offset_x[i]
  1646. xy_preds[i, 1, :] = y + offset_y[i]
  1647. xy_preds[i, 2, :] = 1
  1648. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  1649. return xy_preds.permute(0, 2, 1), end_scores
  1650. def heatmaps_to_keypoints_new(maps, rois):
  1651. # """Extract predicted keypoint locations from heatmaps. Output has shape
  1652. # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  1653. # for each keypoint.
  1654. # """
  1655. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  1656. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  1657. # consistency with keypoints_to_heatmap_labels by using the conversion from
  1658. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  1659. # continuous coordinate.
  1660. print(f"maps.shape:{maps.shape}")
  1661. rois = rois[0]
  1662. offset_x = rois[:, 0]
  1663. offset_y = rois[:, 1]
  1664. widths = rois[:, 2] - rois[:, 0]
  1665. heights = rois[:, 3] - rois[:, 1]
  1666. widths = widths.clamp(min=1)
  1667. heights = heights.clamp(min=1)
  1668. widths_ceil = widths.ceil()
  1669. heights_ceil = heights.ceil()
  1670. num_keypoints = maps.shape[1]
  1671. if torchvision._is_tracing():
  1672. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  1673. maps,
  1674. rois,
  1675. widths_ceil,
  1676. heights_ceil,
  1677. widths,
  1678. heights,
  1679. offset_x,
  1680. offset_y,
  1681. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  1682. )
  1683. return xy_preds.permute(0, 2, 1), end_scores
  1684. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  1685. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  1686. # 创建一个 512x512 的空白图像
  1687. # combined_map = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
  1688. combined_map = torch.zeros((len(rois), maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
  1689. combined_mask = torch.zeros((1, 1, 512, 512), dtype=torch.float32, device=maps.device)
  1690. print(f"combined_map.shape: {combined_map.shape}")
  1691. print(f"len of rois:{len(rois)}")
  1692. for i in range(len(rois)):
  1693. roi_map_width = int(widths_ceil[i].item())
  1694. roi_map_height = int(heights_ceil[i].item())
  1695. width_correction = widths[i] / roi_map_width
  1696. height_correction = heights[i] / roi_map_height
  1697. roi_map = F.interpolate(
  1698. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  1699. )[:, 0]
  1700. x_offset = int(offset_x[i].item()) # 转换为标量
  1701. y_offset = int(offset_y[i].item()) # 转换为标量
  1702. # print(f"x_offset: {x_offset}, y_offset: {y_offset}, roi_map.shape: {roi_map.shape}")
  1703. # 检查偏移量是否合理
  1704. if y_offset < 0 or y_offset + roi_map.shape[1] > combined_map.shape[2] or x_offset < 0 or x_offset + \
  1705. roi_map.shape[2] > combined_map.shape[3]:
  1706. print("Error: Offset exceeds combined_map dimensions.")
  1707. else:
  1708. # 检查 roi_map 的大小
  1709. if roi_map.shape[1] <= 0 or roi_map.shape[2] <= 0:
  1710. print("Error: Invalid ROI size.")
  1711. else:
  1712. # 填充 combined_map
  1713. # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  1714. # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = \
  1715. # torch.max(
  1716. # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]],
  1717. # roi_map)
  1718. combined_map[i, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  1719. roi_mask = torch.ones((1, roi_map_height, roi_map_width), dtype=torch.float32, device=maps.device)
  1720. combined_mask[0, 0, y_offset:y_offset + roi_map_height, x_offset:x_offset + roi_map_width] = roi_mask
  1721. # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  1722. w = roi_map.shape[2]
  1723. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  1724. x_int = pos % w
  1725. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  1726. # assert (roi_map_probs[k, y_int, x_int] ==
  1727. # roi_map_probs[k, :, :].max())
  1728. x = (x_int.float() + 0.5) * width_correction
  1729. y = (y_int.float() + 0.5) * height_correction
  1730. xy_preds[i, 0, :] = x + offset_x[i]
  1731. xy_preds[i, 1, :] = y + offset_y[i]
  1732. xy_preds[i, 2, :] = 1
  1733. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  1734. combined_map_final, _ = torch.max(combined_map, dim=0, keepdim=True)
  1735. combined_map1 = F.interpolate(combined_map_final, size=(128, 128), mode='bilinear', align_corners=False)
  1736. # print(f"combined_map.shape:{combined_map1.shape}")
  1737. combined_mask = F.interpolate(combined_mask, size=(128, 128), mode='bilinear', align_corners=False)
  1738. combined_mask = (combined_mask >= 0.5).float() # 应用阈值0.5
  1739. return combined_map1, xy_preds.permute(0, 2, 1), end_scores, combined_mask
  1740. # def heatmaps_to_keypoints_new(maps, rois):
  1741. # # """Extract predicted keypoint locations from heatmaps. Output has shape
  1742. # # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  1743. # # for each keypoint.
  1744. # # """
  1745. # # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  1746. # # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  1747. # # consistency with keypoints_to_heatmap_labels by using the conversion from
  1748. # # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  1749. # # continuous coordinate.
  1750. # print(f"maps.shape:{maps.shape}")
  1751. # rois = rois[0]
  1752. # offset_x = rois[:, 0]
  1753. # offset_y = rois[:, 1]
  1754. #
  1755. # widths = rois[:, 2] - rois[:, 0]
  1756. # heights = rois[:, 3] - rois[:, 1]
  1757. # widths = widths.clamp(min=1)
  1758. # heights = heights.clamp(min=1)
  1759. # widths_ceil = widths.ceil()
  1760. # heights_ceil = heights.ceil()
  1761. #
  1762. # num_keypoints = maps.shape[1]
  1763. #
  1764. # if torchvision._is_tracing():
  1765. # xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  1766. # maps,
  1767. # rois,
  1768. # widths_ceil,
  1769. # heights_ceil,
  1770. # widths,
  1771. # heights,
  1772. # offset_x,
  1773. # offset_y,
  1774. # torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  1775. # )
  1776. # return xy_preds.permute(0, 2, 1), end_scores
  1777. #
  1778. # xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  1779. # end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  1780. # # 创建一个 512x512 的空白图像
  1781. #
  1782. # # combined_map = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
  1783. # combined = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
  1784. # combined_map = torch.zeros((len(rois), maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
  1785. # combined_mask = torch.zeros((1, 1, 512, 512), dtype=torch.float32, device=maps.device)
  1786. #
  1787. # print(f"combined_map.shape: {combined_map.shape}")
  1788. # print(f"len of rois:{len(rois)}")
  1789. # for i in range(len(rois)):
  1790. # roi_map_width = int(widths_ceil[i].item())
  1791. # roi_map_height = int(heights_ceil[i].item())
  1792. # width_correction = widths[i] / roi_map_width
  1793. # height_correction = heights[i] / roi_map_height
  1794. # roi_map = F.interpolate(
  1795. # maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  1796. # )[:, 0]
  1797. # x_offset = int(offset_x[i].item()) # 转换为标量
  1798. # y_offset = int(offset_y[i].item()) # 转换为标量
  1799. # # print(f"x_offset: {x_offset}, y_offset: {y_offset}, roi_map.shape: {roi_map.shape}")
  1800. # # 检查偏移量是否合理
  1801. # if y_offset < 0 or y_offset + roi_map.shape[1] > combined_map.shape[2] or x_offset < 0 or x_offset + \
  1802. # roi_map.shape[2] > combined_map.shape[3]:
  1803. # print("Error: Offset exceeds combined_map dimensions.")
  1804. # else:
  1805. # # 检查 roi_map 的大小
  1806. # if roi_map.shape[1] <= 0 or roi_map.shape[2] <= 0:
  1807. # print("Error: Invalid ROI size.")
  1808. # else:
  1809. # # 填充 combined_map
  1810. # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  1811. # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = \
  1812. # # torch.max(
  1813. # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]],
  1814. # # roi_map)
  1815. # combined[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = \
  1816. # torch.max(
  1817. # combined[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]],
  1818. # roi_map)
  1819. #
  1820. #
  1821. # combined_map[i, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  1822. #
  1823. # roi_mask = torch.ones((1, roi_map_height, roi_map_width), dtype=torch.float32, device=maps.device)
  1824. # combined_mask[0, 0, y_offset:y_offset + roi_map_height, x_offset:x_offset + roi_map_width] = roi_mask
  1825. #
  1826. # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  1827. # w = roi_map.shape[2]
  1828. # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  1829. #
  1830. # x_int = pos % w
  1831. # y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  1832. # # assert (roi_map_probs[k, y_int, x_int] ==
  1833. # # roi_map_probs[k, :, :].max())
  1834. # x = (x_int.float() + 0.5) * width_correction
  1835. # y = (y_int.float() + 0.5) * height_correction
  1836. # xy_preds[i, 0, :] = x + offset_x[i]
  1837. # xy_preds[i, 1, :] = y + offset_y[i]
  1838. # xy_preds[i, 2, :] = 1
  1839. # end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  1840. # combined_map_final, _ = torch.max(combined_map, dim=0, keepdim=True)
  1841. # print(f"判断:{torch.equal(combined_map_final,combined)}")
  1842. # # print(f"combined_map_final:{combined_map_final.shape}")
  1843. # combined_map1 = F.interpolate(combined_map_final, size=(128, 128), mode='bilinear', align_corners=False)
  1844. # # print(f"combined_map.shape:{combined_map1.shape}")
  1845. #
  1846. # combined_mask = F.interpolate(combined_mask, size=(128, 128), mode='bilinear', align_corners=False)
  1847. # combined_mask = (combined_mask >= 0.5).float() # 应用阈值0.5
  1848. #
  1849. # return combined_map1, xy_preds.permute(0, 2, 1), end_scores, combined_mask
  1850. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  1851. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  1852. N, K, H, W = keypoint_logits.shape
  1853. if H != W:
  1854. raise ValueError(
  1855. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  1856. )
  1857. discretization_size = H
  1858. heatmaps = []
  1859. valid = []
  1860. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  1861. kp = gt_kp_in_image[midx]
  1862. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  1863. heatmaps.append(heatmaps_per_image.view(-1))
  1864. valid.append(valid_per_image.view(-1))
  1865. keypoint_targets = torch.cat(heatmaps, dim=0)
  1866. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  1867. valid = torch.where(valid)[0]
  1868. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  1869. # accept empty tensors, so handle it sepaartely
  1870. if keypoint_targets.numel() == 0 or len(valid) == 0:
  1871. return keypoint_logits.sum() * 0
  1872. keypoint_logits = keypoint_logits.view(N * K, H * W)
  1873. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  1874. return keypoint_loss
  1875. def keypointrcnn_inference(x, boxes):
  1876. # print(f'x:{x.shape}')
  1877. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  1878. kp_probs = []
  1879. kp_scores = []
  1880. boxes_per_image = [box.size(0) for box in boxes]
  1881. x2 = x.split(boxes_per_image, dim=0)
  1882. # print(f'x2:{x2}')
  1883. for xx, bb in zip(x2, boxes):
  1884. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  1885. kp_probs.append(kp_prob)
  1886. kp_scores.append(scores)
  1887. return kp_probs, kp_scores
  1888. def _onnx_expand_boxes(boxes, scale):
  1889. # type: (Tensor, float) -> Tensor
  1890. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  1891. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  1892. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  1893. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  1894. w_half = w_half.to(dtype=torch.float32) * scale
  1895. h_half = h_half.to(dtype=torch.float32) * scale
  1896. boxes_exp0 = x_c - w_half
  1897. boxes_exp1 = y_c - h_half
  1898. boxes_exp2 = x_c + w_half
  1899. boxes_exp3 = y_c + h_half
  1900. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  1901. return boxes_exp
  1902. # the next two functions should be merged inside Masker
  1903. # but are kept here for the moment while we need them
  1904. # temporarily for paste_mask_in_image
  1905. def expand_boxes(boxes, scale):
  1906. # type: (Tensor, float) -> Tensor
  1907. if torchvision._is_tracing():
  1908. return _onnx_expand_boxes(boxes, scale)
  1909. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  1910. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  1911. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  1912. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  1913. w_half *= scale
  1914. h_half *= scale
  1915. boxes_exp = torch.zeros_like(boxes)
  1916. boxes_exp[:, 0] = x_c - w_half
  1917. boxes_exp[:, 2] = x_c + w_half
  1918. boxes_exp[:, 1] = y_c - h_half
  1919. boxes_exp[:, 3] = y_c + h_half
  1920. return boxes_exp
  1921. @torch.jit.unused
  1922. def expand_masks_tracing_scale(M, padding):
  1923. # type: (int, int) -> float
  1924. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  1925. def expand_masks(mask, padding):
  1926. # type: (Tensor, int) -> Tuple[Tensor, float]
  1927. M = mask.shape[-1]
  1928. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  1929. scale = expand_masks_tracing_scale(M, padding)
  1930. else:
  1931. scale = float(M + 2 * padding) / M
  1932. padded_mask = F.pad(mask, (padding,) * 4)
  1933. return padded_mask, scale
  1934. def paste_mask_in_image(mask, box, im_h, im_w):
  1935. # type: (Tensor, Tensor, int, int) -> Tensor
  1936. TO_REMOVE = 1
  1937. w = int(box[2] - box[0] + TO_REMOVE)
  1938. h = int(box[3] - box[1] + TO_REMOVE)
  1939. w = max(w, 1)
  1940. h = max(h, 1)
  1941. # Set shape to [batchxCxHxW]
  1942. mask = mask.expand((1, 1, -1, -1))
  1943. # Resize mask
  1944. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  1945. mask = mask[0][0]
  1946. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  1947. x_0 = max(box[0], 0)
  1948. x_1 = min(box[2] + 1, im_w)
  1949. y_0 = max(box[1], 0)
  1950. y_1 = min(box[3] + 1, im_h)
  1951. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  1952. return im_mask
  1953. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  1954. one = torch.ones(1, dtype=torch.int64)
  1955. zero = torch.zeros(1, dtype=torch.int64)
  1956. w = box[2] - box[0] + one
  1957. h = box[3] - box[1] + one
  1958. w = torch.max(torch.cat((w, one)))
  1959. h = torch.max(torch.cat((h, one)))
  1960. # Set shape to [batchxCxHxW]
  1961. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  1962. # Resize mask
  1963. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  1964. mask = mask[0][0]
  1965. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  1966. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  1967. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  1968. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  1969. unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  1970. # TODO : replace below with a dynamic padding when support is added in ONNX
  1971. # pad y
  1972. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  1973. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  1974. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  1975. # pad x
  1976. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  1977. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  1978. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  1979. return im_mask
  1980. @torch.jit._script_if_tracing
  1981. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  1982. res_append = torch.zeros(0, im_h, im_w)
  1983. for i in range(masks.size(0)):
  1984. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  1985. mask_res = mask_res.unsqueeze(0)
  1986. res_append = torch.cat((res_append, mask_res))
  1987. return res_append
  1988. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  1989. # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  1990. masks, scale = expand_masks(masks, padding=padding)
  1991. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  1992. im_h, im_w = img_shape
  1993. if torchvision._is_tracing():
  1994. return _onnx_paste_masks_in_image_loop(
  1995. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  1996. )[:, None]
  1997. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  1998. if len(res) > 0:
  1999. ret = torch.stack(res, dim=0)[:, None]
  2000. else:
  2001. ret = masks.new_empty((0, 1, im_h, im_w))
  2002. return ret
  2003. class RoIHeads(nn.Module):
  2004. __annotations__ = {
  2005. "box_coder": det_utils.BoxCoder,
  2006. "proposal_matcher": det_utils.Matcher,
  2007. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  2008. }
  2009. def __init__(
  2010. self,
  2011. box_roi_pool,
  2012. box_head,
  2013. box_predictor,
  2014. # Faster R-CNN training
  2015. fg_iou_thresh,
  2016. bg_iou_thresh,
  2017. batch_size_per_image,
  2018. positive_fraction,
  2019. bbox_reg_weights,
  2020. # Faster R-CNN inference
  2021. score_thresh,
  2022. nms_thresh,
  2023. detections_per_img,
  2024. # Mask
  2025. mask_roi_pool=None,
  2026. mask_head=None,
  2027. mask_predictor=None,
  2028. keypoint_roi_pool=None,
  2029. keypoint_head=None,
  2030. keypoint_predictor=None,
  2031. wirepoint_roi_pool=None,
  2032. wirepoint_head=None,
  2033. wirepoint_predictor=None,
  2034. ):
  2035. super().__init__()
  2036. self.box_similarity = box_ops.box_iou
  2037. # assign ground-truth boxes for each proposal
  2038. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  2039. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  2040. if bbox_reg_weights is None:
  2041. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  2042. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  2043. self.box_roi_pool = box_roi_pool
  2044. self.box_head = box_head
  2045. self.box_predictor = box_predictor
  2046. self.score_thresh = score_thresh
  2047. self.nms_thresh = nms_thresh
  2048. self.detections_per_img = detections_per_img
  2049. self.mask_roi_pool = mask_roi_pool
  2050. self.mask_head = mask_head
  2051. self.mask_predictor = mask_predictor
  2052. self.keypoint_roi_pool = keypoint_roi_pool
  2053. self.keypoint_head = keypoint_head
  2054. self.keypoint_predictor = keypoint_predictor
  2055. self.wirepoint_roi_pool = wirepoint_roi_pool
  2056. self.wirepoint_head = wirepoint_head
  2057. self.wirepoint_predictor = wirepoint_predictor
  2058. def has_mask(self):
  2059. if self.mask_roi_pool is None:
  2060. return False
  2061. if self.mask_head is None:
  2062. return False
  2063. if self.mask_predictor is None:
  2064. return False
  2065. return True
  2066. def has_keypoint(self):
  2067. if self.keypoint_roi_pool is None:
  2068. return False
  2069. if self.keypoint_head is None:
  2070. return False
  2071. if self.keypoint_predictor is None:
  2072. return False
  2073. return True
  2074. def has_wirepoint(self):
  2075. if self.wirepoint_roi_pool is None:
  2076. print(f'wirepoint_roi_pool is None')
  2077. return False
  2078. if self.wirepoint_head is None:
  2079. print(f'wirepoint_head is None')
  2080. return False
  2081. if self.wirepoint_predictor is None:
  2082. print(f'wirepoint_roi_predictor is None')
  2083. return False
  2084. return True
  2085. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  2086. # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  2087. matched_idxs = []
  2088. labels = []
  2089. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  2090. if gt_boxes_in_image.numel() == 0:
  2091. # Background image
  2092. device = proposals_in_image.device
  2093. clamped_matched_idxs_in_image = torch.zeros(
  2094. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  2095. )
  2096. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  2097. else:
  2098. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  2099. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  2100. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  2101. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  2102. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  2103. labels_in_image = labels_in_image.to(dtype=torch.int64)
  2104. # Label background (below the low threshold)
  2105. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  2106. labels_in_image[bg_inds] = 0
  2107. # Label ignore proposals (between low and high thresholds)
  2108. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  2109. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  2110. matched_idxs.append(clamped_matched_idxs_in_image)
  2111. labels.append(labels_in_image)
  2112. return matched_idxs, labels
  2113. def subsample(self, labels):
  2114. # type: (List[Tensor]) -> List[Tensor]
  2115. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  2116. sampled_inds = []
  2117. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  2118. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  2119. sampled_inds.append(img_sampled_inds)
  2120. return sampled_inds
  2121. def add_gt_proposals(self, proposals, gt_boxes):
  2122. # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  2123. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  2124. return proposals
  2125. def check_targets(self, targets):
  2126. # type: (Optional[List[Dict[str, Tensor]]]) -> None
  2127. if targets is None:
  2128. raise ValueError("targets should not be None")
  2129. if not all(["boxes" in t for t in targets]):
  2130. raise ValueError("Every element of targets should have a boxes key")
  2131. if not all(["labels" in t for t in targets]):
  2132. raise ValueError("Every element of targets should have a labels key")
  2133. if self.has_mask():
  2134. if not all(["masks" in t for t in targets]):
  2135. raise ValueError("Every element of targets should have a masks key")
  2136. def select_training_samples(
  2137. self,
  2138. proposals, # type: List[Tensor]
  2139. targets, # type: Optional[List[Dict[str, Tensor]]]
  2140. ):
  2141. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  2142. self.check_targets(targets)
  2143. if targets is None:
  2144. raise ValueError("targets should not be None")
  2145. dtype = proposals[0].dtype
  2146. device = proposals[0].device
  2147. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  2148. gt_labels = [t["labels"] for t in targets]
  2149. # append ground-truth bboxes to propos
  2150. proposals = self.add_gt_proposals(proposals, gt_boxes)
  2151. # get matching gt indices for each proposal
  2152. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  2153. # sample a fixed proportion of positive-negative proposals
  2154. sampled_inds = self.subsample(labels)
  2155. matched_gt_boxes = []
  2156. num_images = len(proposals)
  2157. for img_id in range(num_images):
  2158. img_sampled_inds = sampled_inds[img_id]
  2159. proposals[img_id] = proposals[img_id][img_sampled_inds]
  2160. labels[img_id] = labels[img_id][img_sampled_inds]
  2161. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  2162. gt_boxes_in_image = gt_boxes[img_id]
  2163. if gt_boxes_in_image.numel() == 0:
  2164. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  2165. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  2166. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  2167. return proposals, matched_idxs, labels, regression_targets
  2168. def postprocess_detections(
  2169. self,
  2170. class_logits, # type: Tensor
  2171. box_regression, # type: Tensor
  2172. proposals, # type: List[Tensor]
  2173. image_shapes, # type: List[Tuple[int, int]]
  2174. ):
  2175. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  2176. device = class_logits.device
  2177. num_classes = class_logits.shape[-1]
  2178. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  2179. pred_boxes = self.box_coder.decode(box_regression, proposals)
  2180. pred_scores = F.softmax(class_logits, -1)
  2181. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  2182. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  2183. all_boxes = []
  2184. all_scores = []
  2185. all_labels = []
  2186. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  2187. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  2188. # create labels for each prediction
  2189. labels = torch.arange(num_classes, device=device)
  2190. labels = labels.view(1, -1).expand_as(scores)
  2191. # remove predictions with the background label
  2192. boxes = boxes[:, 1:]
  2193. scores = scores[:, 1:]
  2194. labels = labels[:, 1:]
  2195. # batch everything, by making every class prediction be a separate instance
  2196. boxes = boxes.reshape(-1, 4)
  2197. scores = scores.reshape(-1)
  2198. labels = labels.reshape(-1)
  2199. # remove low scoring boxes
  2200. inds = torch.where(scores > self.score_thresh)[0]
  2201. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  2202. # remove empty boxes
  2203. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  2204. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  2205. # non-maximum suppression, independently done per class
  2206. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  2207. # keep only topk scoring predictions
  2208. keep = keep[: self.detections_per_img]
  2209. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  2210. all_boxes.append(boxes)
  2211. all_scores.append(scores)
  2212. all_labels.append(labels)
  2213. return all_boxes, all_scores, all_labels
  2214. def forward(
  2215. self,
  2216. features, # type: Dict[str, Tensor]
  2217. proposals, # type: List[Tensor]
  2218. image_shapes, # type: List[Tuple[int, int]]
  2219. targets=None, # type: Optional[List[Dict[str, Tensor]]]
  2220. ):
  2221. # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  2222. """
  2223. Args:
  2224. features (List[Tensor])
  2225. proposals (List[Tensor[N, 4]])
  2226. image_shapes (List[Tuple[H, W]])
  2227. targets (List[Dict])
  2228. """
  2229. if targets is not None:
  2230. for t in targets:
  2231. # TODO: https://github.com/pytorch/pytorch/issues/26731
  2232. floating_point_types = (torch.float, torch.double, torch.half)
  2233. if not t["boxes"].dtype in floating_point_types:
  2234. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  2235. if not t["labels"].dtype == torch.int64:
  2236. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  2237. if self.has_keypoint():
  2238. if not t["keypoints"].dtype == torch.float32:
  2239. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  2240. print(f"proposals len:{proposals[0].shape}")
  2241. if self.training:
  2242. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  2243. else:
  2244. labels = None
  2245. regression_targets = None
  2246. matched_idxs = None
  2247. print(f"proposals:{proposals[0].shape}")
  2248. box_features = self.box_roi_pool(features, proposals, image_shapes)
  2249. box_features = self.box_head(box_features)
  2250. class_logits, box_regression = self.box_predictor(box_features)
  2251. result: List[Dict[str, torch.Tensor]] = []
  2252. losses = {}
  2253. if self.training:
  2254. if labels is None:
  2255. raise ValueError("labels cannot be None")
  2256. if regression_targets is None:
  2257. raise ValueError("regression_targets cannot be None")
  2258. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  2259. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  2260. else:
  2261. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
  2262. num_images = len(boxes)
  2263. for i in range(num_images):
  2264. result.append(
  2265. {
  2266. "boxes": boxes[i],
  2267. "labels": labels[i],
  2268. "scores": scores[i],
  2269. }
  2270. )
  2271. print(f"proposals len:{proposals[0].shape}")
  2272. print(f"boxes len:{boxes[0].shape}")
  2273. print(f"proposals:{proposals}")
  2274. print(f"boxes:{boxes}")
  2275. # 不走这个
  2276. if self.has_mask():
  2277. mask_proposals = [p["boxes"] for p in result]
  2278. if self.training:
  2279. if matched_idxs is None:
  2280. raise ValueError("if in training, matched_idxs should not be None")
  2281. # during training, only focus on positive boxes
  2282. num_images = len(proposals)
  2283. mask_proposals = []
  2284. pos_matched_idxs = []
  2285. for img_id in range(num_images):
  2286. pos = torch.where(labels[img_id] > 0)[0]
  2287. mask_proposals.append(proposals[img_id][pos])
  2288. pos_matched_idxs.append(matched_idxs[img_id][pos])
  2289. else:
  2290. pos_matched_idxs = None
  2291. if self.mask_roi_pool is not None:
  2292. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  2293. mask_features = self.mask_head(mask_features)
  2294. mask_logits = self.mask_predictor(mask_features)
  2295. else:
  2296. raise Exception("Expected mask_roi_pool to be not None")
  2297. loss_mask = {}
  2298. if self.training:
  2299. if targets is None or pos_matched_idxs is None or mask_logits is None:
  2300. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  2301. gt_masks = [t["masks"] for t in targets]
  2302. gt_labels = [t["labels"] for t in targets]
  2303. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  2304. loss_mask = {"loss_mask": rcnn_loss_mask}
  2305. else:
  2306. labels = [r["labels"] for r in result]
  2307. masks_probs = maskrcnn_inference(mask_logits, labels)
  2308. for mask_prob, r in zip(masks_probs, result):
  2309. r["masks"] = mask_prob
  2310. losses.update(loss_mask)
  2311. # keep none checks in if conditional so torchscript will conditionally
  2312. # compile each branch
  2313. if self.has_keypoint():
  2314. keypoint_proposals = [p["boxes"] for p in result]
  2315. if self.training:
  2316. # during training, only focus on positive boxes
  2317. num_images = len(proposals)
  2318. keypoint_proposals = []
  2319. pos_matched_idxs = []
  2320. if matched_idxs is None:
  2321. raise ValueError("if in trainning, matched_idxs should not be None")
  2322. for img_id in range(num_images):
  2323. pos = torch.where(labels[img_id] > 0)[0]
  2324. keypoint_proposals.append(proposals[img_id][pos])
  2325. pos_matched_idxs.append(matched_idxs[img_id][pos])
  2326. else:
  2327. pos_matched_idxs = None
  2328. keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
  2329. # tmp = keypoint_features[0][0]
  2330. # plt.imshow(tmp.detach().numpy())
  2331. # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
  2332. keypoint_features = self.keypoint_head(keypoint_features)
  2333. # print(f'keypoint_features:{keypoint_features.shape}')
  2334. tmp = keypoint_features[0][0]
  2335. plt.imshow(tmp.detach().numpy())
  2336. keypoint_logits = self.keypoint_predictor(keypoint_features)
  2337. # print(f'keypoint_logits:{keypoint_logits.shape}')
  2338. """
  2339. 接wirenet
  2340. """
  2341. loss_keypoint = {}
  2342. if self.training:
  2343. if targets is None or pos_matched_idxs is None:
  2344. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  2345. gt_keypoints = [t["keypoints"] for t in targets]
  2346. rcnn_loss_keypoint = keypointrcnn_loss(
  2347. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  2348. )
  2349. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  2350. else:
  2351. if keypoint_logits is None or keypoint_proposals is None:
  2352. raise ValueError(
  2353. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  2354. )
  2355. keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  2356. for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
  2357. r["keypoints"] = keypoint_prob
  2358. r["keypoints_scores"] = kps
  2359. losses.update(loss_keypoint)
  2360. if self.has_wirepoint():
  2361. wirepoint_proposals = [p["boxes"] for p in result]
  2362. if self.training:
  2363. # during training, only focus on positive boxes
  2364. num_images = len(proposals)
  2365. wirepoint_proposals = []
  2366. pos_matched_idxs = []
  2367. if matched_idxs is None:
  2368. raise ValueError("if in trainning, matched_idxs should not be None")
  2369. for img_id in range(num_images):
  2370. pos = torch.where(labels[img_id] > 0)[0]
  2371. wirepoint_proposals.append(proposals[img_id][pos])
  2372. pos_matched_idxs.append(matched_idxs[img_id][pos])
  2373. else:
  2374. pos_matched_idxs = None
  2375. wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
  2376. outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
  2377. # print(f"wirepoint_proposal:{type(wirepoint_proposals)}")
  2378. # print(f"wirepoint_proposal:{wirepoint_proposals.__len__()}")
  2379. print(f"wirepoint_proposal[0].shape:{wirepoint_proposals[0].shape}")
  2380. # print(f"wirepoint_proposal[0]:{wirepoint_proposals[0]}")
  2381. print(f"wirepoint_features:{wirepoint_features.shape}")
  2382. # outputs = merge_features(outputs, wirepoint_proposals)
  2383. combined_output, xy_preds, end_scores, mask_key = heatmaps_to_keypoints_new(outputs, wirepoint_proposals)
  2384. wire_combined_features, wire_xy_preds, wire_end_scores, wire_mask = heatmaps_to_keypoints_new(
  2385. wirepoint_features, wirepoint_proposals)
  2386. # print(f'combined_output:{combined_output.shape}')
  2387. print(f"wire_combined_features:{wire_combined_features.shape}")
  2388. wirepoint_logits = self.wirepoint_predictor(inputs=combined_output, features=wire_combined_features,
  2389. mask=wire_mask, targets=targets)
  2390. x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
  2391. # print(f'keypoint_features:{wirepoint_features.shape}')
  2392. if self.training:
  2393. if targets is None or pos_matched_idxs is None:
  2394. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  2395. loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
  2396. rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, combined_output, x, y, idx, loss_weight)
  2397. loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
  2398. else:
  2399. pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
  2400. result.append(pred)
  2401. loss_wirepoint = {}
  2402. losses.update(loss_wirepoint)
  2403. return result, losses
  2404. def merge_features(features, proposals):
  2405. print("merge==========================================================================start")
  2406. print(f"Features type: {type(features)}, shape: {features.shape}")
  2407. print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
  2408. print(f"Proposals : {proposals[0].shape},")
  2409. def diagnose_input(features, proposals):
  2410. """诊断输入数据"""
  2411. print("Input Diagnostics:")
  2412. print(f"Features type: {type(features)}, shape: {features.shape}")
  2413. print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
  2414. for i, p in enumerate(proposals):
  2415. print(f"Proposal {i} shape: {p.shape}")
  2416. def validate_inputs(features, proposals):
  2417. """验证输入的有效性"""
  2418. if features is None or proposals is None:
  2419. raise ValueError("Features or proposals cannot be None")
  2420. proposals_count = sum([p.size(0) for p in proposals])
  2421. features_size = features.size(0)
  2422. if proposals_count != features_size:
  2423. raise ValueError(
  2424. f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
  2425. )
  2426. def safe_max_reduction(features_per_img):
  2427. """安全的最大值压缩"""
  2428. if features_per_img.numel() == 0:
  2429. return torch.zeros_like(features_per_img).unsqueeze(0)
  2430. try:
  2431. # 沿着第0维求最大值,保持维度
  2432. max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
  2433. return max_features
  2434. except Exception as e:
  2435. print(f"Max reduction error: {e}")
  2436. return features_per_img.unsqueeze(0)
  2437. try:
  2438. # 诊断输入(可选)
  2439. # diagnose_input(features, proposals)
  2440. # 验证输入
  2441. validate_inputs(features, proposals)
  2442. # 分割特征
  2443. split_features = []
  2444. start_idx = 0
  2445. for proposal in proposals:
  2446. # 提取当前图像的特征
  2447. current_features = features[start_idx:start_idx + proposal.size(0)]
  2448. split_features.append(current_features)
  2449. start_idx += proposal.size(0)
  2450. # 每张图像特征压缩
  2451. features_imgs = []
  2452. for features_per_img in split_features:
  2453. compressed_features = safe_max_reduction(features_per_img)
  2454. features_imgs.append(compressed_features)
  2455. # 合并特征
  2456. merged_features = torch.cat(features_imgs, dim=0)
  2457. return merged_features
  2458. except Exception as e:
  2459. print(f"Error in merge_features: {e}")
  2460. # 返回原始特征或None
  2461. return features
  2462. '''
  2463. from collections import OrderedDict
  2464. from typing import Dict, List, Optional, Tuple
  2465. import matplotlib.pyplot as plt
  2466. import torch
  2467. import torch.nn.functional as F
  2468. import torchvision
  2469. from torch import nn, Tensor
  2470. from torchvision.ops import boxes as box_ops, roi_align
  2471. from models.wirenet import _utils as det_utils
  2472. from torch.utils.data.dataloader import default_collate
  2473. def l2loss(input, target):
  2474. return ((target - input) ** 2).mean(2).mean(1)
  2475. def cross_entropy_loss(logits, positive):
  2476. nlogp = -F.log_softmax(logits, dim=0)
  2477. return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
  2478. def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
  2479. logp = torch.sigmoid(logits) + offset
  2480. loss = torch.abs(logp - target)
  2481. if mask is not None:
  2482. w = mask.mean(2, True).mean(1, True)
  2483. w[w == 0] = 1
  2484. loss = loss * (mask / w)
  2485. return loss.mean(2).mean(1)
  2486. def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
  2487. # output, feature: head返回结果
  2488. # x, y, idx : line中间生成结果
  2489. result = {}
  2490. batch, channel, row, col = output.shape
  2491. wires_targets = [t["wires"] for t in targets]
  2492. wires_targets = wires_targets.copy()
  2493. # print(f'wires_target:{wires_targets}')
  2494. # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
  2495. junc_maps = [d["junc_map"] for d in wires_targets]
  2496. junc_offsets = [d["junc_offset"] for d in wires_targets]
  2497. line_maps = [d["line_map"] for d in wires_targets]
  2498. junc_map_tensor = torch.stack(junc_maps, dim=0)
  2499. junc_offset_tensor = torch.stack(junc_offsets, dim=0)
  2500. line_map_tensor = torch.stack(line_maps, dim=0)
  2501. T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
  2502. n_jtyp = T["junc_map"].shape[1]
  2503. for task in ["junc_map"]:
  2504. T[task] = T[task].permute(1, 0, 2, 3)
  2505. for task in ["junc_offset"]:
  2506. T[task] = T[task].permute(1, 2, 0, 3, 4)
  2507. offset = [2, 3, 5]
  2508. losses = []
  2509. output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
  2510. jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
  2511. lmap = output[offset[0]: offset[1]].squeeze(0)
  2512. joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
  2513. L = OrderedDict()
  2514. L["junc_map"] = sum(
  2515. cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
  2516. )
  2517. L["line_map"] = (
  2518. F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
  2519. .mean(2)
  2520. .mean(1)
  2521. )
  2522. L["junc_offset"] = sum(
  2523. sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
  2524. for i in range(n_jtyp)
  2525. for j in range(2)
  2526. )
  2527. for loss_name in L:
  2528. L[loss_name].mul_(loss_weight[loss_name])
  2529. losses.append(L)
  2530. result["losses"] = losses
  2531. loss = nn.BCEWithLogitsLoss(reduction="none")
  2532. loss = loss(x, y)
  2533. lpos_mask, lneg_mask = y, 1 - y
  2534. loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
  2535. def sum_batch(x):
  2536. xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
  2537. return torch.cat(xs)
  2538. lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
  2539. lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
  2540. result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
  2541. result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
  2542. return result
  2543. def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
  2544. result = {}
  2545. result["wires"] = {}
  2546. p = torch.cat(ps)
  2547. s = torch.sigmoid(input)
  2548. b = s > 0.5
  2549. lines = []
  2550. score = []
  2551. # print(f"n_batch:{n_batch}")
  2552. for i in range(n_batch):
  2553. # print(f"idx:{idx}")
  2554. p0 = p[idx[i]: idx[i + 1]]
  2555. s0 = s[idx[i]: idx[i + 1]]
  2556. mask = b[idx[i]: idx[i + 1]]
  2557. p0 = p0[mask]
  2558. s0 = s0[mask]
  2559. if len(p0) == 0:
  2560. lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
  2561. score.append(torch.zeros([1, n_out_line], device=p.device))
  2562. else:
  2563. arg = torch.argsort(s0, descending=True)
  2564. p0, s0 = p0[arg], s0[arg]
  2565. lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
  2566. score.append(s0[None, torch.arange(n_out_line) % len(s0)])
  2567. for j in range(len(jcs[i])):
  2568. if len(jcs[i][j]) == 0:
  2569. jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
  2570. jcs[i][j] = jcs[i][j][
  2571. None, torch.arange(n_out_junc) % len(jcs[i][j])
  2572. ]
  2573. result["wires"]["lines"] = torch.cat(lines)
  2574. result["wires"]["score"] = torch.cat(score)
  2575. result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
  2576. if len(jcs[i]) > 1:
  2577. result["preds"]["junts"] = torch.cat(
  2578. [jcs[i][1] for i in range(n_batch)]
  2579. )
  2580. return result
  2581. def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
  2582. # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
  2583. """
  2584. Computes the loss for Faster R-CNN.
  2585. Args:
  2586. class_logits (Tensor)
  2587. box_regression (Tensor)
  2588. labels (list[BoxList])
  2589. regression_targets (Tensor)
  2590. Returns:
  2591. classification_loss (Tensor)
  2592. box_loss (Tensor)
  2593. """
  2594. labels = torch.cat(labels, dim=0)
  2595. regression_targets = torch.cat(regression_targets, dim=0)
  2596. classification_loss = F.cross_entropy(class_logits, labels)
  2597. # get indices that correspond to the regression targets for
  2598. # the corresponding ground truth labels, to be used with
  2599. # advanced indexing
  2600. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  2601. labels_pos = labels[sampled_pos_inds_subset]
  2602. N, num_classes = class_logits.shape
  2603. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  2604. box_loss = F.smooth_l1_loss(
  2605. box_regression[sampled_pos_inds_subset, labels_pos],
  2606. regression_targets[sampled_pos_inds_subset],
  2607. beta=1 / 9,
  2608. reduction="sum",
  2609. )
  2610. box_loss = box_loss / labels.numel()
  2611. return classification_loss, box_loss
  2612. def maskrcnn_inference(x, labels):
  2613. # type: (Tensor, List[Tensor]) -> List[Tensor]
  2614. """
  2615. From the results of the CNN, post process the masks
  2616. by taking the mask corresponding to the class with max
  2617. probability (which are of fixed size and directly output
  2618. by the CNN) and return the masks in the mask field of the BoxList.
  2619. Args:
  2620. x (Tensor): the mask logits
  2621. labels (list[BoxList]): bounding boxes that are used as
  2622. reference, one for ech image
  2623. Returns:
  2624. results (list[BoxList]): one BoxList for each image, containing
  2625. the extra field mask
  2626. """
  2627. mask_prob = x.sigmoid()
  2628. # select masks corresponding to the predicted classes
  2629. num_masks = x.shape[0]
  2630. boxes_per_image = [label.shape[0] for label in labels]
  2631. labels = torch.cat(labels)
  2632. index = torch.arange(num_masks, device=labels.device)
  2633. mask_prob = mask_prob[index, labels][:, None]
  2634. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  2635. return mask_prob
  2636. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  2637. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  2638. """
  2639. Given segmentation masks and the bounding boxes corresponding
  2640. to the location of the masks in the image, this function
  2641. crops and resizes the masks in the position defined by the
  2642. boxes. This prepares the masks for them to be fed to the
  2643. loss computation as the targets.
  2644. """
  2645. matched_idxs = matched_idxs.to(boxes)
  2646. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  2647. gt_masks = gt_masks[:, None].to(rois)
  2648. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  2649. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  2650. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  2651. """
  2652. Args:
  2653. proposals (list[BoxList])
  2654. mask_logits (Tensor)
  2655. targets (list[BoxList])
  2656. Return:
  2657. mask_loss (Tensor): scalar tensor containing the loss
  2658. """
  2659. discretization_size = mask_logits.shape[-1]
  2660. # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
  2661. # print(f'mask discretization_size:{discretization_size}')
  2662. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  2663. # print(f'mask labels:{labels}')
  2664. mask_targets = [
  2665. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  2666. ]
  2667. labels = torch.cat(labels, dim=0)
  2668. # print(f'mask labels1:{labels}')
  2669. mask_targets = torch.cat(mask_targets, dim=0)
  2670. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  2671. # accept empty tensors, so handle it separately
  2672. if mask_targets.numel() == 0:
  2673. return mask_logits.sum() * 0
  2674. # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
  2675. # print(f'mask_targets:{mask_targets}')
  2676. mask_loss = F.binary_cross_entropy_with_logits(
  2677. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  2678. )
  2679. # print(f'mask_loss:{mask_loss}')
  2680. return mask_loss
  2681. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  2682. # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
  2683. offset_x = rois[:, 0]
  2684. offset_y = rois[:, 1]
  2685. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  2686. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  2687. offset_x = offset_x[:, None]
  2688. offset_y = offset_y[:, None]
  2689. scale_x = scale_x[:, None]
  2690. scale_y = scale_y[:, None]
  2691. x = keypoints[..., 0]
  2692. y = keypoints[..., 1]
  2693. x_boundary_inds = x == rois[:, 2][:, None]
  2694. y_boundary_inds = y == rois[:, 3][:, None]
  2695. x = (x - offset_x) * scale_x
  2696. x = x.floor().long()
  2697. y = (y - offset_y) * scale_y
  2698. y = y.floor().long()
  2699. x[x_boundary_inds] = heatmap_size - 1
  2700. y[y_boundary_inds] = heatmap_size - 1
  2701. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  2702. vis = keypoints[..., 2] > 0
  2703. valid = (valid_loc & vis).long()
  2704. lin_ind = y * heatmap_size + x
  2705. heatmaps = lin_ind * valid
  2706. return heatmaps, valid
  2707. def _onnx_heatmaps_to_keypoints(
  2708. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  2709. ):
  2710. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  2711. width_correction = widths_i / roi_map_width
  2712. height_correction = heights_i / roi_map_height
  2713. roi_map = F.interpolate(
  2714. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  2715. )[:, 0]
  2716. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  2717. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  2718. x_int = pos % w
  2719. y_int = (pos - x_int) // w
  2720. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  2721. dtype=torch.float32
  2722. )
  2723. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  2724. dtype=torch.float32
  2725. )
  2726. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  2727. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  2728. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  2729. xy_preds_i = torch.stack(
  2730. [
  2731. xy_preds_i_0.to(dtype=torch.float32),
  2732. xy_preds_i_1.to(dtype=torch.float32),
  2733. xy_preds_i_2.to(dtype=torch.float32),
  2734. ],
  2735. 0,
  2736. )
  2737. # TODO: simplify when indexing without rank will be supported by ONNX
  2738. base = num_keypoints * num_keypoints + num_keypoints + 1
  2739. ind = torch.arange(num_keypoints)
  2740. ind = ind.to(dtype=torch.int64) * base
  2741. end_scores_i = (
  2742. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  2743. .index_select(2, x_int.to(dtype=torch.int64))
  2744. .view(-1)
  2745. .index_select(0, ind.to(dtype=torch.int64))
  2746. )
  2747. return xy_preds_i, end_scores_i
  2748. @torch.jit._script_if_tracing
  2749. def _onnx_heatmaps_to_keypoints_loop(
  2750. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  2751. ):
  2752. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  2753. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  2754. for i in range(int(rois.size(0))):
  2755. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  2756. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  2757. )
  2758. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  2759. end_scores = torch.cat(
  2760. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  2761. )
  2762. return xy_preds, end_scores
  2763. def heatmaps_to_keypoints(maps, rois):
  2764. """Extract predicted keypoint locations from heatmaps. Output has shape
  2765. (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  2766. for each keypoint.
  2767. """
  2768. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  2769. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  2770. # consistency with keypoints_to_heatmap_labels by using the conversion from
  2771. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  2772. # continuous coordinate.
  2773. offset_x = rois[:, 0]
  2774. offset_y = rois[:, 1]
  2775. widths = rois[:, 2] - rois[:, 0]
  2776. heights = rois[:, 3] - rois[:, 1]
  2777. widths = widths.clamp(min=1)
  2778. heights = heights.clamp(min=1)
  2779. widths_ceil = widths.ceil()
  2780. heights_ceil = heights.ceil()
  2781. num_keypoints = maps.shape[1]
  2782. if torchvision._is_tracing():
  2783. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  2784. maps,
  2785. rois,
  2786. widths_ceil,
  2787. heights_ceil,
  2788. widths,
  2789. heights,
  2790. offset_x,
  2791. offset_y,
  2792. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  2793. )
  2794. return xy_preds.permute(0, 2, 1), end_scores
  2795. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  2796. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  2797. for i in range(len(rois)):
  2798. roi_map_width = int(widths_ceil[i].item())
  2799. roi_map_height = int(heights_ceil[i].item())
  2800. width_correction = widths[i] / roi_map_width
  2801. height_correction = heights[i] / roi_map_height
  2802. roi_map = F.interpolate(
  2803. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  2804. )[:, 0]
  2805. # roi_map_probs = scores_to_probs(roi_map.copy())
  2806. w = roi_map.shape[2]
  2807. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  2808. x_int = pos % w
  2809. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  2810. # assert (roi_map_probs[k, y_int, x_int] ==
  2811. # roi_map_probs[k, :, :].max())
  2812. x = (x_int.float() + 0.5) * width_correction
  2813. y = (y_int.float() + 0.5) * height_correction
  2814. xy_preds[i, 0, :] = x + offset_x[i]
  2815. xy_preds[i, 1, :] = y + offset_y[i]
  2816. xy_preds[i, 2, :] = 1
  2817. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  2818. return xy_preds.permute(0, 2, 1), end_scores
  2819. import torch
  2820. import torch.nn.functional as F
  2821. def heatmaps_to_keypoints_new(maps, rois):
  2822. # """Extract predicted keypoint locations from heatmaps. Output has shape
  2823. # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
  2824. # for each keypoint.
  2825. # """
  2826. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  2827. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  2828. # consistency with keypoints_to_heatmap_labels by using the conversion from
  2829. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  2830. # continuous coordinate.
  2831. print(f"maps.shape:{maps.shape}")
  2832. rois = rois[0]
  2833. offset_x = rois[:, 0]
  2834. offset_y = rois[:, 1]
  2835. widths = rois[:, 2] - rois[:, 0]
  2836. heights = rois[:, 3] - rois[:, 1]
  2837. widths = widths.clamp(min=1)
  2838. heights = heights.clamp(min=1)
  2839. widths_ceil = widths.ceil()
  2840. heights_ceil = heights.ceil()
  2841. num_keypoints = maps.shape[1]
  2842. if torchvision._is_tracing():
  2843. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  2844. maps,
  2845. rois,
  2846. widths_ceil,
  2847. heights_ceil,
  2848. widths,
  2849. heights,
  2850. offset_x,
  2851. offset_y,
  2852. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  2853. )
  2854. return xy_preds.permute(0, 2, 1), end_scores
  2855. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  2856. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  2857. # 创建一个 512x512 的空白图像
  2858. combined_map = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
  2859. print(f"combined_map.shape: {combined_map.shape}")
  2860. print(f"len of rois:{len(rois)}")
  2861. for i in range(len(rois)):
  2862. roi_map_width = int(widths_ceil[i].item())
  2863. roi_map_height = int(heights_ceil[i].item())
  2864. width_correction = widths[i] / roi_map_width
  2865. height_correction = heights[i] / roi_map_height
  2866. roi_map = F.interpolate(
  2867. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  2868. )[:, 0]
  2869. x_offset = int(offset_x[i].item()) # 转换为标量
  2870. y_offset = int(offset_y[i].item()) # 转换为标量
  2871. # print(f"x_offset: {x_offset}, y_offset: {y_offset}, roi_map.shape: {roi_map.shape}")
  2872. # 检查偏移量是否合理
  2873. if y_offset < 0 or y_offset + roi_map.shape[1] > combined_map.shape[2] or x_offset < 0 or x_offset + roi_map.shape[2] > combined_map.shape[3]:
  2874. print("Error: Offset exceeds combined_map dimensions.")
  2875. else:
  2876. # 检查 roi_map 的大小
  2877. if roi_map.shape[1] <= 0 or roi_map.shape[2] <= 0:
  2878. print("Error: Invalid ROI size.")
  2879. else:
  2880. # 填充 combined_map
  2881. combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  2882. # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
  2883. w = roi_map.shape[2]
  2884. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  2885. x_int = pos % w
  2886. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  2887. # assert (roi_map_probs[k, y_int, x_int] ==
  2888. # roi_map_probs[k, :, :].max())
  2889. x = (x_int.float() + 0.5) * width_correction
  2890. y = (y_int.float() + 0.5) * height_correction
  2891. xy_preds[i, 0, :] = x + offset_x[i]
  2892. xy_preds[i, 1, :] = y + offset_y[i]
  2893. xy_preds[i, 2, :] = 1
  2894. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  2895. combined_map= F.interpolate(combined_map, size=(128, 128), mode='bilinear', align_corners=False)
  2896. print(f"combined_map.shape:{combined_map.shape}")
  2897. return combined_map, xy_preds.permute(0, 2, 1), end_scores
  2898. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  2899. # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
  2900. N, K, H, W = keypoint_logits.shape
  2901. if H != W:
  2902. raise ValueError(
  2903. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  2904. )
  2905. discretization_size = H
  2906. heatmaps = []
  2907. valid = []
  2908. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  2909. kp = gt_kp_in_image[midx]
  2910. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  2911. heatmaps.append(heatmaps_per_image.view(-1))
  2912. valid.append(valid_per_image.view(-1))
  2913. keypoint_targets = torch.cat(heatmaps, dim=0)
  2914. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  2915. valid = torch.where(valid)[0]
  2916. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  2917. # accept empty tensors, so handle it sepaartely
  2918. if keypoint_targets.numel() == 0 or len(valid) == 0:
  2919. return keypoint_logits.sum() * 0
  2920. keypoint_logits = keypoint_logits.view(N * K, H * W)
  2921. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  2922. return keypoint_loss
  2923. def keypointrcnn_inference(x, boxes):
  2924. # print(f'x:{x.shape}')
  2925. # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  2926. kp_probs = []
  2927. kp_scores = []
  2928. boxes_per_image = [box.size(0) for box in boxes]
  2929. x2 = x.split(boxes_per_image, dim=0)
  2930. # print(f'x2:{x2}')
  2931. for xx, bb in zip(x2, boxes):
  2932. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  2933. kp_probs.append(kp_prob)
  2934. kp_scores.append(scores)
  2935. return kp_probs, kp_scores
  2936. def _onnx_expand_boxes(boxes, scale):
  2937. # type: (Tensor, float) -> Tensor
  2938. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  2939. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  2940. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  2941. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  2942. w_half = w_half.to(dtype=torch.float32) * scale
  2943. h_half = h_half.to(dtype=torch.float32) * scale
  2944. boxes_exp0 = x_c - w_half
  2945. boxes_exp1 = y_c - h_half
  2946. boxes_exp2 = x_c + w_half
  2947. boxes_exp3 = y_c + h_half
  2948. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  2949. return boxes_exp
  2950. # the next two functions should be merged inside Masker
  2951. # but are kept here for the moment while we need them
  2952. # temporarily for paste_mask_in_image
  2953. def expand_boxes(boxes, scale):
  2954. # type: (Tensor, float) -> Tensor
  2955. if torchvision._is_tracing():
  2956. return _onnx_expand_boxes(boxes, scale)
  2957. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  2958. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  2959. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  2960. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  2961. w_half *= scale
  2962. h_half *= scale
  2963. boxes_exp = torch.zeros_like(boxes)
  2964. boxes_exp[:, 0] = x_c - w_half
  2965. boxes_exp[:, 2] = x_c + w_half
  2966. boxes_exp[:, 1] = y_c - h_half
  2967. boxes_exp[:, 3] = y_c + h_half
  2968. return boxes_exp
  2969. @torch.jit.unused
  2970. def expand_masks_tracing_scale(M, padding):
  2971. # type: (int, int) -> float
  2972. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  2973. def expand_masks(mask, padding):
  2974. # type: (Tensor, int) -> Tuple[Tensor, float]
  2975. M = mask.shape[-1]
  2976. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  2977. scale = expand_masks_tracing_scale(M, padding)
  2978. else:
  2979. scale = float(M + 2 * padding) / M
  2980. padded_mask = F.pad(mask, (padding,) * 4)
  2981. return padded_mask, scale
  2982. def paste_mask_in_image(mask, box, im_h, im_w):
  2983. # type: (Tensor, Tensor, int, int) -> Tensor
  2984. TO_REMOVE = 1
  2985. w = int(box[2] - box[0] + TO_REMOVE)
  2986. h = int(box[3] - box[1] + TO_REMOVE)
  2987. w = max(w, 1)
  2988. h = max(h, 1)
  2989. # Set shape to [batchxCxHxW]
  2990. mask = mask.expand((1, 1, -1, -1))
  2991. # Resize mask
  2992. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  2993. mask = mask[0][0]
  2994. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  2995. x_0 = max(box[0], 0)
  2996. x_1 = min(box[2] + 1, im_w)
  2997. y_0 = max(box[1], 0)
  2998. y_1 = min(box[3] + 1, im_h)
  2999. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  3000. return im_mask
  3001. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  3002. one = torch.ones(1, dtype=torch.int64)
  3003. zero = torch.zeros(1, dtype=torch.int64)
  3004. w = box[2] - box[0] + one
  3005. h = box[3] - box[1] + one
  3006. w = torch.max(torch.cat((w, one)))
  3007. h = torch.max(torch.cat((h, one)))
  3008. # Set shape to [batchxCxHxW]
  3009. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  3010. # Resize mask
  3011. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  3012. mask = mask[0][0]
  3013. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  3014. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  3015. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  3016. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  3017. unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
  3018. # TODO : replace below with a dynamic padding when support is added in ONNX
  3019. # pad y
  3020. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  3021. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  3022. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  3023. # pad x
  3024. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  3025. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  3026. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  3027. return im_mask
  3028. @torch.jit._script_if_tracing
  3029. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  3030. res_append = torch.zeros(0, im_h, im_w)
  3031. for i in range(masks.size(0)):
  3032. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  3033. mask_res = mask_res.unsqueeze(0)
  3034. res_append = torch.cat((res_append, mask_res))
  3035. return res_append
  3036. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  3037. # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
  3038. masks, scale = expand_masks(masks, padding=padding)
  3039. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  3040. im_h, im_w = img_shape
  3041. if torchvision._is_tracing():
  3042. return _onnx_paste_masks_in_image_loop(
  3043. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  3044. )[:, None]
  3045. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  3046. if len(res) > 0:
  3047. ret = torch.stack(res, dim=0)[:, None]
  3048. else:
  3049. ret = masks.new_empty((0, 1, im_h, im_w))
  3050. return ret
  3051. class RoIHeads(nn.Module):
  3052. __annotations__ = {
  3053. "box_coder": det_utils.BoxCoder,
  3054. "proposal_matcher": det_utils.Matcher,
  3055. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  3056. }
  3057. def __init__(
  3058. self,
  3059. box_roi_pool,
  3060. box_head,
  3061. box_predictor,
  3062. # Faster R-CNN training
  3063. fg_iou_thresh,
  3064. bg_iou_thresh,
  3065. batch_size_per_image,
  3066. positive_fraction,
  3067. bbox_reg_weights,
  3068. # Faster R-CNN inference
  3069. score_thresh,
  3070. nms_thresh,
  3071. detections_per_img,
  3072. # Mask
  3073. mask_roi_pool=None,
  3074. mask_head=None,
  3075. mask_predictor=None,
  3076. keypoint_roi_pool=None,
  3077. keypoint_head=None,
  3078. keypoint_predictor=None,
  3079. wirepoint_roi_pool=None,
  3080. wirepoint_head=None,
  3081. wirepoint_predictor=None,
  3082. ):
  3083. super().__init__()
  3084. self.box_similarity = box_ops.box_iou
  3085. # assign ground-truth boxes for each proposal
  3086. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  3087. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  3088. if bbox_reg_weights is None:
  3089. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  3090. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  3091. self.box_roi_pool = box_roi_pool
  3092. self.box_head = box_head
  3093. self.box_predictor = box_predictor
  3094. self.score_thresh = score_thresh
  3095. self.nms_thresh = nms_thresh
  3096. self.detections_per_img = detections_per_img
  3097. self.mask_roi_pool = mask_roi_pool
  3098. self.mask_head = mask_head
  3099. self.mask_predictor = mask_predictor
  3100. self.keypoint_roi_pool = keypoint_roi_pool
  3101. self.keypoint_head = keypoint_head
  3102. self.keypoint_predictor = keypoint_predictor
  3103. self.wirepoint_roi_pool = wirepoint_roi_pool
  3104. self.wirepoint_head = wirepoint_head
  3105. self.wirepoint_predictor = wirepoint_predictor
  3106. def has_mask(self):
  3107. if self.mask_roi_pool is None:
  3108. return False
  3109. if self.mask_head is None:
  3110. return False
  3111. if self.mask_predictor is None:
  3112. return False
  3113. return True
  3114. def has_keypoint(self):
  3115. if self.keypoint_roi_pool is None:
  3116. return False
  3117. if self.keypoint_head is None:
  3118. return False
  3119. if self.keypoint_predictor is None:
  3120. return False
  3121. return True
  3122. def has_wirepoint(self):
  3123. if self.wirepoint_roi_pool is None:
  3124. print(f'wirepoint_roi_pool is None')
  3125. return False
  3126. if self.wirepoint_head is None:
  3127. print(f'wirepoint_head is None')
  3128. return False
  3129. if self.wirepoint_predictor is None:
  3130. print(f'wirepoint_roi_predictor is None')
  3131. return False
  3132. return True
  3133. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  3134. # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
  3135. matched_idxs = []
  3136. labels = []
  3137. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  3138. if gt_boxes_in_image.numel() == 0:
  3139. # Background image
  3140. device = proposals_in_image.device
  3141. clamped_matched_idxs_in_image = torch.zeros(
  3142. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  3143. )
  3144. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  3145. else:
  3146. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  3147. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  3148. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  3149. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  3150. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  3151. labels_in_image = labels_in_image.to(dtype=torch.int64)
  3152. # Label background (below the low threshold)
  3153. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  3154. labels_in_image[bg_inds] = 0
  3155. # Label ignore proposals (between low and high thresholds)
  3156. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  3157. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  3158. matched_idxs.append(clamped_matched_idxs_in_image)
  3159. labels.append(labels_in_image)
  3160. return matched_idxs, labels
  3161. def subsample(self, labels):
  3162. # type: (List[Tensor]) -> List[Tensor]
  3163. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  3164. sampled_inds = []
  3165. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  3166. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  3167. sampled_inds.append(img_sampled_inds)
  3168. return sampled_inds
  3169. def add_gt_proposals(self, proposals, gt_boxes):
  3170. # type: (List[Tensor], List[Tensor]) -> List[Tensor]
  3171. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  3172. return proposals
  3173. def check_targets(self, targets):
  3174. # type: (Optional[List[Dict[str, Tensor]]]) -> None
  3175. if targets is None:
  3176. raise ValueError("targets should not be None")
  3177. if not all(["boxes" in t for t in targets]):
  3178. raise ValueError("Every element of targets should have a boxes key")
  3179. if not all(["labels" in t for t in targets]):
  3180. raise ValueError("Every element of targets should have a labels key")
  3181. if self.has_mask():
  3182. if not all(["masks" in t for t in targets]):
  3183. raise ValueError("Every element of targets should have a masks key")
  3184. def select_training_samples(
  3185. self,
  3186. proposals, # type: List[Tensor]
  3187. targets, # type: Optional[List[Dict[str, Tensor]]]
  3188. ):
  3189. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
  3190. self.check_targets(targets)
  3191. if targets is None:
  3192. raise ValueError("targets should not be None")
  3193. dtype = proposals[0].dtype
  3194. device = proposals[0].device
  3195. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  3196. gt_labels = [t["labels"] for t in targets]
  3197. # append ground-truth bboxes to propos
  3198. proposals = self.add_gt_proposals(proposals, gt_boxes)
  3199. # get matching gt indices for each proposal
  3200. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  3201. # sample a fixed proportion of positive-negative proposals
  3202. sampled_inds = self.subsample(labels)
  3203. matched_gt_boxes = []
  3204. num_images = len(proposals)
  3205. for img_id in range(num_images):
  3206. img_sampled_inds = sampled_inds[img_id]
  3207. proposals[img_id] = proposals[img_id][img_sampled_inds]
  3208. labels[img_id] = labels[img_id][img_sampled_inds]
  3209. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  3210. gt_boxes_in_image = gt_boxes[img_id]
  3211. if gt_boxes_in_image.numel() == 0:
  3212. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  3213. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  3214. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  3215. return proposals, matched_idxs, labels, regression_targets
  3216. def postprocess_detections(
  3217. self,
  3218. class_logits, # type: Tensor
  3219. box_regression, # type: Tensor
  3220. proposals, # type: List[Tensor]
  3221. image_shapes, # type: List[Tuple[int, int]]
  3222. ):
  3223. # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
  3224. device = class_logits.device
  3225. num_classes = class_logits.shape[-1]
  3226. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  3227. pred_boxes = self.box_coder.decode(box_regression, proposals)
  3228. pred_scores = F.softmax(class_logits, -1)
  3229. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  3230. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  3231. all_boxes = []
  3232. all_scores = []
  3233. all_labels = []
  3234. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  3235. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  3236. # create labels for each prediction
  3237. labels = torch.arange(num_classes, device=device)
  3238. labels = labels.view(1, -1).expand_as(scores)
  3239. # remove predictions with the background label
  3240. boxes = boxes[:, 1:]
  3241. scores = scores[:, 1:]
  3242. labels = labels[:, 1:]
  3243. # batch everything, by making every class prediction be a separate instance
  3244. boxes = boxes.reshape(-1, 4)
  3245. scores = scores.reshape(-1)
  3246. labels = labels.reshape(-1)
  3247. # remove low scoring boxes
  3248. inds = torch.where(scores > self.score_thresh)[0]
  3249. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  3250. # remove empty boxes
  3251. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  3252. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  3253. # non-maximum suppression, independently done per class
  3254. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  3255. # keep only topk scoring predictions
  3256. keep = keep[: self.detections_per_img]
  3257. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  3258. all_boxes.append(boxes)
  3259. all_scores.append(scores)
  3260. all_labels.append(labels)
  3261. return all_boxes, all_scores, all_labels
  3262. def forward(
  3263. self,
  3264. features, # type: Dict[str, Tensor]
  3265. proposals, # type: List[Tensor]
  3266. image_shapes, # type: List[Tuple[int, int]]
  3267. targets=None, # type: Optional[List[Dict[str, Tensor]]]
  3268. ):
  3269. # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
  3270. """
  3271. Args:
  3272. features (List[Tensor])
  3273. proposals (List[Tensor[N, 4]])
  3274. image_shapes (List[Tuple[H, W]])
  3275. targets (List[Dict])
  3276. """
  3277. if targets is not None:
  3278. for t in targets:
  3279. # TODO: https://github.com/pytorch/pytorch/issues/26731
  3280. floating_point_types = (torch.float, torch.double, torch.half)
  3281. if not t["boxes"].dtype in floating_point_types:
  3282. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  3283. if not t["labels"].dtype == torch.int64:
  3284. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  3285. if self.has_keypoint():
  3286. if not t["keypoints"].dtype == torch.float32:
  3287. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  3288. if self.training:
  3289. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  3290. else:
  3291. labels = None
  3292. regression_targets = None
  3293. matched_idxs = None
  3294. box_features = self.box_roi_pool(features, proposals, image_shapes)
  3295. box_features = self.box_head(box_features)
  3296. class_logits, box_regression = self.box_predictor(box_features)
  3297. result: List[Dict[str, torch.Tensor]] = []
  3298. losses = {}
  3299. if self.training:
  3300. if labels is None:
  3301. raise ValueError("labels cannot be None")
  3302. if regression_targets is None:
  3303. raise ValueError("regression_targets cannot be None")
  3304. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  3305. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  3306. else:
  3307. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
  3308. num_images = len(boxes)
  3309. for i in range(num_images):
  3310. result.append(
  3311. {
  3312. "boxes": boxes[i],
  3313. "labels": labels[i],
  3314. "scores": scores[i],
  3315. }
  3316. )
  3317. if self.has_mask():
  3318. mask_proposals = [p["boxes"] for p in result]
  3319. if self.training:
  3320. if matched_idxs is None:
  3321. raise ValueError("if in training, matched_idxs should not be None")
  3322. # during training, only focus on positive boxes
  3323. num_images = len(proposals)
  3324. mask_proposals = []
  3325. pos_matched_idxs = []
  3326. for img_id in range(num_images):
  3327. pos = torch.where(labels[img_id] > 0)[0]
  3328. mask_proposals.append(proposals[img_id][pos])
  3329. pos_matched_idxs.append(matched_idxs[img_id][pos])
  3330. else:
  3331. pos_matched_idxs = None
  3332. if self.mask_roi_pool is not None:
  3333. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  3334. mask_features = self.mask_head(mask_features)
  3335. mask_logits = self.mask_predictor(mask_features)
  3336. else:
  3337. raise Exception("Expected mask_roi_pool to be not None")
  3338. loss_mask = {}
  3339. if self.training:
  3340. if targets is None or pos_matched_idxs is None or mask_logits is None:
  3341. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  3342. gt_masks = [t["masks"] for t in targets]
  3343. gt_labels = [t["labels"] for t in targets]
  3344. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  3345. loss_mask = {"loss_mask": rcnn_loss_mask}
  3346. else:
  3347. labels = [r["labels"] for r in result]
  3348. masks_probs = maskrcnn_inference(mask_logits, labels)
  3349. for mask_prob, r in zip(masks_probs, result):
  3350. r["masks"] = mask_prob
  3351. losses.update(loss_mask)
  3352. # keep none checks in if conditional so torchscript will conditionally
  3353. # compile each branch
  3354. if self.has_keypoint():
  3355. keypoint_proposals = [p["boxes"] for p in result]
  3356. if self.training:
  3357. # during training, only focus on positive boxes
  3358. num_images = len(proposals)
  3359. keypoint_proposals = []
  3360. pos_matched_idxs = []
  3361. if matched_idxs is None:
  3362. raise ValueError("if in trainning, matched_idxs should not be None")
  3363. for img_id in range(num_images):
  3364. pos = torch.where(labels[img_id] > 0)[0]
  3365. keypoint_proposals.append(proposals[img_id][pos])
  3366. pos_matched_idxs.append(matched_idxs[img_id][pos])
  3367. else:
  3368. pos_matched_idxs = None
  3369. keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
  3370. # tmp = keypoint_features[0][0]
  3371. # plt.imshow(tmp.detach().numpy())
  3372. # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
  3373. keypoint_features = self.keypoint_head(keypoint_features)
  3374. # print(f'keypoint_features:{keypoint_features.shape}')
  3375. tmp = keypoint_features[0][0]
  3376. plt.imshow(tmp.detach().numpy())
  3377. keypoint_logits = self.keypoint_predictor(keypoint_features)
  3378. # print(f'keypoint_logits:{keypoint_logits.shape}')
  3379. """
  3380. 接wirenet
  3381. """
  3382. loss_keypoint = {}
  3383. if self.training:
  3384. if targets is None or pos_matched_idxs is None:
  3385. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  3386. gt_keypoints = [t["keypoints"] for t in targets]
  3387. rcnn_loss_keypoint = keypointrcnn_loss(
  3388. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  3389. )
  3390. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  3391. else:
  3392. if keypoint_logits is None or keypoint_proposals is None:
  3393. raise ValueError(
  3394. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  3395. )
  3396. keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  3397. for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
  3398. r["keypoints"] = keypoint_prob
  3399. r["keypoints_scores"] = kps
  3400. losses.update(loss_keypoint)
  3401. if self.has_wirepoint():
  3402. wirepoint_proposals = [p["boxes"] for p in result]
  3403. if self.training:
  3404. # during training, only focus on positive boxes
  3405. num_images = len(proposals)
  3406. wirepoint_proposals = []
  3407. pos_matched_idxs = []
  3408. if matched_idxs is None:
  3409. raise ValueError("if in trainning, matched_idxs should not be None")
  3410. for img_id in range(num_images):
  3411. pos = torch.where(labels[img_id] > 0)[0]
  3412. wirepoint_proposals.append(proposals[img_id][pos])
  3413. pos_matched_idxs.append(matched_idxs[img_id][pos])
  3414. else:
  3415. pos_matched_idxs = None
  3416. wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
  3417. outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
  3418. # print(f"wirepoint_proposal:{type(wirepoint_proposals)}")
  3419. # print(f"wirepoint_proposal:{wirepoint_proposals.__len__()}")
  3420. # print(f"wirepoint_proposal[0].shape:{wirepoint_proposals[0].shape}")
  3421. # print(f"wirepoint_proposal[0]:{wirepoint_proposals[0]}")
  3422. # outputs = merge_features(outputs, wirepoint_proposals)
  3423. combined_output, xy_preds, end_scores = heatmaps_to_keypoints_new(outputs, wirepoint_proposals)
  3424. wire_combined_features, wire_xy_preds, wire_end_scores = heatmaps_to_keypoints_new(wirepoint_features, wirepoint_proposals)
  3425. # print(f'combined_output:{combined_output.shape}')
  3426. wirepoint_logits = self.wirepoint_predictor(inputs=combined_output, features=wire_combined_features,
  3427. targets=targets)
  3428. x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
  3429. # print(f'keypoint_features:{wirepoint_features.shape}')
  3430. if self.training:
  3431. if targets is None or pos_matched_idxs is None:
  3432. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  3433. loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
  3434. rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, combined_output, x, y, idx, loss_weight)
  3435. loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
  3436. else:
  3437. pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
  3438. result.append(pred)
  3439. loss_wirepoint = {}
  3440. losses.update(loss_wirepoint)
  3441. return result, losses
  3442. def merge_features(features, proposals):
  3443. print("merge==========================================================================start")
  3444. print(f"Features type: {type(features)}, shape: {features.shape}")
  3445. print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
  3446. print(f"Proposals : {proposals[0].shape},")
  3447. def diagnose_input(features, proposals):
  3448. """诊断输入数据"""
  3449. print("Input Diagnostics:")
  3450. print(f"Features type: {type(features)}, shape: {features.shape}")
  3451. print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
  3452. for i, p in enumerate(proposals):
  3453. print(f"Proposal {i} shape: {p.shape}")
  3454. def validate_inputs(features, proposals):
  3455. """验证输入的有效性"""
  3456. if features is None or proposals is None:
  3457. raise ValueError("Features or proposals cannot be None")
  3458. proposals_count = sum([p.size(0) for p in proposals])
  3459. features_size = features.size(0)
  3460. if proposals_count != features_size:
  3461. raise ValueError(
  3462. f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
  3463. )
  3464. def safe_max_reduction(features_per_img):
  3465. """安全的最大值压缩"""
  3466. if features_per_img.numel() == 0:
  3467. return torch.zeros_like(features_per_img).unsqueeze(0)
  3468. try:
  3469. # 沿着第0维求最大值,保持维度
  3470. max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
  3471. return max_features
  3472. except Exception as e:
  3473. print(f"Max reduction error: {e}")
  3474. return features_per_img.unsqueeze(0)
  3475. try:
  3476. # 诊断输入(可选)
  3477. # diagnose_input(features, proposals)
  3478. # 验证输入
  3479. validate_inputs(features, proposals)
  3480. # 分割特征
  3481. split_features = []
  3482. start_idx = 0
  3483. for proposal in proposals:
  3484. # 提取当前图像的特征
  3485. current_features = features[start_idx:start_idx + proposal.size(0)]
  3486. split_features.append(current_features)
  3487. start_idx += proposal.size(0)
  3488. # 每张图像特征压缩
  3489. features_imgs = []
  3490. for features_per_img in split_features:
  3491. compressed_features = safe_max_reduction(features_per_img)
  3492. features_imgs.append(compressed_features)
  3493. # 合并特征
  3494. merged_features = torch.cat(features_imgs, dim=0)
  3495. return merged_features
  3496. except Exception as e:
  3497. print(f"Error in merge_features: {e}")
  3498. # 返回原始特征或None
  3499. return features