12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950 |
- # from collections import OrderedDict
- # from typing import Dict, List, Optional, Tuple
- #
- # import matplotlib.pyplot as plt
- # import torch
- # import torch.nn.functional as F
- # import torchvision
- # from torch import nn, Tensor
- # from torchvision.ops import boxes as box_ops, roi_align
- #
- # from . import _utils as det_utils
- #
- # from torch.utils.data.dataloader import default_collate
- #
- #
- # def l2loss(input, target):
- # return ((target - input) ** 2).mean(2).mean(1)
- #
- #
- # def cross_entropy_loss(logits, positive):
- # nlogp = -F.log_softmax(logits, dim=0)
- # return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
- #
- #
- # def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
- # logp = torch.sigmoid(logits) + offset
- # loss = torch.abs(logp - target)
- # if mask is not None:
- # w = mask.mean(2, True).mean(1, True)
- # w[w == 0] = 1
- # loss = loss * (mask / w)
- #
- # return loss.mean(2).mean(1)
- #
- #
- # # def wirepoint_loss(target, outputs, feature, loss_weight,mode):
- # # wires = target['wires']
- # # result = {"feature": feature}
- # # batch, channel, row, col = outputs[0].shape
- # # print(f"Initial Output[0] shape: {outputs[0].shape}") # 打印初始输出形状
- # # print(f"Total Stacks: {len(outputs)}") # 打印堆栈数
- # #
- # # T = wires.copy()
- # # n_jtyp = T["junc_map"].shape[1]
- # # for task in ["junc_map"]:
- # # T[task] = T[task].permute(1, 0, 2, 3)
- # # for task in ["junc_offset"]:
- # # T[task] = T[task].permute(1, 2, 0, 3, 4)
- # #
- # # offset = self.head_off
- # # loss_weight = loss_weight
- # # losses = []
- # #
- # # for stack, output in enumerate(outputs):
- # # output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
- # # print(f"Stack {stack} output shape: {output.shape}") # 打印每层的输出形状
- # # jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
- # # lmap = output[offset[0]: offset[1]].squeeze(0)
- # # joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
- # #
- # # if stack == 0:
- # # result["preds"] = {
- # # "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
- # # "lmap": lmap.sigmoid(),
- # # "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
- # # }
- # # # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
- # # # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
- # # # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
- # #
- # # if mode == "testing":
- # # return result
- # #
- # # L = OrderedDict()
- # # L["junc_map"] = sum(
- # # cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
- # # )
- # # L["line_map"] = (
- # # F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
- # # .mean(2)
- # # .mean(1)
- # # )
- # # L["junc_offset"] = sum(
- # # sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
- # # for i in range(n_jtyp)
- # # for j in range(2)
- # # )
- # # for loss_name in L:
- # # L[loss_name].mul_(loss_weight[loss_name])
- # # losses.append(L)
- # #
- # # result["losses"] = losses
- # # return result
- #
- # def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
- # # output, feature: head返回结果
- # # x, y, idx : line中间生成结果
- # result = {}
- # batch, channel, row, col = output.shape
- #
- # wires_targets = [t["wires"] for t in targets]
- # wires_targets = wires_targets.copy()
- # # print(f'wires_target:{wires_targets}')
- # # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
- # junc_maps = [d["junc_map"] for d in wires_targets]
- # junc_offsets = [d["junc_offset"] for d in wires_targets]
- # line_maps = [d["line_map"] for d in wires_targets]
- #
- # junc_map_tensor = torch.stack(junc_maps, dim=0)
- # junc_offset_tensor = torch.stack(junc_offsets, dim=0)
- # line_map_tensor = torch.stack(line_maps, dim=0)
- # T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
- #
- # n_jtyp = T["junc_map"].shape[1]
- #
- # for task in ["junc_map"]:
- # T[task] = T[task].permute(1, 0, 2, 3)
- # for task in ["junc_offset"]:
- # T[task] = T[task].permute(1, 2, 0, 3, 4)
- #
- # offset = [2, 3, 5]
- # losses = []
- # output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
- # jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
- # lmap = output[offset[0]: offset[1]].squeeze(0)
- # joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
- # L = OrderedDict()
- # L["junc_map"] = sum(
- # cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
- # )
- # L["line_map"] = (
- # F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
- # .mean(2)
- # .mean(1)
- # )
- # L["junc_offset"] = sum(
- # sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
- # for i in range(n_jtyp)
- # for j in range(2)
- # )
- # for loss_name in L:
- # L[loss_name].mul_(loss_weight[loss_name])
- # losses.append(L)
- # result["losses"] = losses
- #
- # loss = nn.BCEWithLogitsLoss(reduction="none")
- # loss = loss(x, y)
- # lpos_mask, lneg_mask = y, 1 - y
- # loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
- #
- # def sum_batch(x):
- # xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
- # return torch.cat(xs)
- #
- # lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
- # lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
- # result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
- # result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
- #
- # return result
- #
- #
- # def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
- # result = {}
- # result["wires"] = {}
- # p = torch.cat(ps)
- # s = torch.sigmoid(input)
- # b = s > 0.5
- # lines = []
- # score = []
- # # print(f"n_batch:{n_batch}")
- # for i in range(n_batch):
- # # print(f"idx:{idx}")
- # p0 = p[idx[i]: idx[i + 1]]
- # s0 = s[idx[i]: idx[i + 1]]
- # mask = b[idx[i]: idx[i + 1]]
- # p0 = p0[mask]
- # s0 = s0[mask]
- # if len(p0) == 0:
- # lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
- # score.append(torch.zeros([1, n_out_line], device=p.device))
- # else:
- # arg = torch.argsort(s0, descending=True)
- # p0, s0 = p0[arg], s0[arg]
- # lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
- # score.append(s0[None, torch.arange(n_out_line) % len(s0)])
- # for j in range(len(jcs[i])):
- # if len(jcs[i][j]) == 0:
- # jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
- # jcs[i][j] = jcs[i][j][
- # None, torch.arange(n_out_junc) % len(jcs[i][j])
- # ]
- # result["wires"]["lines"] = torch.cat(lines)
- # result["wires"]["score"] = torch.cat(score)
- # result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
- #
- # if len(jcs[i]) > 1:
- # result["preds"]["junts"] = torch.cat(
- # [jcs[i][1] for i in range(n_batch)]
- # )
- #
- # return result
- #
- #
- # def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
- # # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
- # """
- # Computes the loss for Faster R-CNN.
- #
- # Args:
- # class_logits (Tensor)
- # box_regression (Tensor)
- # labels (list[BoxList])
- # regression_targets (Tensor)
- #
- # Returns:
- # classification_loss (Tensor)
- # box_loss (Tensor)
- # """
- #
- # labels = torch.cat(labels, dim=0)
- # regression_targets = torch.cat(regression_targets, dim=0)
- #
- # classification_loss = F.cross_entropy(class_logits, labels)
- #
- # # get indices that correspond to the regression targets for
- # # the corresponding ground truth labels, to be used with
- # # advanced indexing
- # sampled_pos_inds_subset = torch.where(labels > 0)[0]
- # labels_pos = labels[sampled_pos_inds_subset]
- # N, num_classes = class_logits.shape
- # box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
- #
- # box_loss = F.smooth_l1_loss(
- # box_regression[sampled_pos_inds_subset, labels_pos],
- # regression_targets[sampled_pos_inds_subset],
- # beta=1 / 9,
- # reduction="sum",
- # )
- # box_loss = box_loss / labels.numel()
- #
- # return classification_loss, box_loss
- #
- #
- # def maskrcnn_inference(x, labels):
- # # type: (Tensor, List[Tensor]) -> List[Tensor]
- # """
- # From the results of the CNN, post process the masks
- # by taking the mask corresponding to the class with max
- # probability (which are of fixed size and directly output
- # by the CNN) and return the masks in the mask field of the BoxList.
- #
- # Args:
- # x (Tensor): the mask logits
- # labels (list[BoxList]): bounding boxes that are used as
- # reference, one for ech image
- #
- # Returns:
- # results (list[BoxList]): one BoxList for each image, containing
- # the extra field mask
- # """
- # mask_prob = x.sigmoid()
- #
- # # select masks corresponding to the predicted classes
- # num_masks = x.shape[0]
- # boxes_per_image = [label.shape[0] for label in labels]
- # labels = torch.cat(labels)
- # index = torch.arange(num_masks, device=labels.device)
- # mask_prob = mask_prob[index, labels][:, None]
- # mask_prob = mask_prob.split(boxes_per_image, dim=0)
- #
- # return mask_prob
- #
- #
- # def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
- # # type: (Tensor, Tensor, Tensor, int) -> Tensor
- # """
- # Given segmentation masks and the bounding boxes corresponding
- # to the location of the masks in the image, this function
- # crops and resizes the masks in the position defined by the
- # boxes. This prepares the masks for them to be fed to the
- # loss computation as the targets.
- # """
- # matched_idxs = matched_idxs.to(boxes)
- # rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
- # gt_masks = gt_masks[:, None].to(rois)
- # return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
- #
- #
- # def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
- # # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- # """
- # Args:
- # proposals (list[BoxList])
- # mask_logits (Tensor)
- # targets (list[BoxList])
- #
- # Return:
- # mask_loss (Tensor): scalar tensor containing the loss
- # """
- #
- # discretization_size = mask_logits.shape[-1]
- # # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
- # # print(f'mask discretization_size:{discretization_size}')
- # labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
- # # print(f'mask labels:{labels}')
- # mask_targets = [
- # project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
- # ]
- #
- # labels = torch.cat(labels, dim=0)
- # # print(f'mask labels1:{labels}')
- # mask_targets = torch.cat(mask_targets, dim=0)
- #
- # # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # # accept empty tensors, so handle it separately
- # if mask_targets.numel() == 0:
- # return mask_logits.sum() * 0
- # # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
- # # print(f'mask_targets:{mask_targets}')
- # mask_loss = F.binary_cross_entropy_with_logits(
- # mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
- # )
- # # print(f'mask_loss:{mask_loss}')
- # return mask_loss
- #
- #
- # def keypoints_to_heatmap(keypoints, rois, heatmap_size):
- # # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
- # offset_x = rois[:, 0]
- # offset_y = rois[:, 1]
- # scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
- # scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
- #
- # offset_x = offset_x[:, None]
- # offset_y = offset_y[:, None]
- # scale_x = scale_x[:, None]
- # scale_y = scale_y[:, None]
- #
- # x = keypoints[..., 0]
- # y = keypoints[..., 1]
- #
- # x_boundary_inds = x == rois[:, 2][:, None]
- # y_boundary_inds = y == rois[:, 3][:, None]
- #
- # x = (x - offset_x) * scale_x
- # x = x.floor().long()
- # y = (y - offset_y) * scale_y
- # y = y.floor().long()
- #
- # x[x_boundary_inds] = heatmap_size - 1
- # y[y_boundary_inds] = heatmap_size - 1
- #
- # valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
- # vis = keypoints[..., 2] > 0
- # valid = (valid_loc & vis).long()
- #
- # lin_ind = y * heatmap_size + x
- # heatmaps = lin_ind * valid
- #
- # return heatmaps, valid
- #
- #
- # def _onnx_heatmaps_to_keypoints(
- # maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
- # ):
- # num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
- #
- # width_correction = widths_i / roi_map_width
- # height_correction = heights_i / roi_map_height
- #
- # roi_map = F.interpolate(
- # maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
- # )[:, 0]
- #
- # w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- #
- # x_int = pos % w
- # y_int = (pos - x_int) // w
- #
- # x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
- # dtype=torch.float32
- # )
- # y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
- # dtype=torch.float32
- # )
- #
- # xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
- # xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
- # xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
- # xy_preds_i = torch.stack(
- # [
- # xy_preds_i_0.to(dtype=torch.float32),
- # xy_preds_i_1.to(dtype=torch.float32),
- # xy_preds_i_2.to(dtype=torch.float32),
- # ],
- # 0,
- # )
- #
- # # TODO: simplify when indexing without rank will be supported by ONNX
- # base = num_keypoints * num_keypoints + num_keypoints + 1
- # ind = torch.arange(num_keypoints)
- # ind = ind.to(dtype=torch.int64) * base
- # end_scores_i = (
- # roi_map.index_select(1, y_int.to(dtype=torch.int64))
- # .index_select(2, x_int.to(dtype=torch.int64))
- # .view(-1)
- # .index_select(0, ind.to(dtype=torch.int64))
- # )
- #
- # return xy_preds_i, end_scores_i
- #
- #
- # @torch.jit._script_if_tracing
- # def _onnx_heatmaps_to_keypoints_loop(
- # maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
- # ):
- # xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- # end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- #
- # for i in range(int(rois.size(0))):
- # xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
- # maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
- # )
- # xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
- # end_scores = torch.cat(
- # (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
- # )
- # return xy_preds, end_scores
- #
- #
- # def heatmaps_to_keypoints(maps, rois):
- # """Extract predicted keypoint locations from heatmaps. Output has shape
- # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- # for each keypoint.
- # """
- # # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # # consistency with keypoints_to_heatmap_labels by using the conversion from
- # # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # # continuous coordinate.
- # offset_x = rois[:, 0]
- # offset_y = rois[:, 1]
- #
- # widths = rois[:, 2] - rois[:, 0]
- # heights = rois[:, 3] - rois[:, 1]
- # widths = widths.clamp(min=1)
- # heights = heights.clamp(min=1)
- # widths_ceil = widths.ceil()
- # heights_ceil = heights.ceil()
- #
- # num_keypoints = maps.shape[1]
- #
- # if torchvision._is_tracing():
- # xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
- # maps,
- # rois,
- # widths_ceil,
- # heights_ceil,
- # widths,
- # heights,
- # offset_x,
- # offset_y,
- # torch.scalar_tensor(num_keypoints, dtype=torch.int64),
- # )
- # return xy_preds.permute(0, 2, 1), end_scores
- #
- # xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
- # end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
- # for i in range(len(rois)):
- # roi_map_width = int(widths_ceil[i].item())
- # roi_map_height = int(heights_ceil[i].item())
- # width_correction = widths[i] / roi_map_width
- # height_correction = heights[i] / roi_map_height
- # roi_map = F.interpolate(
- # maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
- # )[:, 0]
- # # roi_map_probs = scores_to_probs(roi_map.copy())
- # w = roi_map.shape[2]
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- #
- # x_int = pos % w
- # y_int = torch.div(pos - x_int, w, rounding_mode="floor")
- # # assert (roi_map_probs[k, y_int, x_int] ==
- # # roi_map_probs[k, :, :].max())
- # x = (x_int.float() + 0.5) * width_correction
- # y = (y_int.float() + 0.5) * height_correction
- # xy_preds[i, 0, :] = x + offset_x[i]
- # xy_preds[i, 1, :] = y + offset_y[i]
- # xy_preds[i, 2, :] = 1
- # end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
- #
- # return xy_preds.permute(0, 2, 1), end_scores
- #
- #
- # def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
- # # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- # N, K, H, W = keypoint_logits.shape
- # if H != W:
- # raise ValueError(
- # f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- # )
- # discretization_size = H
- # heatmaps = []
- # valid = []
- # for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
- # kp = gt_kp_in_image[midx]
- # heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
- # heatmaps.append(heatmaps_per_image.view(-1))
- # valid.append(valid_per_image.view(-1))
- #
- # keypoint_targets = torch.cat(heatmaps, dim=0)
- # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
- # valid = torch.where(valid)[0]
- #
- # # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # # accept empty tensors, so handle it sepaartely
- # if keypoint_targets.numel() == 0 or len(valid) == 0:
- # return keypoint_logits.sum() * 0
- #
- # keypoint_logits = keypoint_logits.view(N * K, H * W)
- #
- # keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
- # return keypoint_loss
- #
- #
- # def keypointrcnn_inference(x, boxes):
- # # print(f'x:{x.shape}')
- # # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- # kp_probs = []
- # kp_scores = []
- #
- # boxes_per_image = [box.size(0) for box in boxes]
- # x2 = x.split(boxes_per_image, dim=0)
- # # print(f'x2:{x2}')
- #
- # for xx, bb in zip(x2, boxes):
- # kp_prob, scores = heatmaps_to_keypoints(xx, bb)
- # kp_probs.append(kp_prob)
- # kp_scores.append(scores)
- #
- # return kp_probs, kp_scores
- #
- #
- # def _onnx_expand_boxes(boxes, scale):
- # # type: (Tensor, float) -> Tensor
- # w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- # h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- # x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- # y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- #
- # w_half = w_half.to(dtype=torch.float32) * scale
- # h_half = h_half.to(dtype=torch.float32) * scale
- #
- # boxes_exp0 = x_c - w_half
- # boxes_exp1 = y_c - h_half
- # boxes_exp2 = x_c + w_half
- # boxes_exp3 = y_c + h_half
- # boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
- # return boxes_exp
- #
- #
- # # the next two functions should be merged inside Masker
- # # but are kept here for the moment while we need them
- # # temporarily for paste_mask_in_image
- # def expand_boxes(boxes, scale):
- # # type: (Tensor, float) -> Tensor
- # if torchvision._is_tracing():
- # return _onnx_expand_boxes(boxes, scale)
- # w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- # h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- # x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- # y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- #
- # w_half *= scale
- # h_half *= scale
- #
- # boxes_exp = torch.zeros_like(boxes)
- # boxes_exp[:, 0] = x_c - w_half
- # boxes_exp[:, 2] = x_c + w_half
- # boxes_exp[:, 1] = y_c - h_half
- # boxes_exp[:, 3] = y_c + h_half
- # return boxes_exp
- #
- #
- # @torch.jit.unused
- # def expand_masks_tracing_scale(M, padding):
- # # type: (int, int) -> float
- # return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
- #
- #
- # def expand_masks(mask, padding):
- # # type: (Tensor, int) -> Tuple[Tensor, float]
- # M = mask.shape[-1]
- # if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
- # scale = expand_masks_tracing_scale(M, padding)
- # else:
- # scale = float(M + 2 * padding) / M
- # padded_mask = F.pad(mask, (padding,) * 4)
- # return padded_mask, scale
- #
- #
- # def paste_mask_in_image(mask, box, im_h, im_w):
- # # type: (Tensor, Tensor, int, int) -> Tensor
- # TO_REMOVE = 1
- # w = int(box[2] - box[0] + TO_REMOVE)
- # h = int(box[3] - box[1] + TO_REMOVE)
- # w = max(w, 1)
- # h = max(h, 1)
- #
- # # Set shape to [batchxCxHxW]
- # mask = mask.expand((1, 1, -1, -1))
- #
- # # Resize mask
- # mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
- # mask = mask[0][0]
- #
- # im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
- # x_0 = max(box[0], 0)
- # x_1 = min(box[2] + 1, im_w)
- # y_0 = max(box[1], 0)
- # y_1 = min(box[3] + 1, im_h)
- #
- # im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- # return im_mask
- #
- #
- # def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
- # one = torch.ones(1, dtype=torch.int64)
- # zero = torch.zeros(1, dtype=torch.int64)
- #
- # w = box[2] - box[0] + one
- # h = box[3] - box[1] + one
- # w = torch.max(torch.cat((w, one)))
- # h = torch.max(torch.cat((h, one)))
- #
- # # Set shape to [batchxCxHxW]
- # mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
- #
- # # Resize mask
- # mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
- # mask = mask[0][0]
- #
- # x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
- # x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
- # y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
- # y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
- #
- # unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- #
- # # TODO : replace below with a dynamic padding when support is added in ONNX
- #
- # # pad y
- # zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
- # zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
- # concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
- # # pad x
- # zeros_x0 = torch.zeros(concat_0.size(0), x_0)
- # zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
- # im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
- # return im_mask
- #
- #
- # @torch.jit._script_if_tracing
- # def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
- # res_append = torch.zeros(0, im_h, im_w)
- # for i in range(masks.size(0)):
- # mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
- # mask_res = mask_res.unsqueeze(0)
- # res_append = torch.cat((res_append, mask_res))
- # return res_append
- #
- #
- # def paste_masks_in_image(masks, boxes, img_shape, padding=1):
- # # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
- # masks, scale = expand_masks(masks, padding=padding)
- # boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
- # im_h, im_w = img_shape
- #
- # if torchvision._is_tracing():
- # return _onnx_paste_masks_in_image_loop(
- # masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
- # )[:, None]
- # res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
- # if len(res) > 0:
- # ret = torch.stack(res, dim=0)[:, None]
- # else:
- # ret = masks.new_empty((0, 1, im_h, im_w))
- # return ret
- #
- #
- # class RoIHeads(nn.Module):
- # __annotations__ = {
- # "box_coder": det_utils.BoxCoder,
- # "proposal_matcher": det_utils.Matcher,
- # "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
- # }
- #
- # def __init__(
- # self,
- # box_roi_pool,
- # box_head,
- # box_predictor,
- # # Faster R-CNN training
- # fg_iou_thresh,
- # bg_iou_thresh,
- # batch_size_per_image,
- # positive_fraction,
- # bbox_reg_weights,
- # # Faster R-CNN inference
- # score_thresh,
- # nms_thresh,
- # detections_per_img,
- # # Mask
- # mask_roi_pool=None,
- # mask_head=None,
- # mask_predictor=None,
- # keypoint_roi_pool=None,
- # keypoint_head=None,
- # keypoint_predictor=None,
- # wirepoint_roi_pool=None,
- # wirepoint_head=None,
- # wirepoint_predictor=None,
- # ):
- # super().__init__()
- #
- # self.box_similarity = box_ops.box_iou
- # # assign ground-truth boxes for each proposal
- # self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
- #
- # self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
- #
- # if bbox_reg_weights is None:
- # bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
- # self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
- #
- # self.box_roi_pool = box_roi_pool
- # self.box_head = box_head
- # self.box_predictor = box_predictor
- #
- # self.score_thresh = score_thresh
- # self.nms_thresh = nms_thresh
- # self.detections_per_img = detections_per_img
- #
- # self.mask_roi_pool = mask_roi_pool
- # self.mask_head = mask_head
- # self.mask_predictor = mask_predictor
- #
- # self.keypoint_roi_pool = keypoint_roi_pool
- # self.keypoint_head = keypoint_head
- # self.keypoint_predictor = keypoint_predictor
- #
- # self.wirepoint_roi_pool = wirepoint_roi_pool
- # self.wirepoint_head = wirepoint_head
- # self.wirepoint_predictor = wirepoint_predictor
- #
- # def has_mask(self):
- # if self.mask_roi_pool is None:
- # return False
- # if self.mask_head is None:
- # return False
- # if self.mask_predictor is None:
- # return False
- # return True
- #
- # def has_keypoint(self):
- # if self.keypoint_roi_pool is None:
- # return False
- # if self.keypoint_head is None:
- # return False
- # if self.keypoint_predictor is None:
- # return False
- # return True
- #
- # def has_wirepoint(self):
- # if self.wirepoint_roi_pool is None:
- # print(f'wirepoint_roi_pool is None')
- # return False
- # if self.wirepoint_head is None:
- # print(f'wirepoint_head is None')
- # return False
- # if self.wirepoint_predictor is None:
- # print(f'wirepoint_roi_predictor is None')
- # return False
- # return True
- #
- # def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
- # # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- # matched_idxs = []
- # labels = []
- # for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
- #
- # if gt_boxes_in_image.numel() == 0:
- # # Background image
- # device = proposals_in_image.device
- # clamped_matched_idxs_in_image = torch.zeros(
- # (proposals_in_image.shape[0],), dtype=torch.int64, device=device
- # )
- # labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
- # else:
- # # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
- # match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
- # matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
- #
- # clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
- #
- # labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
- # labels_in_image = labels_in_image.to(dtype=torch.int64)
- #
- # # Label background (below the low threshold)
- # bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
- # labels_in_image[bg_inds] = 0
- #
- # # Label ignore proposals (between low and high thresholds)
- # ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
- # labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
- #
- # matched_idxs.append(clamped_matched_idxs_in_image)
- # labels.append(labels_in_image)
- # return matched_idxs, labels
- #
- # def subsample(self, labels):
- # # type: (List[Tensor]) -> List[Tensor]
- # sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
- # sampled_inds = []
- # for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
- # img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
- # sampled_inds.append(img_sampled_inds)
- # return sampled_inds
- #
- # def add_gt_proposals(self, proposals, gt_boxes):
- # # type: (List[Tensor], List[Tensor]) -> List[Tensor]
- # proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
- #
- # return proposals
- #
- # def check_targets(self, targets):
- # # type: (Optional[List[Dict[str, Tensor]]]) -> None
- # if targets is None:
- # raise ValueError("targets should not be None")
- # if not all(["boxes" in t for t in targets]):
- # raise ValueError("Every element of targets should have a boxes key")
- # if not all(["labels" in t for t in targets]):
- # raise ValueError("Every element of targets should have a labels key")
- # if self.has_mask():
- # if not all(["masks" in t for t in targets]):
- # raise ValueError("Every element of targets should have a masks key")
- #
- # def select_training_samples(
- # self,
- # proposals, # type: List[Tensor]
- # targets, # type: Optional[List[Dict[str, Tensor]]]
- # ):
- # # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
- # self.check_targets(targets)
- # if targets is None:
- # raise ValueError("targets should not be None")
- # dtype = proposals[0].dtype
- # device = proposals[0].device
- #
- # gt_boxes = [t["boxes"].to(dtype) for t in targets]
- # gt_labels = [t["labels"] for t in targets]
- #
- # # append ground-truth bboxes to propos
- # proposals = self.add_gt_proposals(proposals, gt_boxes)
- #
- # # get matching gt indices for each proposal
- # matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
- # # sample a fixed proportion of positive-negative proposals
- # sampled_inds = self.subsample(labels)
- # matched_gt_boxes = []
- # num_images = len(proposals)
- # for img_id in range(num_images):
- # img_sampled_inds = sampled_inds[img_id]
- # proposals[img_id] = proposals[img_id][img_sampled_inds]
- # labels[img_id] = labels[img_id][img_sampled_inds]
- # matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
- #
- # gt_boxes_in_image = gt_boxes[img_id]
- # if gt_boxes_in_image.numel() == 0:
- # gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
- # matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
- #
- # regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
- # return proposals, matched_idxs, labels, regression_targets
- #
- # def postprocess_detections(
- # self,
- # class_logits, # type: Tensor
- # box_regression, # type: Tensor
- # proposals, # type: List[Tensor]
- # image_shapes, # type: List[Tuple[int, int]]
- # ):
- # # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
- # device = class_logits.device
- # num_classes = class_logits.shape[-1]
- #
- # boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
- # pred_boxes = self.box_coder.decode(box_regression, proposals)
- #
- # pred_scores = F.softmax(class_logits, -1)
- #
- # pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
- # pred_scores_list = pred_scores.split(boxes_per_image, 0)
- #
- # all_boxes = []
- # all_scores = []
- # all_labels = []
- # for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
- # boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
- #
- # # create labels for each prediction
- # labels = torch.arange(num_classes, device=device)
- # labels = labels.view(1, -1).expand_as(scores)
- #
- # # remove predictions with the background label
- # boxes = boxes[:, 1:]
- # scores = scores[:, 1:]
- # labels = labels[:, 1:]
- #
- # # batch everything, by making every class prediction be a separate instance
- # boxes = boxes.reshape(-1, 4)
- # scores = scores.reshape(-1)
- # labels = labels.reshape(-1)
- #
- # # remove low scoring boxes
- # inds = torch.where(scores > self.score_thresh)[0]
- # boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
- #
- # # remove empty boxes
- # keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
- # boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- #
- # # non-maximum suppression, independently done per class
- # keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
- # # keep only topk scoring predictions
- # keep = keep[: self.detections_per_img]
- # boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- #
- # all_boxes.append(boxes)
- # all_scores.append(scores)
- # all_labels.append(labels)
- #
- # return all_boxes, all_scores, all_labels
- #
- # def forward(
- # self,
- # features, # type: Dict[str, Tensor]
- # proposals, # type: List[Tensor]
- # image_shapes, # type: List[Tuple[int, int]]
- # targets=None, # type: Optional[List[Dict[str, Tensor]]]
- # ):
- # # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
- # """
- # Args:
- # features (List[Tensor])
- # proposals (List[Tensor[N, 4]])
- # image_shapes (List[Tuple[H, W]])
- # targets (List[Dict])
- # """
- # if targets is not None:
- # for t in targets:
- # # TODO: https://github.com/pytorch/pytorch/issues/26731
- # floating_point_types = (torch.float, torch.double, torch.half)
- # if not t["boxes"].dtype in floating_point_types:
- # raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
- # if not t["labels"].dtype == torch.int64:
- # raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
- # if self.has_keypoint():
- # if not t["keypoints"].dtype == torch.float32:
- # raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
- #
- # if self.training:
- # proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
- # else:
- # labels = None
- # regression_targets = None
- # matched_idxs = None
- #
- # print(f"image_shapes:{image_shapes}")
- # box_features = self.box_roi_pool(features, proposals, image_shapes)
- # box_features = self.box_head(box_features)
- # class_logits, box_regression = self.box_predictor(box_features)
- #
- # result: List[Dict[str, torch.Tensor]] = []
- # losses = {}
- # if self.training:
- # if labels is None:
- # raise ValueError("labels cannot be None")
- # if regression_targets is None:
- # raise ValueError("regression_targets cannot be None")
- # loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
- # losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
- # else:
- # print('result append boxes!!!')
- # boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
- # num_images = len(boxes)
- # for i in range(num_images):
- # result.append(
- # {
- # "boxes": boxes[i],
- # "labels": labels[i],
- # "scores": scores[i],
- # }
- # )
- #
- # if self.has_mask():
- # mask_proposals = [p["boxes"] for p in result]
- # if self.training:
- # if matched_idxs is None:
- # raise ValueError("if in training, matched_idxs should not be None")
- #
- # # during training, only focus on positive boxes
- # num_images = len(proposals)
- # mask_proposals = []
- # pos_matched_idxs = []
- # for img_id in range(num_images):
- # pos = torch.where(labels[img_id] > 0)[0]
- # mask_proposals.append(proposals[img_id][pos])
- # pos_matched_idxs.append(matched_idxs[img_id][pos])
- # else:
- # pos_matched_idxs = None
- #
- # if self.mask_roi_pool is not None:
- # mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
- # mask_features = self.mask_head(mask_features)
- # mask_logits = self.mask_predictor(mask_features)
- # else:
- # raise Exception("Expected mask_roi_pool to be not None")
- #
- # loss_mask = {}
- # if self.training:
- # if targets is None or pos_matched_idxs is None or mask_logits is None:
- # raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
- #
- # gt_masks = [t["masks"] for t in targets]
- # gt_labels = [t["labels"] for t in targets]
- # rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
- # loss_mask = {"loss_mask": rcnn_loss_mask}
- # else:
- # labels = [r["labels"] for r in result]
- # masks_probs = maskrcnn_inference(mask_logits, labels)
- # for mask_prob, r in zip(masks_probs, result):
- # r["masks"] = mask_prob
- #
- # losses.update(loss_mask)
- #
- # # keep none checks in if conditional so torchscript will conditionally
- # # compile each branch
- # if self.has_keypoint():
- #
- # keypoint_proposals = [p["boxes"] for p in result]
- # if self.training:
- # # during training, only focus on positive boxes
- # num_images = len(proposals)
- # keypoint_proposals = []
- # pos_matched_idxs = []
- # if matched_idxs is None:
- # raise ValueError("if in trainning, matched_idxs should not be None")
- #
- # for img_id in range(num_images):
- # pos = torch.where(labels[img_id] > 0)[0]
- # keypoint_proposals.append(proposals[img_id][pos])
- # pos_matched_idxs.append(matched_idxs[img_id][pos])
- # else:
- # pos_matched_idxs = None
- #
- # keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
- # # tmp = keypoint_features[0][0]
- # # plt.imshow(tmp.detach().numpy())
- # # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
- # keypoint_features = self.keypoint_head(keypoint_features)
- #
- # # print(f'keypoint_features:{keypoint_features.shape}')
- # tmp = keypoint_features[0][0]
- # plt.imshow(tmp.detach().numpy())
- # keypoint_logits = self.keypoint_predictor(keypoint_features)
- # # print(f'keypoint_logits:{keypoint_logits.shape}')
- # """
- # 接wirenet
- # """
- #
- # loss_keypoint = {}
- # if self.training:
- # if targets is None or pos_matched_idxs is None:
- # raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- #
- # gt_keypoints = [t["keypoints"] for t in targets]
- # rcnn_loss_keypoint = keypointrcnn_loss(
- # keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
- # )
- # loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
- # else:
- # if keypoint_logits is None or keypoint_proposals is None:
- # raise ValueError(
- # "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
- # )
- #
- # keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
- # for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
- # r["keypoints"] = keypoint_prob
- # r["keypoints_scores"] = kps
- # losses.update(loss_keypoint)
- #
- # if self.has_wirepoint():
- # # print(f'result:{result}')
- # wirepoint_proposals = [p["boxes"] for p in result]
- # if self.training:
- # # during training, only focus on positive boxes
- # num_images = len(proposals)
- # wirepoint_proposals = []
- # pos_matched_idxs = []
- # if matched_idxs is None:
- # raise ValueError("if in trainning, matched_idxs should not be None")
- #
- # for img_id in range(num_images):
- # pos = torch.where(labels[img_id] > 0)[0]
- # wirepoint_proposals.append(proposals[img_id][pos])
- # pos_matched_idxs.append(matched_idxs[img_id][pos])
- # else:
- # pos_matched_idxs = None
- #
- # # print(f'proposals:{len(proposals)}')
- # wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
- #
- # # tmp = keypoint_features[0][0]
- # # plt.imshow(tmp.detach().numpy())
- # # print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
- # outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
- #
- # print(f"wirepoint_features:{wirepoint_features}")
- #
- #
- #
- # outputs = merge_features(outputs, wirepoint_proposals)
- #
- #
- #
- # wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
- #
- # print(f'outpust:{outputs.shape}')
- #
- # wirepoint_logits = self.wirepoint_predictor(inputs=outputs, features=wirepoint_features, targets=targets)
- # x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
- #
- # # print(f'keypoint_features:{wirepoint_features.shape}')
- # if self.training:
- #
- # if targets is None or pos_matched_idxs is None:
- # raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- #
- # loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
- # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
- #
- # loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
- #
- # else:
- # pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
- # result.append(pred)
- #
- # loss_wirepoint = {}
- #
- # # loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
- # # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
- # # loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
- #
- # # tmp = wirepoint_features[0][0]
- # # plt.imshow(tmp.detach().numpy())
- # # wirepoint_logits = self.wirepoint_predictor((outputs,wirepoint_features))
- # # print(f'keypoint_logits:{wirepoint_logits.shape}')
- #
- # # loss_wirepoint = {} lm
- # # result=wirepoint_logits
- #
- # # result.append(pred) lm
- # losses.update(loss_wirepoint)
- # # print(f"result{result}")
- # # print(f"losses{losses}")
- #
- # return result, losses
- #
- #
- # # def merge_features(features, proposals):
- # # # 假设 roi_pool_features 是你的输入张量,形状为 [600, 256, 128, 128]
- # #
- # # # 使用 torch.split 按照每个图像的提议数量分割 features
- # # proposals_count = sum([p.size(0) for p in proposals])
- # # features_size = features.size(0)
- # # # (f'proposals sum:{proposals_count},features batch:{features.size(0)}')
- # # if proposals_count != features_size:
- # # raise ValueError("The length of proposals must match the batch size of features.")
- # #
- # # split_features = []
- # # start_idx = 0
- # # print(f"proposals:{proposals}")
- # # for proposal in proposals:
- # # # 提取当前图像的特征
- # # current_features = features[start_idx:start_idx + proposal.size(0)]
- # # # print(f'current_features:{current_features.shape}')
- # # split_features.append(current_features)
- # # start_idx += 1
- # #
- # # features_imgs = []
- # # for features_per_img in split_features:
- # # features_per_img, _ = torch.max(features_per_img, dim=0, keepdim=True)
- # # features_imgs.append(features_per_img)
- # #
- # # merged_features = torch.cat(features_imgs, dim=0)
- # # # print(f' merged_features:{merged_features.shape}')
- # # return merged_features
- #
- # def merge_features(features, proposals):
- # print(f'features:{features.shape}')
- # print(f'proposals:{len(proposals)}')
- # def diagnose_input(features, proposals):
- # """诊断输入数据"""
- # print("Input Diagnostics:")
- # print(f"Features type: {type(features)}, shape: {features.shape}")
- # print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
- # for i, p in enumerate(proposals):
- # print(f"Proposal {i} shape: {p.shape}")
- #
- # def validate_inputs(features, proposals):
- # """验证输入的有效性"""
- # if features is None or proposals is None:
- # raise ValueError("Features or proposals cannot be None")
- #
- # proposals_count = sum([p.size(0) for p in proposals])
- # features_size = features.size(0)
- #
- # if proposals_count != features_size:
- # raise ValueError(
- # f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
- # )
- #
- # def safe_max_reduction(features_per_img,proposals):
- #
- # print(f'proposal:{proposals.shape},features_per_img:{features_per_img.shape}')
- # """安全的最大值压缩"""
- # if features_per_img.numel() == 0:
- # return torch.zeros_like(features_per_img).unsqueeze(0)
- #
- # for feature_map,roi in zip(features_per_img,proposals):
- # # print(f'feature_map:{feature_map.shape},roi:{roi}')
- # roi_off_x=roi[0]
- # roi_off_y=roi[1]
- #
- #
- # try:
- # # 沿着第0维求最大值,保持维度
- # max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
- # return max_features
- # except Exception as e:
- # print(f"Max reduction error: {e}")
- # return features_per_img.unsqueeze(0)
- #
- # try:
- # # 诊断输入(可选)
- # # diagnose_input(features, proposals)
- #
- # # 验证输入
- # validate_inputs(features, proposals)
- #
- # # 分割特征
- # split_features = []
- # start_idx = 0
- #
- # for proposal in proposals:
- # # 提取当前图像的特征
- # current_features = features[start_idx:start_idx + proposal.size(0)]
- # split_features.append(current_features)
- # start_idx += proposal.size(0)
- #
- # # 每张图像特征压缩
- # features_imgs = []
- #
- # print(f'split_features:{len(split_features)}')
- # for features_per_img,proposal in zip(split_features,proposals):
- # compressed_features = safe_max_reduction(features_per_img,proposal)
- # features_imgs.append(compressed_features)
- #
- # # 合并特征
- # merged_features = torch.cat(features_imgs, dim=0)
- #
- # return merged_features
- #
- # except Exception as e:
- # print(f"Error in merge_features: {e}")
- # # 返回原始特征或None
- # return features
- #
- '''
- from collections import OrderedDict
- from typing import Dict, List, Optional, Tuple
- import matplotlib.pyplot as plt
- import torch
- import torch.nn.functional as F
- import torchvision
- from torch import nn, Tensor
- from torchvision.ops import boxes as box_ops, roi_align
- from models.wirenet import _utils as det_utils
- from torch.utils.data.dataloader import default_collate
- def l2loss(input, target):
- return ((target - input) ** 2).mean(2).mean(1)
- def cross_entropy_loss(logits, positive):
- nlogp = -F.log_softmax(logits, dim=0)
- return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
- def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
- logp = torch.sigmoid(logits) + offset
- loss = torch.abs(logp - target)
- if mask is not None:
- w = mask.mean(2, True).mean(1, True)
- w[w == 0] = 1
- loss = loss * (mask / w)
- return loss.mean(2).mean(1)
- def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
- # output, feature: head返回结果
- # x, y, idx : line中间生成结果
- result = {}
- batch, channel, row, col = output.shape
- wires_targets = [t["wires"] for t in targets]
- wires_targets = wires_targets.copy()
- # print(f'wires_target:{wires_targets}')
- # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
- junc_maps = [d["junc_map"] for d in wires_targets]
- junc_offsets = [d["junc_offset"] for d in wires_targets]
- line_maps = [d["line_map"] for d in wires_targets]
- junc_map_tensor = torch.stack(junc_maps, dim=0)
- junc_offset_tensor = torch.stack(junc_offsets, dim=0)
- line_map_tensor = torch.stack(line_maps, dim=0)
- T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
- n_jtyp = T["junc_map"].shape[1]
- for task in ["junc_map"]:
- T[task] = T[task].permute(1, 0, 2, 3)
- for task in ["junc_offset"]:
- T[task] = T[task].permute(1, 2, 0, 3, 4)
- offset = [2, 3, 5]
- losses = []
- output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
- jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
- lmap = output[offset[0]: offset[1]].squeeze(0)
- joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
- L = OrderedDict()
- L["junc_map"] = sum(
- cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
- )
- L["line_map"] = (
- F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
- .mean(2)
- .mean(1)
- )
- L["junc_offset"] = sum(
- sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
- for i in range(n_jtyp)
- for j in range(2)
- )
- for loss_name in L:
- L[loss_name].mul_(loss_weight[loss_name])
- losses.append(L)
- result["losses"] = losses
- loss = nn.BCEWithLogitsLoss(reduction="none")
- loss = loss(x, y)
- lpos_mask, lneg_mask = y, 1 - y
- loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
- def sum_batch(x):
- xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
- return torch.cat(xs)
- lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
- lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
- result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
- result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
- return result
- def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
- result = {}
- result["wires"] = {}
- p = torch.cat(ps)
- s = torch.sigmoid(input)
- b = s > 0.5
- lines = []
- score = []
- # print(f"n_batch:{n_batch}")
- for i in range(n_batch):
- # print(f"idx:{idx}")
- p0 = p[idx[i]: idx[i + 1]]
- s0 = s[idx[i]: idx[i + 1]]
- mask = b[idx[i]: idx[i + 1]]
- p0 = p0[mask]
- s0 = s0[mask]
- if len(p0) == 0:
- lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
- score.append(torch.zeros([1, n_out_line], device=p.device))
- else:
- arg = torch.argsort(s0, descending=True)
- p0, s0 = p0[arg], s0[arg]
- lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
- score.append(s0[None, torch.arange(n_out_line) % len(s0)])
- for j in range(len(jcs[i])):
- if len(jcs[i][j]) == 0:
- jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
- jcs[i][j] = jcs[i][j][
- None, torch.arange(n_out_junc) % len(jcs[i][j])
- ]
- result["wires"]["lines"] = torch.cat(lines)
- result["wires"]["score"] = torch.cat(score)
- result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
- if len(jcs[i]) > 1:
- result["preds"]["junts"] = torch.cat(
- [jcs[i][1] for i in range(n_batch)]
- )
- return result
- def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
- # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
- """
- Computes the loss for Faster R-CNN.
- Args:
- class_logits (Tensor)
- box_regression (Tensor)
- labels (list[BoxList])
- regression_targets (Tensor)
- Returns:
- classification_loss (Tensor)
- box_loss (Tensor)
- """
- labels = torch.cat(labels, dim=0)
- regression_targets = torch.cat(regression_targets, dim=0)
- classification_loss = F.cross_entropy(class_logits, labels)
- # get indices that correspond to the regression targets for
- # the corresponding ground truth labels, to be used with
- # advanced indexing
- sampled_pos_inds_subset = torch.where(labels > 0)[0]
- labels_pos = labels[sampled_pos_inds_subset]
- N, num_classes = class_logits.shape
- box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
- box_loss = F.smooth_l1_loss(
- box_regression[sampled_pos_inds_subset, labels_pos],
- regression_targets[sampled_pos_inds_subset],
- beta=1 / 9,
- reduction="sum",
- )
- box_loss = box_loss / labels.numel()
- return classification_loss, box_loss
- def maskrcnn_inference(x, labels):
- # type: (Tensor, List[Tensor]) -> List[Tensor]
- """
- From the results of the CNN, post process the masks
- by taking the mask corresponding to the class with max
- probability (which are of fixed size and directly output
- by the CNN) and return the masks in the mask field of the BoxList.
- Args:
- x (Tensor): the mask logits
- labels (list[BoxList]): bounding boxes that are used as
- reference, one for ech image
- Returns:
- results (list[BoxList]): one BoxList for each image, containing
- the extra field mask
- """
- mask_prob = x.sigmoid()
- # select masks corresponding to the predicted classes
- num_masks = x.shape[0]
- boxes_per_image = [label.shape[0] for label in labels]
- labels = torch.cat(labels)
- index = torch.arange(num_masks, device=labels.device)
- mask_prob = mask_prob[index, labels][:, None]
- mask_prob = mask_prob.split(boxes_per_image, dim=0)
- return mask_prob
- def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
- # type: (Tensor, Tensor, Tensor, int) -> Tensor
- """
- Given segmentation masks and the bounding boxes corresponding
- to the location of the masks in the image, this function
- crops and resizes the masks in the position defined by the
- boxes. This prepares the masks for them to be fed to the
- loss computation as the targets.
- """
- matched_idxs = matched_idxs.to(boxes)
- rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
- gt_masks = gt_masks[:, None].to(rois)
- return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
- def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- """
- Args:
- proposals (list[BoxList])
- mask_logits (Tensor)
- targets (list[BoxList])
- Return:
- mask_loss (Tensor): scalar tensor containing the loss
- """
- discretization_size = mask_logits.shape[-1]
- # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
- # print(f'mask discretization_size:{discretization_size}')
- labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
- # print(f'mask labels:{labels}')
- mask_targets = [
- project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
- ]
- labels = torch.cat(labels, dim=0)
- # print(f'mask labels1:{labels}')
- mask_targets = torch.cat(mask_targets, dim=0)
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it separately
- if mask_targets.numel() == 0:
- return mask_logits.sum() * 0
- # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
- # print(f'mask_targets:{mask_targets}')
- mask_loss = F.binary_cross_entropy_with_logits(
- mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
- )
- # print(f'mask_loss:{mask_loss}')
- return mask_loss
- def keypoints_to_heatmap(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
- scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
- offset_x = offset_x[:, None]
- offset_y = offset_y[:, None]
- scale_x = scale_x[:, None]
- scale_y = scale_y[:, None]
- x = keypoints[..., 0]
- y = keypoints[..., 1]
- x_boundary_inds = x == rois[:, 2][:, None]
- y_boundary_inds = y == rois[:, 3][:, None]
- x = (x - offset_x) * scale_x
- x = x.floor().long()
- y = (y - offset_y) * scale_y
- y = y.floor().long()
- x[x_boundary_inds] = heatmap_size - 1
- y[y_boundary_inds] = heatmap_size - 1
- valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
- vis = keypoints[..., 2] > 0
- valid = (valid_loc & vis).long()
- lin_ind = y * heatmap_size + x
- heatmaps = lin_ind * valid
- return heatmaps, valid
- def _onnx_heatmaps_to_keypoints(
- maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
- ):
- num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
- width_correction = widths_i / roi_map_width
- height_correction = heights_i / roi_map_height
- roi_map = F.interpolate(
- maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
- )[:, 0]
- w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = (pos - x_int) // w
- x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
- dtype=torch.float32
- )
- y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
- dtype=torch.float32
- )
- xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
- xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
- xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
- xy_preds_i = torch.stack(
- [
- xy_preds_i_0.to(dtype=torch.float32),
- xy_preds_i_1.to(dtype=torch.float32),
- xy_preds_i_2.to(dtype=torch.float32),
- ],
- 0,
- )
- # TODO: simplify when indexing without rank will be supported by ONNX
- base = num_keypoints * num_keypoints + num_keypoints + 1
- ind = torch.arange(num_keypoints)
- ind = ind.to(dtype=torch.int64) * base
- end_scores_i = (
- roi_map.index_select(1, y_int.to(dtype=torch.int64))
- .index_select(2, x_int.to(dtype=torch.int64))
- .view(-1)
- .index_select(0, ind.to(dtype=torch.int64))
- )
- return xy_preds_i, end_scores_i
- @torch.jit._script_if_tracing
- def _onnx_heatmaps_to_keypoints_loop(
- maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
- ):
- xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- for i in range(int(rois.size(0))):
- xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
- maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
- )
- xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
- end_scores = torch.cat(
- (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
- )
- return xy_preds, end_scores
- def heatmaps_to_keypoints(maps, rois):
- """Extract predicted keypoint locations from heatmaps. Output has shape
- (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- for each keypoint.
- """
- # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # consistency with keypoints_to_heatmap_labels by using the conversion from
- # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # continuous coordinate.
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- widths = rois[:, 2] - rois[:, 0]
- heights = rois[:, 3] - rois[:, 1]
- widths = widths.clamp(min=1)
- heights = heights.clamp(min=1)
- widths_ceil = widths.ceil()
- heights_ceil = heights.ceil()
- num_keypoints = maps.shape[1]
- if torchvision._is_tracing():
- xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
- maps,
- rois,
- widths_ceil,
- heights_ceil,
- widths,
- heights,
- offset_x,
- offset_y,
- torch.scalar_tensor(num_keypoints, dtype=torch.int64),
- )
- return xy_preds.permute(0, 2, 1), end_scores
- xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
- for i in range(len(rois)):
- roi_map_width = int(widths_ceil[i].item())
- roi_map_height = int(heights_ceil[i].item())
- width_correction = widths[i] / roi_map_width
- height_correction = heights[i] / roi_map_height
- roi_map = F.interpolate(
- maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
- )[:, 0]
- # roi_map_probs = scores_to_probs(roi_map.copy())
- w = roi_map.shape[2]
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = torch.div(pos - x_int, w, rounding_mode="floor")
- # assert (roi_map_probs[k, y_int, x_int] ==
- # roi_map_probs[k, :, :].max())
- x = (x_int.float() + 0.5) * width_correction
- y = (y_int.float() + 0.5) * height_correction
- xy_preds[i, 0, :] = x + offset_x[i]
- xy_preds[i, 1, :] = y + offset_y[i]
- xy_preds[i, 2, :] = 1
- end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
- return xy_preds.permute(0, 2, 1), end_scores
- def heatmaps_to_keypoints_new(maps, rois):
- # """Extract predicted keypoint locations from heatmaps. Output has shape
- # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- # for each keypoint.
- # """
- # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # consistency with keypoints_to_heatmap_labels by using the conversion from
- # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # continuous coordinate.
- print(f"maps.shape:{maps.shape}")
- rois = rois[0]
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- widths = rois[:, 2] - rois[:, 0]
- heights = rois[:, 3] - rois[:, 1]
- widths = widths.clamp(min=1)
- heights = heights.clamp(min=1)
- widths_ceil = widths.ceil()
- heights_ceil = heights.ceil()
- num_keypoints = maps.shape[1]
- if torchvision._is_tracing():
- xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
- maps,
- rois,
- widths_ceil,
- heights_ceil,
- widths,
- heights,
- offset_x,
- offset_y,
- torch.scalar_tensor(num_keypoints, dtype=torch.int64),
- )
- return xy_preds.permute(0, 2, 1), end_scores
- xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
- # 创建一个 512x512 的空白图像
- # combined_map = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
- combined_map = torch.zeros((len(rois), maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
- combined_mask = torch.zeros((1, 1, 512, 512), dtype=torch.float32, device=maps.device)
- print(f"combined_map.shape: {combined_map.shape}")
- print(f"len of rois:{len(rois)}")
- for i in range(len(rois)):
- roi_map_width = int(widths_ceil[i].item())
- roi_map_height = int(heights_ceil[i].item())
- width_correction = widths[i] / roi_map_width
- height_correction = heights[i] / roi_map_height
- roi_map = F.interpolate(
- maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
- )[:, 0]
- x_offset = int(offset_x[i].item()) # 转换为标量
- y_offset = int(offset_y[i].item()) # 转换为标量
- # print(f"x_offset: {x_offset}, y_offset: {y_offset}, roi_map.shape: {roi_map.shape}")
- # 检查偏移量是否合理
- if y_offset < 0 or y_offset + roi_map.shape[1] > combined_map.shape[2] or x_offset < 0 or x_offset + \
- roi_map.shape[2] > combined_map.shape[3]:
- print("Error: Offset exceeds combined_map dimensions.")
- else:
- # 检查 roi_map 的大小
- if roi_map.shape[1] <= 0 or roi_map.shape[2] <= 0:
- print("Error: Invalid ROI size.")
- else:
- # 填充 combined_map
- # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = \
- # torch.max(
- # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]],
- # roi_map)
- combined_map[i, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- roi_mask = torch.ones((1, roi_map_height, roi_map_width), dtype=torch.float32, device=maps.device)
- combined_mask[0, 0, y_offset:y_offset + roi_map_height, x_offset:x_offset + roi_map_width] = roi_mask
- # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- w = roi_map.shape[2]
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = torch.div(pos - x_int, w, rounding_mode="floor")
- # assert (roi_map_probs[k, y_int, x_int] ==
- # roi_map_probs[k, :, :].max())
- x = (x_int.float() + 0.5) * width_correction
- y = (y_int.float() + 0.5) * height_correction
- xy_preds[i, 0, :] = x + offset_x[i]
- xy_preds[i, 1, :] = y + offset_y[i]
- xy_preds[i, 2, :] = 1
- end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
- combined_map_final, _ = torch.max(combined_map, dim=0, keepdim=True)
- combined_map1 = F.interpolate(combined_map_final, size=(128, 128), mode='bilinear', align_corners=False)
- # print(f"combined_map.shape:{combined_map1.shape}")
- combined_mask = F.interpolate(combined_mask, size=(128, 128), mode='bilinear', align_corners=False)
- combined_mask = (combined_mask >= 0.5).float() # 应用阈值0.5
- return combined_map1, xy_preds.permute(0, 2, 1), end_scores, combined_mask
- # def heatmaps_to_keypoints_new(maps, rois):
- # # """Extract predicted keypoint locations from heatmaps. Output has shape
- # # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- # # for each keypoint.
- # # """
- # # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # # consistency with keypoints_to_heatmap_labels by using the conversion from
- # # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # # continuous coordinate.
- # print(f"maps.shape:{maps.shape}")
- # rois = rois[0]
- # offset_x = rois[:, 0]
- # offset_y = rois[:, 1]
- #
- # widths = rois[:, 2] - rois[:, 0]
- # heights = rois[:, 3] - rois[:, 1]
- # widths = widths.clamp(min=1)
- # heights = heights.clamp(min=1)
- # widths_ceil = widths.ceil()
- # heights_ceil = heights.ceil()
- #
- # num_keypoints = maps.shape[1]
- #
- # if torchvision._is_tracing():
- # xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
- # maps,
- # rois,
- # widths_ceil,
- # heights_ceil,
- # widths,
- # heights,
- # offset_x,
- # offset_y,
- # torch.scalar_tensor(num_keypoints, dtype=torch.int64),
- # )
- # return xy_preds.permute(0, 2, 1), end_scores
- #
- # xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
- # end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
- # # 创建一个 512x512 的空白图像
- #
- # # combined_map = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
- # combined = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
- # combined_map = torch.zeros((len(rois), maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
- # combined_mask = torch.zeros((1, 1, 512, 512), dtype=torch.float32, device=maps.device)
- #
- # print(f"combined_map.shape: {combined_map.shape}")
- # print(f"len of rois:{len(rois)}")
- # for i in range(len(rois)):
- # roi_map_width = int(widths_ceil[i].item())
- # roi_map_height = int(heights_ceil[i].item())
- # width_correction = widths[i] / roi_map_width
- # height_correction = heights[i] / roi_map_height
- # roi_map = F.interpolate(
- # maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
- # )[:, 0]
- # x_offset = int(offset_x[i].item()) # 转换为标量
- # y_offset = int(offset_y[i].item()) # 转换为标量
- # # print(f"x_offset: {x_offset}, y_offset: {y_offset}, roi_map.shape: {roi_map.shape}")
- # # 检查偏移量是否合理
- # if y_offset < 0 or y_offset + roi_map.shape[1] > combined_map.shape[2] or x_offset < 0 or x_offset + \
- # roi_map.shape[2] > combined_map.shape[3]:
- # print("Error: Offset exceeds combined_map dimensions.")
- # else:
- # # 检查 roi_map 的大小
- # if roi_map.shape[1] <= 0 or roi_map.shape[2] <= 0:
- # print("Error: Invalid ROI size.")
- # else:
- # # 填充 combined_map
- # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = \
- # # torch.max(
- # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]],
- # # roi_map)
- # combined[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = \
- # torch.max(
- # combined[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]],
- # roi_map)
- #
- #
- # combined_map[i, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- #
- # roi_mask = torch.ones((1, roi_map_height, roi_map_width), dtype=torch.float32, device=maps.device)
- # combined_mask[0, 0, y_offset:y_offset + roi_map_height, x_offset:x_offset + roi_map_width] = roi_mask
- #
- # # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- # w = roi_map.shape[2]
- # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- #
- # x_int = pos % w
- # y_int = torch.div(pos - x_int, w, rounding_mode="floor")
- # # assert (roi_map_probs[k, y_int, x_int] ==
- # # roi_map_probs[k, :, :].max())
- # x = (x_int.float() + 0.5) * width_correction
- # y = (y_int.float() + 0.5) * height_correction
- # xy_preds[i, 0, :] = x + offset_x[i]
- # xy_preds[i, 1, :] = y + offset_y[i]
- # xy_preds[i, 2, :] = 1
- # end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
- # combined_map_final, _ = torch.max(combined_map, dim=0, keepdim=True)
- # print(f"判断:{torch.equal(combined_map_final,combined)}")
- # # print(f"combined_map_final:{combined_map_final.shape}")
- # combined_map1 = F.interpolate(combined_map_final, size=(128, 128), mode='bilinear', align_corners=False)
- # # print(f"combined_map.shape:{combined_map1.shape}")
- #
- # combined_mask = F.interpolate(combined_mask, size=(128, 128), mode='bilinear', align_corners=False)
- # combined_mask = (combined_mask >= 0.5).float() # 应用阈值0.5
- #
- # return combined_map1, xy_preds.permute(0, 2, 1), end_scores, combined_mask
- def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = keypoint_logits.shape
- if H != W:
- raise ValueError(
- f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- heatmaps = []
- valid = []
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
- kp = gt_kp_in_image[midx]
- heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
- heatmaps.append(heatmaps_per_image.view(-1))
- valid.append(valid_per_image.view(-1))
- keypoint_targets = torch.cat(heatmaps, dim=0)
- valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
- valid = torch.where(valid)[0]
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it sepaartely
- if keypoint_targets.numel() == 0 or len(valid) == 0:
- return keypoint_logits.sum() * 0
- keypoint_logits = keypoint_logits.view(N * K, H * W)
- keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
- return keypoint_loss
- def keypointrcnn_inference(x, boxes):
- # print(f'x:{x.shape}')
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- kp_probs = []
- kp_scores = []
- boxes_per_image = [box.size(0) for box in boxes]
- x2 = x.split(boxes_per_image, dim=0)
- # print(f'x2:{x2}')
- for xx, bb in zip(x2, boxes):
- kp_prob, scores = heatmaps_to_keypoints(xx, bb)
- kp_probs.append(kp_prob)
- kp_scores.append(scores)
- return kp_probs, kp_scores
- def _onnx_expand_boxes(boxes, scale):
- # type: (Tensor, float) -> Tensor
- w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- w_half = w_half.to(dtype=torch.float32) * scale
- h_half = h_half.to(dtype=torch.float32) * scale
- boxes_exp0 = x_c - w_half
- boxes_exp1 = y_c - h_half
- boxes_exp2 = x_c + w_half
- boxes_exp3 = y_c + h_half
- boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
- return boxes_exp
- # the next two functions should be merged inside Masker
- # but are kept here for the moment while we need them
- # temporarily for paste_mask_in_image
- def expand_boxes(boxes, scale):
- # type: (Tensor, float) -> Tensor
- if torchvision._is_tracing():
- return _onnx_expand_boxes(boxes, scale)
- w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- w_half *= scale
- h_half *= scale
- boxes_exp = torch.zeros_like(boxes)
- boxes_exp[:, 0] = x_c - w_half
- boxes_exp[:, 2] = x_c + w_half
- boxes_exp[:, 1] = y_c - h_half
- boxes_exp[:, 3] = y_c + h_half
- return boxes_exp
- @torch.jit.unused
- def expand_masks_tracing_scale(M, padding):
- # type: (int, int) -> float
- return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
- def expand_masks(mask, padding):
- # type: (Tensor, int) -> Tuple[Tensor, float]
- M = mask.shape[-1]
- if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
- scale = expand_masks_tracing_scale(M, padding)
- else:
- scale = float(M + 2 * padding) / M
- padded_mask = F.pad(mask, (padding,) * 4)
- return padded_mask, scale
- def paste_mask_in_image(mask, box, im_h, im_w):
- # type: (Tensor, Tensor, int, int) -> Tensor
- TO_REMOVE = 1
- w = int(box[2] - box[0] + TO_REMOVE)
- h = int(box[3] - box[1] + TO_REMOVE)
- w = max(w, 1)
- h = max(h, 1)
- # Set shape to [batchxCxHxW]
- mask = mask.expand((1, 1, -1, -1))
- # Resize mask
- mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
- mask = mask[0][0]
- im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
- x_0 = max(box[0], 0)
- x_1 = min(box[2] + 1, im_w)
- y_0 = max(box[1], 0)
- y_1 = min(box[3] + 1, im_h)
- im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- return im_mask
- def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
- one = torch.ones(1, dtype=torch.int64)
- zero = torch.zeros(1, dtype=torch.int64)
- w = box[2] - box[0] + one
- h = box[3] - box[1] + one
- w = torch.max(torch.cat((w, one)))
- h = torch.max(torch.cat((h, one)))
- # Set shape to [batchxCxHxW]
- mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
- # Resize mask
- mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
- mask = mask[0][0]
- x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
- x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
- y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
- y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
- unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- # TODO : replace below with a dynamic padding when support is added in ONNX
- # pad y
- zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
- zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
- concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
- # pad x
- zeros_x0 = torch.zeros(concat_0.size(0), x_0)
- zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
- im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
- return im_mask
- @torch.jit._script_if_tracing
- def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
- res_append = torch.zeros(0, im_h, im_w)
- for i in range(masks.size(0)):
- mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
- mask_res = mask_res.unsqueeze(0)
- res_append = torch.cat((res_append, mask_res))
- return res_append
- def paste_masks_in_image(masks, boxes, img_shape, padding=1):
- # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
- masks, scale = expand_masks(masks, padding=padding)
- boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
- im_h, im_w = img_shape
- if torchvision._is_tracing():
- return _onnx_paste_masks_in_image_loop(
- masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
- )[:, None]
- res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
- if len(res) > 0:
- ret = torch.stack(res, dim=0)[:, None]
- else:
- ret = masks.new_empty((0, 1, im_h, im_w))
- return ret
- class RoIHeads(nn.Module):
- __annotations__ = {
- "box_coder": det_utils.BoxCoder,
- "proposal_matcher": det_utils.Matcher,
- "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
- }
- def __init__(
- self,
- box_roi_pool,
- box_head,
- box_predictor,
- # Faster R-CNN training
- fg_iou_thresh,
- bg_iou_thresh,
- batch_size_per_image,
- positive_fraction,
- bbox_reg_weights,
- # Faster R-CNN inference
- score_thresh,
- nms_thresh,
- detections_per_img,
- # Mask
- mask_roi_pool=None,
- mask_head=None,
- mask_predictor=None,
- keypoint_roi_pool=None,
- keypoint_head=None,
- keypoint_predictor=None,
- wirepoint_roi_pool=None,
- wirepoint_head=None,
- wirepoint_predictor=None,
- ):
- super().__init__()
- self.box_similarity = box_ops.box_iou
- # assign ground-truth boxes for each proposal
- self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
- self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
- if bbox_reg_weights is None:
- bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
- self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
- self.box_roi_pool = box_roi_pool
- self.box_head = box_head
- self.box_predictor = box_predictor
- self.score_thresh = score_thresh
- self.nms_thresh = nms_thresh
- self.detections_per_img = detections_per_img
- self.mask_roi_pool = mask_roi_pool
- self.mask_head = mask_head
- self.mask_predictor = mask_predictor
- self.keypoint_roi_pool = keypoint_roi_pool
- self.keypoint_head = keypoint_head
- self.keypoint_predictor = keypoint_predictor
- self.wirepoint_roi_pool = wirepoint_roi_pool
- self.wirepoint_head = wirepoint_head
- self.wirepoint_predictor = wirepoint_predictor
- def has_mask(self):
- if self.mask_roi_pool is None:
- return False
- if self.mask_head is None:
- return False
- if self.mask_predictor is None:
- return False
- return True
- def has_keypoint(self):
- if self.keypoint_roi_pool is None:
- return False
- if self.keypoint_head is None:
- return False
- if self.keypoint_predictor is None:
- return False
- return True
- def has_wirepoint(self):
- if self.wirepoint_roi_pool is None:
- print(f'wirepoint_roi_pool is None')
- return False
- if self.wirepoint_head is None:
- print(f'wirepoint_head is None')
- return False
- if self.wirepoint_predictor is None:
- print(f'wirepoint_roi_predictor is None')
- return False
- return True
- def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
- # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- matched_idxs = []
- labels = []
- for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
- if gt_boxes_in_image.numel() == 0:
- # Background image
- device = proposals_in_image.device
- clamped_matched_idxs_in_image = torch.zeros(
- (proposals_in_image.shape[0],), dtype=torch.int64, device=device
- )
- labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
- else:
- # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
- match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
- matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
- clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
- labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
- labels_in_image = labels_in_image.to(dtype=torch.int64)
- # Label background (below the low threshold)
- bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
- labels_in_image[bg_inds] = 0
- # Label ignore proposals (between low and high thresholds)
- ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
- labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
- matched_idxs.append(clamped_matched_idxs_in_image)
- labels.append(labels_in_image)
- return matched_idxs, labels
- def subsample(self, labels):
- # type: (List[Tensor]) -> List[Tensor]
- sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
- sampled_inds = []
- for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
- img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
- sampled_inds.append(img_sampled_inds)
- return sampled_inds
- def add_gt_proposals(self, proposals, gt_boxes):
- # type: (List[Tensor], List[Tensor]) -> List[Tensor]
- proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
- return proposals
- def check_targets(self, targets):
- # type: (Optional[List[Dict[str, Tensor]]]) -> None
- if targets is None:
- raise ValueError("targets should not be None")
- if not all(["boxes" in t for t in targets]):
- raise ValueError("Every element of targets should have a boxes key")
- if not all(["labels" in t for t in targets]):
- raise ValueError("Every element of targets should have a labels key")
- if self.has_mask():
- if not all(["masks" in t for t in targets]):
- raise ValueError("Every element of targets should have a masks key")
- def select_training_samples(
- self,
- proposals, # type: List[Tensor]
- targets, # type: Optional[List[Dict[str, Tensor]]]
- ):
- # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
- self.check_targets(targets)
- if targets is None:
- raise ValueError("targets should not be None")
- dtype = proposals[0].dtype
- device = proposals[0].device
- gt_boxes = [t["boxes"].to(dtype) for t in targets]
- gt_labels = [t["labels"] for t in targets]
- # append ground-truth bboxes to propos
- proposals = self.add_gt_proposals(proposals, gt_boxes)
- # get matching gt indices for each proposal
- matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
- # sample a fixed proportion of positive-negative proposals
- sampled_inds = self.subsample(labels)
- matched_gt_boxes = []
- num_images = len(proposals)
- for img_id in range(num_images):
- img_sampled_inds = sampled_inds[img_id]
- proposals[img_id] = proposals[img_id][img_sampled_inds]
- labels[img_id] = labels[img_id][img_sampled_inds]
- matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
- gt_boxes_in_image = gt_boxes[img_id]
- if gt_boxes_in_image.numel() == 0:
- gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
- matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
- regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
- return proposals, matched_idxs, labels, regression_targets
- def postprocess_detections(
- self,
- class_logits, # type: Tensor
- box_regression, # type: Tensor
- proposals, # type: List[Tensor]
- image_shapes, # type: List[Tuple[int, int]]
- ):
- # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
- device = class_logits.device
- num_classes = class_logits.shape[-1]
- boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
- pred_boxes = self.box_coder.decode(box_regression, proposals)
- pred_scores = F.softmax(class_logits, -1)
- pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
- pred_scores_list = pred_scores.split(boxes_per_image, 0)
- all_boxes = []
- all_scores = []
- all_labels = []
- for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
- boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
- # create labels for each prediction
- labels = torch.arange(num_classes, device=device)
- labels = labels.view(1, -1).expand_as(scores)
- # remove predictions with the background label
- boxes = boxes[:, 1:]
- scores = scores[:, 1:]
- labels = labels[:, 1:]
- # batch everything, by making every class prediction be a separate instance
- boxes = boxes.reshape(-1, 4)
- scores = scores.reshape(-1)
- labels = labels.reshape(-1)
- # remove low scoring boxes
- inds = torch.where(scores > self.score_thresh)[0]
- boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
- # remove empty boxes
- keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
- boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- # non-maximum suppression, independently done per class
- keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
- # keep only topk scoring predictions
- keep = keep[: self.detections_per_img]
- boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- all_boxes.append(boxes)
- all_scores.append(scores)
- all_labels.append(labels)
- return all_boxes, all_scores, all_labels
- def forward(
- self,
- features, # type: Dict[str, Tensor]
- proposals, # type: List[Tensor]
- image_shapes, # type: List[Tuple[int, int]]
- targets=None, # type: Optional[List[Dict[str, Tensor]]]
- ):
- # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
- """
- Args:
- features (List[Tensor])
- proposals (List[Tensor[N, 4]])
- image_shapes (List[Tuple[H, W]])
- targets (List[Dict])
- """
- if targets is not None:
- for t in targets:
- # TODO: https://github.com/pytorch/pytorch/issues/26731
- floating_point_types = (torch.float, torch.double, torch.half)
- if not t["boxes"].dtype in floating_point_types:
- raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
- if not t["labels"].dtype == torch.int64:
- raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
- if self.has_keypoint():
- if not t["keypoints"].dtype == torch.float32:
- raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
- print(f"proposals len:{proposals[0].shape}")
- if self.training:
- proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
- else:
- labels = None
- regression_targets = None
- matched_idxs = None
- print(f"proposals:{proposals[0].shape}")
- box_features = self.box_roi_pool(features, proposals, image_shapes)
- box_features = self.box_head(box_features)
- class_logits, box_regression = self.box_predictor(box_features)
- result: List[Dict[str, torch.Tensor]] = []
- losses = {}
- if self.training:
- if labels is None:
- raise ValueError("labels cannot be None")
- if regression_targets is None:
- raise ValueError("regression_targets cannot be None")
- loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
- losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
- else:
- boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
- num_images = len(boxes)
- for i in range(num_images):
- result.append(
- {
- "boxes": boxes[i],
- "labels": labels[i],
- "scores": scores[i],
- }
- )
- print(f"proposals len:{proposals[0].shape}")
- print(f"boxes len:{boxes[0].shape}")
- print(f"proposals:{proposals}")
- print(f"boxes:{boxes}")
- # 不走这个
- if self.has_mask():
- mask_proposals = [p["boxes"] for p in result]
- if self.training:
- if matched_idxs is None:
- raise ValueError("if in training, matched_idxs should not be None")
- # during training, only focus on positive boxes
- num_images = len(proposals)
- mask_proposals = []
- pos_matched_idxs = []
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- mask_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- if self.mask_roi_pool is not None:
- mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
- mask_features = self.mask_head(mask_features)
- mask_logits = self.mask_predictor(mask_features)
- else:
- raise Exception("Expected mask_roi_pool to be not None")
- loss_mask = {}
- if self.training:
- if targets is None or pos_matched_idxs is None or mask_logits is None:
- raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
- gt_masks = [t["masks"] for t in targets]
- gt_labels = [t["labels"] for t in targets]
- rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
- loss_mask = {"loss_mask": rcnn_loss_mask}
- else:
- labels = [r["labels"] for r in result]
- masks_probs = maskrcnn_inference(mask_logits, labels)
- for mask_prob, r in zip(masks_probs, result):
- r["masks"] = mask_prob
- losses.update(loss_mask)
- # keep none checks in if conditional so torchscript will conditionally
- # compile each branch
- if self.has_keypoint():
- keypoint_proposals = [p["boxes"] for p in result]
- if self.training:
- # during training, only focus on positive boxes
- num_images = len(proposals)
- keypoint_proposals = []
- pos_matched_idxs = []
- if matched_idxs is None:
- raise ValueError("if in trainning, matched_idxs should not be None")
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- keypoint_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
- # tmp = keypoint_features[0][0]
- # plt.imshow(tmp.detach().numpy())
- # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
- keypoint_features = self.keypoint_head(keypoint_features)
- # print(f'keypoint_features:{keypoint_features.shape}')
- tmp = keypoint_features[0][0]
- plt.imshow(tmp.detach().numpy())
- keypoint_logits = self.keypoint_predictor(keypoint_features)
- # print(f'keypoint_logits:{keypoint_logits.shape}')
- """
- 接wirenet
- """
- loss_keypoint = {}
- if self.training:
- if targets is None or pos_matched_idxs is None:
- raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- gt_keypoints = [t["keypoints"] for t in targets]
- rcnn_loss_keypoint = keypointrcnn_loss(
- keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
- )
- loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
- else:
- if keypoint_logits is None or keypoint_proposals is None:
- raise ValueError(
- "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
- )
- keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
- for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
- r["keypoints"] = keypoint_prob
- r["keypoints_scores"] = kps
- losses.update(loss_keypoint)
- if self.has_wirepoint():
- wirepoint_proposals = [p["boxes"] for p in result]
- if self.training:
- # during training, only focus on positive boxes
- num_images = len(proposals)
- wirepoint_proposals = []
- pos_matched_idxs = []
- if matched_idxs is None:
- raise ValueError("if in trainning, matched_idxs should not be None")
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- wirepoint_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
- outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
- # print(f"wirepoint_proposal:{type(wirepoint_proposals)}")
- # print(f"wirepoint_proposal:{wirepoint_proposals.__len__()}")
- print(f"wirepoint_proposal[0].shape:{wirepoint_proposals[0].shape}")
- # print(f"wirepoint_proposal[0]:{wirepoint_proposals[0]}")
- print(f"wirepoint_features:{wirepoint_features.shape}")
- # outputs = merge_features(outputs, wirepoint_proposals)
- combined_output, xy_preds, end_scores, mask_key = heatmaps_to_keypoints_new(outputs, wirepoint_proposals)
- wire_combined_features, wire_xy_preds, wire_end_scores, wire_mask = heatmaps_to_keypoints_new(
- wirepoint_features, wirepoint_proposals)
- # print(f'combined_output:{combined_output.shape}')
- print(f"wire_combined_features:{wire_combined_features.shape}")
- wirepoint_logits = self.wirepoint_predictor(inputs=combined_output, features=wire_combined_features,
- mask=wire_mask, targets=targets)
- x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
- # print(f'keypoint_features:{wirepoint_features.shape}')
- if self.training:
- if targets is None or pos_matched_idxs is None:
- raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
- rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, combined_output, x, y, idx, loss_weight)
- loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
- else:
- pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
- result.append(pred)
- loss_wirepoint = {}
- losses.update(loss_wirepoint)
- return result, losses
- def merge_features(features, proposals):
- print("merge==========================================================================start")
- print(f"Features type: {type(features)}, shape: {features.shape}")
- print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
- print(f"Proposals : {proposals[0].shape},")
- def diagnose_input(features, proposals):
- """诊断输入数据"""
- print("Input Diagnostics:")
- print(f"Features type: {type(features)}, shape: {features.shape}")
- print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
- for i, p in enumerate(proposals):
- print(f"Proposal {i} shape: {p.shape}")
- def validate_inputs(features, proposals):
- """验证输入的有效性"""
- if features is None or proposals is None:
- raise ValueError("Features or proposals cannot be None")
- proposals_count = sum([p.size(0) for p in proposals])
- features_size = features.size(0)
- if proposals_count != features_size:
- raise ValueError(
- f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
- )
- def safe_max_reduction(features_per_img):
- """安全的最大值压缩"""
- if features_per_img.numel() == 0:
- return torch.zeros_like(features_per_img).unsqueeze(0)
- try:
- # 沿着第0维求最大值,保持维度
- max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
- return max_features
- except Exception as e:
- print(f"Max reduction error: {e}")
- return features_per_img.unsqueeze(0)
- try:
- # 诊断输入(可选)
- # diagnose_input(features, proposals)
- # 验证输入
- validate_inputs(features, proposals)
- # 分割特征
- split_features = []
- start_idx = 0
- for proposal in proposals:
- # 提取当前图像的特征
- current_features = features[start_idx:start_idx + proposal.size(0)]
- split_features.append(current_features)
- start_idx += proposal.size(0)
- # 每张图像特征压缩
- features_imgs = []
- for features_per_img in split_features:
- compressed_features = safe_max_reduction(features_per_img)
- features_imgs.append(compressed_features)
- # 合并特征
- merged_features = torch.cat(features_imgs, dim=0)
- return merged_features
- except Exception as e:
- print(f"Error in merge_features: {e}")
- # 返回原始特征或None
- return features
- '''
- from collections import OrderedDict
- from typing import Dict, List, Optional, Tuple
- import matplotlib.pyplot as plt
- import torch
- import torch.nn.functional as F
- import torchvision
- from torch import nn, Tensor
- from torchvision.ops import boxes as box_ops, roi_align
- from models.wirenet import _utils as det_utils
- from torch.utils.data.dataloader import default_collate
- def l2loss(input, target):
- return ((target - input) ** 2).mean(2).mean(1)
- def cross_entropy_loss(logits, positive):
- nlogp = -F.log_softmax(logits, dim=0)
- return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
- def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
- logp = torch.sigmoid(logits) + offset
- loss = torch.abs(logp - target)
- if mask is not None:
- w = mask.mean(2, True).mean(1, True)
- w[w == 0] = 1
- loss = loss * (mask / w)
- return loss.mean(2).mean(1)
- def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
- # output, feature: head返回结果
- # x, y, idx : line中间生成结果
- result = {}
- batch, channel, row, col = output.shape
- wires_targets = [t["wires"] for t in targets]
- wires_targets = wires_targets.copy()
- # print(f'wires_target:{wires_targets}')
- # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
- junc_maps = [d["junc_map"] for d in wires_targets]
- junc_offsets = [d["junc_offset"] for d in wires_targets]
- line_maps = [d["line_map"] for d in wires_targets]
- junc_map_tensor = torch.stack(junc_maps, dim=0)
- junc_offset_tensor = torch.stack(junc_offsets, dim=0)
- line_map_tensor = torch.stack(line_maps, dim=0)
- T = {"junc_map": junc_map_tensor, "junc_offset": junc_offset_tensor, "line_map": line_map_tensor}
- n_jtyp = T["junc_map"].shape[1]
- for task in ["junc_map"]:
- T[task] = T[task].permute(1, 0, 2, 3)
- for task in ["junc_offset"]:
- T[task] = T[task].permute(1, 2, 0, 3, 4)
- offset = [2, 3, 5]
- losses = []
- output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
- jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
- lmap = output[offset[0]: offset[1]].squeeze(0)
- joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
- L = OrderedDict()
- L["junc_map"] = sum(
- cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
- )
- L["line_map"] = (
- F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
- .mean(2)
- .mean(1)
- )
- L["junc_offset"] = sum(
- sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
- for i in range(n_jtyp)
- for j in range(2)
- )
- for loss_name in L:
- L[loss_name].mul_(loss_weight[loss_name])
- losses.append(L)
- result["losses"] = losses
- loss = nn.BCEWithLogitsLoss(reduction="none")
- loss = loss(x, y)
- lpos_mask, lneg_mask = y, 1 - y
- loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
- def sum_batch(x):
- xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(batch)]
- return torch.cat(xs)
- lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
- lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
- result["losses"][0]["lpos"] = lpos * loss_weight["lpos"]
- result["losses"][0]["lneg"] = lneg * loss_weight["lneg"]
- return result
- def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
- result = {}
- result["wires"] = {}
- p = torch.cat(ps)
- s = torch.sigmoid(input)
- b = s > 0.5
- lines = []
- score = []
- # print(f"n_batch:{n_batch}")
- for i in range(n_batch):
- # print(f"idx:{idx}")
- p0 = p[idx[i]: idx[i + 1]]
- s0 = s[idx[i]: idx[i + 1]]
- mask = b[idx[i]: idx[i + 1]]
- p0 = p0[mask]
- s0 = s0[mask]
- if len(p0) == 0:
- lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
- score.append(torch.zeros([1, n_out_line], device=p.device))
- else:
- arg = torch.argsort(s0, descending=True)
- p0, s0 = p0[arg], s0[arg]
- lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
- score.append(s0[None, torch.arange(n_out_line) % len(s0)])
- for j in range(len(jcs[i])):
- if len(jcs[i][j]) == 0:
- jcs[i][j] = torch.zeros([n_out_junc, 2], device=p.device)
- jcs[i][j] = jcs[i][j][
- None, torch.arange(n_out_junc) % len(jcs[i][j])
- ]
- result["wires"]["lines"] = torch.cat(lines)
- result["wires"]["score"] = torch.cat(score)
- result["wires"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
- if len(jcs[i]) > 1:
- result["preds"]["junts"] = torch.cat(
- [jcs[i][1] for i in range(n_batch)]
- )
- return result
- def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
- # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
- """
- Computes the loss for Faster R-CNN.
- Args:
- class_logits (Tensor)
- box_regression (Tensor)
- labels (list[BoxList])
- regression_targets (Tensor)
- Returns:
- classification_loss (Tensor)
- box_loss (Tensor)
- """
- labels = torch.cat(labels, dim=0)
- regression_targets = torch.cat(regression_targets, dim=0)
- classification_loss = F.cross_entropy(class_logits, labels)
- # get indices that correspond to the regression targets for
- # the corresponding ground truth labels, to be used with
- # advanced indexing
- sampled_pos_inds_subset = torch.where(labels > 0)[0]
- labels_pos = labels[sampled_pos_inds_subset]
- N, num_classes = class_logits.shape
- box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
- box_loss = F.smooth_l1_loss(
- box_regression[sampled_pos_inds_subset, labels_pos],
- regression_targets[sampled_pos_inds_subset],
- beta=1 / 9,
- reduction="sum",
- )
- box_loss = box_loss / labels.numel()
- return classification_loss, box_loss
- def maskrcnn_inference(x, labels):
- # type: (Tensor, List[Tensor]) -> List[Tensor]
- """
- From the results of the CNN, post process the masks
- by taking the mask corresponding to the class with max
- probability (which are of fixed size and directly output
- by the CNN) and return the masks in the mask field of the BoxList.
- Args:
- x (Tensor): the mask logits
- labels (list[BoxList]): bounding boxes that are used as
- reference, one for ech image
- Returns:
- results (list[BoxList]): one BoxList for each image, containing
- the extra field mask
- """
- mask_prob = x.sigmoid()
- # select masks corresponding to the predicted classes
- num_masks = x.shape[0]
- boxes_per_image = [label.shape[0] for label in labels]
- labels = torch.cat(labels)
- index = torch.arange(num_masks, device=labels.device)
- mask_prob = mask_prob[index, labels][:, None]
- mask_prob = mask_prob.split(boxes_per_image, dim=0)
- return mask_prob
- def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
- # type: (Tensor, Tensor, Tensor, int) -> Tensor
- """
- Given segmentation masks and the bounding boxes corresponding
- to the location of the masks in the image, this function
- crops and resizes the masks in the position defined by the
- boxes. This prepares the masks for them to be fed to the
- loss computation as the targets.
- """
- matched_idxs = matched_idxs.to(boxes)
- rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
- gt_masks = gt_masks[:, None].to(rois)
- return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
- def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- """
- Args:
- proposals (list[BoxList])
- mask_logits (Tensor)
- targets (list[BoxList])
- Return:
- mask_loss (Tensor): scalar tensor containing the loss
- """
- discretization_size = mask_logits.shape[-1]
- # print(f'mask_logits:{mask_logits},gt_masks:{gt_masks},,gt_labels:{gt_labels}]')
- # print(f'mask discretization_size:{discretization_size}')
- labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
- # print(f'mask labels:{labels}')
- mask_targets = [
- project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
- ]
- labels = torch.cat(labels, dim=0)
- # print(f'mask labels1:{labels}')
- mask_targets = torch.cat(mask_targets, dim=0)
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it separately
- if mask_targets.numel() == 0:
- return mask_logits.sum() * 0
- # print(f'mask_targets:{mask_targets.shape},mask_logits:{mask_logits.shape}')
- # print(f'mask_targets:{mask_targets}')
- mask_loss = F.binary_cross_entropy_with_logits(
- mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
- )
- # print(f'mask_loss:{mask_loss}')
- return mask_loss
- def keypoints_to_heatmap(keypoints, rois, heatmap_size):
- # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
- scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
- offset_x = offset_x[:, None]
- offset_y = offset_y[:, None]
- scale_x = scale_x[:, None]
- scale_y = scale_y[:, None]
- x = keypoints[..., 0]
- y = keypoints[..., 1]
- x_boundary_inds = x == rois[:, 2][:, None]
- y_boundary_inds = y == rois[:, 3][:, None]
- x = (x - offset_x) * scale_x
- x = x.floor().long()
- y = (y - offset_y) * scale_y
- y = y.floor().long()
- x[x_boundary_inds] = heatmap_size - 1
- y[y_boundary_inds] = heatmap_size - 1
- valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
- vis = keypoints[..., 2] > 0
- valid = (valid_loc & vis).long()
- lin_ind = y * heatmap_size + x
- heatmaps = lin_ind * valid
- return heatmaps, valid
- def _onnx_heatmaps_to_keypoints(
- maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
- ):
- num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
- width_correction = widths_i / roi_map_width
- height_correction = heights_i / roi_map_height
- roi_map = F.interpolate(
- maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
- )[:, 0]
- w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = (pos - x_int) // w
- x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
- dtype=torch.float32
- )
- y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
- dtype=torch.float32
- )
- xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
- xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
- xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
- xy_preds_i = torch.stack(
- [
- xy_preds_i_0.to(dtype=torch.float32),
- xy_preds_i_1.to(dtype=torch.float32),
- xy_preds_i_2.to(dtype=torch.float32),
- ],
- 0,
- )
- # TODO: simplify when indexing without rank will be supported by ONNX
- base = num_keypoints * num_keypoints + num_keypoints + 1
- ind = torch.arange(num_keypoints)
- ind = ind.to(dtype=torch.int64) * base
- end_scores_i = (
- roi_map.index_select(1, y_int.to(dtype=torch.int64))
- .index_select(2, x_int.to(dtype=torch.int64))
- .view(-1)
- .index_select(0, ind.to(dtype=torch.int64))
- )
- return xy_preds_i, end_scores_i
- @torch.jit._script_if_tracing
- def _onnx_heatmaps_to_keypoints_loop(
- maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
- ):
- xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
- for i in range(int(rois.size(0))):
- xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
- maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
- )
- xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
- end_scores = torch.cat(
- (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
- )
- return xy_preds, end_scores
- def heatmaps_to_keypoints(maps, rois):
- """Extract predicted keypoint locations from heatmaps. Output has shape
- (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- for each keypoint.
- """
- # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # consistency with keypoints_to_heatmap_labels by using the conversion from
- # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # continuous coordinate.
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- widths = rois[:, 2] - rois[:, 0]
- heights = rois[:, 3] - rois[:, 1]
- widths = widths.clamp(min=1)
- heights = heights.clamp(min=1)
- widths_ceil = widths.ceil()
- heights_ceil = heights.ceil()
- num_keypoints = maps.shape[1]
- if torchvision._is_tracing():
- xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
- maps,
- rois,
- widths_ceil,
- heights_ceil,
- widths,
- heights,
- offset_x,
- offset_y,
- torch.scalar_tensor(num_keypoints, dtype=torch.int64),
- )
- return xy_preds.permute(0, 2, 1), end_scores
- xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
- for i in range(len(rois)):
- roi_map_width = int(widths_ceil[i].item())
- roi_map_height = int(heights_ceil[i].item())
- width_correction = widths[i] / roi_map_width
- height_correction = heights[i] / roi_map_height
- roi_map = F.interpolate(
- maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
- )[:, 0]
- # roi_map_probs = scores_to_probs(roi_map.copy())
- w = roi_map.shape[2]
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = torch.div(pos - x_int, w, rounding_mode="floor")
- # assert (roi_map_probs[k, y_int, x_int] ==
- # roi_map_probs[k, :, :].max())
- x = (x_int.float() + 0.5) * width_correction
- y = (y_int.float() + 0.5) * height_correction
- xy_preds[i, 0, :] = x + offset_x[i]
- xy_preds[i, 1, :] = y + offset_y[i]
- xy_preds[i, 2, :] = 1
- end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
- return xy_preds.permute(0, 2, 1), end_scores
- import torch
- import torch.nn.functional as F
- def heatmaps_to_keypoints_new(maps, rois):
- # """Extract predicted keypoint locations from heatmaps. Output has shape
- # (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
- # for each keypoint.
- # """
- # This function converts a discrete image coordinate in a HEATMAP_SIZE x
- # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
- # consistency with keypoints_to_heatmap_labels by using the conversion from
- # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
- # continuous coordinate.
- print(f"maps.shape:{maps.shape}")
- rois = rois[0]
- offset_x = rois[:, 0]
- offset_y = rois[:, 1]
- widths = rois[:, 2] - rois[:, 0]
- heights = rois[:, 3] - rois[:, 1]
- widths = widths.clamp(min=1)
- heights = heights.clamp(min=1)
- widths_ceil = widths.ceil()
- heights_ceil = heights.ceil()
- num_keypoints = maps.shape[1]
- if torchvision._is_tracing():
- xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
- maps,
- rois,
- widths_ceil,
- heights_ceil,
- widths,
- heights,
- offset_x,
- offset_y,
- torch.scalar_tensor(num_keypoints, dtype=torch.int64),
- )
- return xy_preds.permute(0, 2, 1), end_scores
- xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
- end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
- # 创建一个 512x512 的空白图像
- combined_map = torch.zeros((1, maps.shape[1], 512, 512), dtype=torch.float32, device=maps.device)
- print(f"combined_map.shape: {combined_map.shape}")
- print(f"len of rois:{len(rois)}")
- for i in range(len(rois)):
- roi_map_width = int(widths_ceil[i].item())
- roi_map_height = int(heights_ceil[i].item())
- width_correction = widths[i] / roi_map_width
- height_correction = heights[i] / roi_map_height
- roi_map = F.interpolate(
- maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
- )[:, 0]
- x_offset = int(offset_x[i].item()) # 转换为标量
- y_offset = int(offset_y[i].item()) # 转换为标量
- # print(f"x_offset: {x_offset}, y_offset: {y_offset}, roi_map.shape: {roi_map.shape}")
- # 检查偏移量是否合理
- if y_offset < 0 or y_offset + roi_map.shape[1] > combined_map.shape[2] or x_offset < 0 or x_offset + roi_map.shape[2] > combined_map.shape[3]:
- print("Error: Offset exceeds combined_map dimensions.")
- else:
- # 检查 roi_map 的大小
- if roi_map.shape[1] <= 0 or roi_map.shape[2] <= 0:
- print("Error: Invalid ROI size.")
- else:
- # 填充 combined_map
- combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- # combined_map[0, :, y_offset:y_offset + roi_map.shape[1], x_offset:x_offset + roi_map.shape[2]] = roi_map
- w = roi_map.shape[2]
- pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
- x_int = pos % w
- y_int = torch.div(pos - x_int, w, rounding_mode="floor")
- # assert (roi_map_probs[k, y_int, x_int] ==
- # roi_map_probs[k, :, :].max())
- x = (x_int.float() + 0.5) * width_correction
- y = (y_int.float() + 0.5) * height_correction
- xy_preds[i, 0, :] = x + offset_x[i]
- xy_preds[i, 1, :] = y + offset_y[i]
- xy_preds[i, 2, :] = 1
- end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
- combined_map= F.interpolate(combined_map, size=(128, 128), mode='bilinear', align_corners=False)
- print(f"combined_map.shape:{combined_map.shape}")
- return combined_map, xy_preds.permute(0, 2, 1), end_scores
- def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
- # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
- N, K, H, W = keypoint_logits.shape
- if H != W:
- raise ValueError(
- f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- heatmaps = []
- valid = []
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
- kp = gt_kp_in_image[midx]
- heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
- heatmaps.append(heatmaps_per_image.view(-1))
- valid.append(valid_per_image.view(-1))
- keypoint_targets = torch.cat(heatmaps, dim=0)
- valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
- valid = torch.where(valid)[0]
- # torch.mean (in binary_cross_entropy_with_logits) doesn't
- # accept empty tensors, so handle it sepaartely
- if keypoint_targets.numel() == 0 or len(valid) == 0:
- return keypoint_logits.sum() * 0
- keypoint_logits = keypoint_logits.view(N * K, H * W)
- keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
- return keypoint_loss
- def keypointrcnn_inference(x, boxes):
- # print(f'x:{x.shape}')
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- kp_probs = []
- kp_scores = []
- boxes_per_image = [box.size(0) for box in boxes]
- x2 = x.split(boxes_per_image, dim=0)
- # print(f'x2:{x2}')
- for xx, bb in zip(x2, boxes):
- kp_prob, scores = heatmaps_to_keypoints(xx, bb)
- kp_probs.append(kp_prob)
- kp_scores.append(scores)
- return kp_probs, kp_scores
- def _onnx_expand_boxes(boxes, scale):
- # type: (Tensor, float) -> Tensor
- w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- w_half = w_half.to(dtype=torch.float32) * scale
- h_half = h_half.to(dtype=torch.float32) * scale
- boxes_exp0 = x_c - w_half
- boxes_exp1 = y_c - h_half
- boxes_exp2 = x_c + w_half
- boxes_exp3 = y_c + h_half
- boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
- return boxes_exp
- # the next two functions should be merged inside Masker
- # but are kept here for the moment while we need them
- # temporarily for paste_mask_in_image
- def expand_boxes(boxes, scale):
- # type: (Tensor, float) -> Tensor
- if torchvision._is_tracing():
- return _onnx_expand_boxes(boxes, scale)
- w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
- h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
- x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
- y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
- w_half *= scale
- h_half *= scale
- boxes_exp = torch.zeros_like(boxes)
- boxes_exp[:, 0] = x_c - w_half
- boxes_exp[:, 2] = x_c + w_half
- boxes_exp[:, 1] = y_c - h_half
- boxes_exp[:, 3] = y_c + h_half
- return boxes_exp
- @torch.jit.unused
- def expand_masks_tracing_scale(M, padding):
- # type: (int, int) -> float
- return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
- def expand_masks(mask, padding):
- # type: (Tensor, int) -> Tuple[Tensor, float]
- M = mask.shape[-1]
- if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
- scale = expand_masks_tracing_scale(M, padding)
- else:
- scale = float(M + 2 * padding) / M
- padded_mask = F.pad(mask, (padding,) * 4)
- return padded_mask, scale
- def paste_mask_in_image(mask, box, im_h, im_w):
- # type: (Tensor, Tensor, int, int) -> Tensor
- TO_REMOVE = 1
- w = int(box[2] - box[0] + TO_REMOVE)
- h = int(box[3] - box[1] + TO_REMOVE)
- w = max(w, 1)
- h = max(h, 1)
- # Set shape to [batchxCxHxW]
- mask = mask.expand((1, 1, -1, -1))
- # Resize mask
- mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
- mask = mask[0][0]
- im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
- x_0 = max(box[0], 0)
- x_1 = min(box[2] + 1, im_w)
- y_0 = max(box[1], 0)
- y_1 = min(box[3] + 1, im_h)
- im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- return im_mask
- def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
- one = torch.ones(1, dtype=torch.int64)
- zero = torch.zeros(1, dtype=torch.int64)
- w = box[2] - box[0] + one
- h = box[3] - box[1] + one
- w = torch.max(torch.cat((w, one)))
- h = torch.max(torch.cat((h, one)))
- # Set shape to [batchxCxHxW]
- mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
- # Resize mask
- mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
- mask = mask[0][0]
- x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
- x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
- y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
- y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
- unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
- # TODO : replace below with a dynamic padding when support is added in ONNX
- # pad y
- zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
- zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
- concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
- # pad x
- zeros_x0 = torch.zeros(concat_0.size(0), x_0)
- zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
- im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
- return im_mask
- @torch.jit._script_if_tracing
- def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
- res_append = torch.zeros(0, im_h, im_w)
- for i in range(masks.size(0)):
- mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
- mask_res = mask_res.unsqueeze(0)
- res_append = torch.cat((res_append, mask_res))
- return res_append
- def paste_masks_in_image(masks, boxes, img_shape, padding=1):
- # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
- masks, scale = expand_masks(masks, padding=padding)
- boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
- im_h, im_w = img_shape
- if torchvision._is_tracing():
- return _onnx_paste_masks_in_image_loop(
- masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
- )[:, None]
- res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
- if len(res) > 0:
- ret = torch.stack(res, dim=0)[:, None]
- else:
- ret = masks.new_empty((0, 1, im_h, im_w))
- return ret
- class RoIHeads(nn.Module):
- __annotations__ = {
- "box_coder": det_utils.BoxCoder,
- "proposal_matcher": det_utils.Matcher,
- "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
- }
- def __init__(
- self,
- box_roi_pool,
- box_head,
- box_predictor,
- # Faster R-CNN training
- fg_iou_thresh,
- bg_iou_thresh,
- batch_size_per_image,
- positive_fraction,
- bbox_reg_weights,
- # Faster R-CNN inference
- score_thresh,
- nms_thresh,
- detections_per_img,
- # Mask
- mask_roi_pool=None,
- mask_head=None,
- mask_predictor=None,
- keypoint_roi_pool=None,
- keypoint_head=None,
- keypoint_predictor=None,
- wirepoint_roi_pool=None,
- wirepoint_head=None,
- wirepoint_predictor=None,
- ):
- super().__init__()
- self.box_similarity = box_ops.box_iou
- # assign ground-truth boxes for each proposal
- self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
- self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
- if bbox_reg_weights is None:
- bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
- self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
- self.box_roi_pool = box_roi_pool
- self.box_head = box_head
- self.box_predictor = box_predictor
- self.score_thresh = score_thresh
- self.nms_thresh = nms_thresh
- self.detections_per_img = detections_per_img
- self.mask_roi_pool = mask_roi_pool
- self.mask_head = mask_head
- self.mask_predictor = mask_predictor
- self.keypoint_roi_pool = keypoint_roi_pool
- self.keypoint_head = keypoint_head
- self.keypoint_predictor = keypoint_predictor
- self.wirepoint_roi_pool = wirepoint_roi_pool
- self.wirepoint_head = wirepoint_head
- self.wirepoint_predictor = wirepoint_predictor
- def has_mask(self):
- if self.mask_roi_pool is None:
- return False
- if self.mask_head is None:
- return False
- if self.mask_predictor is None:
- return False
- return True
- def has_keypoint(self):
- if self.keypoint_roi_pool is None:
- return False
- if self.keypoint_head is None:
- return False
- if self.keypoint_predictor is None:
- return False
- return True
- def has_wirepoint(self):
- if self.wirepoint_roi_pool is None:
- print(f'wirepoint_roi_pool is None')
- return False
- if self.wirepoint_head is None:
- print(f'wirepoint_head is None')
- return False
- if self.wirepoint_predictor is None:
- print(f'wirepoint_roi_predictor is None')
- return False
- return True
- def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
- # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
- matched_idxs = []
- labels = []
- for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
- if gt_boxes_in_image.numel() == 0:
- # Background image
- device = proposals_in_image.device
- clamped_matched_idxs_in_image = torch.zeros(
- (proposals_in_image.shape[0],), dtype=torch.int64, device=device
- )
- labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
- else:
- # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
- match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
- matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
- clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
- labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
- labels_in_image = labels_in_image.to(dtype=torch.int64)
- # Label background (below the low threshold)
- bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
- labels_in_image[bg_inds] = 0
- # Label ignore proposals (between low and high thresholds)
- ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
- labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
- matched_idxs.append(clamped_matched_idxs_in_image)
- labels.append(labels_in_image)
- return matched_idxs, labels
- def subsample(self, labels):
- # type: (List[Tensor]) -> List[Tensor]
- sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
- sampled_inds = []
- for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
- img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
- sampled_inds.append(img_sampled_inds)
- return sampled_inds
- def add_gt_proposals(self, proposals, gt_boxes):
- # type: (List[Tensor], List[Tensor]) -> List[Tensor]
- proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
- return proposals
- def check_targets(self, targets):
- # type: (Optional[List[Dict[str, Tensor]]]) -> None
- if targets is None:
- raise ValueError("targets should not be None")
- if not all(["boxes" in t for t in targets]):
- raise ValueError("Every element of targets should have a boxes key")
- if not all(["labels" in t for t in targets]):
- raise ValueError("Every element of targets should have a labels key")
- if self.has_mask():
- if not all(["masks" in t for t in targets]):
- raise ValueError("Every element of targets should have a masks key")
- def select_training_samples(
- self,
- proposals, # type: List[Tensor]
- targets, # type: Optional[List[Dict[str, Tensor]]]
- ):
- # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
- self.check_targets(targets)
- if targets is None:
- raise ValueError("targets should not be None")
- dtype = proposals[0].dtype
- device = proposals[0].device
- gt_boxes = [t["boxes"].to(dtype) for t in targets]
- gt_labels = [t["labels"] for t in targets]
- # append ground-truth bboxes to propos
- proposals = self.add_gt_proposals(proposals, gt_boxes)
- # get matching gt indices for each proposal
- matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
- # sample a fixed proportion of positive-negative proposals
- sampled_inds = self.subsample(labels)
- matched_gt_boxes = []
- num_images = len(proposals)
- for img_id in range(num_images):
- img_sampled_inds = sampled_inds[img_id]
- proposals[img_id] = proposals[img_id][img_sampled_inds]
- labels[img_id] = labels[img_id][img_sampled_inds]
- matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
- gt_boxes_in_image = gt_boxes[img_id]
- if gt_boxes_in_image.numel() == 0:
- gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
- matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
- regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
- return proposals, matched_idxs, labels, regression_targets
- def postprocess_detections(
- self,
- class_logits, # type: Tensor
- box_regression, # type: Tensor
- proposals, # type: List[Tensor]
- image_shapes, # type: List[Tuple[int, int]]
- ):
- # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
- device = class_logits.device
- num_classes = class_logits.shape[-1]
- boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
- pred_boxes = self.box_coder.decode(box_regression, proposals)
- pred_scores = F.softmax(class_logits, -1)
- pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
- pred_scores_list = pred_scores.split(boxes_per_image, 0)
- all_boxes = []
- all_scores = []
- all_labels = []
- for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
- boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
- # create labels for each prediction
- labels = torch.arange(num_classes, device=device)
- labels = labels.view(1, -1).expand_as(scores)
- # remove predictions with the background label
- boxes = boxes[:, 1:]
- scores = scores[:, 1:]
- labels = labels[:, 1:]
- # batch everything, by making every class prediction be a separate instance
- boxes = boxes.reshape(-1, 4)
- scores = scores.reshape(-1)
- labels = labels.reshape(-1)
- # remove low scoring boxes
- inds = torch.where(scores > self.score_thresh)[0]
- boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
- # remove empty boxes
- keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
- boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- # non-maximum suppression, independently done per class
- keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
- # keep only topk scoring predictions
- keep = keep[: self.detections_per_img]
- boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
- all_boxes.append(boxes)
- all_scores.append(scores)
- all_labels.append(labels)
- return all_boxes, all_scores, all_labels
- def forward(
- self,
- features, # type: Dict[str, Tensor]
- proposals, # type: List[Tensor]
- image_shapes, # type: List[Tuple[int, int]]
- targets=None, # type: Optional[List[Dict[str, Tensor]]]
- ):
- # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
- """
- Args:
- features (List[Tensor])
- proposals (List[Tensor[N, 4]])
- image_shapes (List[Tuple[H, W]])
- targets (List[Dict])
- """
- if targets is not None:
- for t in targets:
- # TODO: https://github.com/pytorch/pytorch/issues/26731
- floating_point_types = (torch.float, torch.double, torch.half)
- if not t["boxes"].dtype in floating_point_types:
- raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
- if not t["labels"].dtype == torch.int64:
- raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
- if self.has_keypoint():
- if not t["keypoints"].dtype == torch.float32:
- raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
- if self.training:
- proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
- else:
- labels = None
- regression_targets = None
- matched_idxs = None
- box_features = self.box_roi_pool(features, proposals, image_shapes)
- box_features = self.box_head(box_features)
- class_logits, box_regression = self.box_predictor(box_features)
- result: List[Dict[str, torch.Tensor]] = []
- losses = {}
- if self.training:
- if labels is None:
- raise ValueError("labels cannot be None")
- if regression_targets is None:
- raise ValueError("regression_targets cannot be None")
- loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
- losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
- else:
- boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
- num_images = len(boxes)
- for i in range(num_images):
- result.append(
- {
- "boxes": boxes[i],
- "labels": labels[i],
- "scores": scores[i],
- }
- )
- if self.has_mask():
- mask_proposals = [p["boxes"] for p in result]
- if self.training:
- if matched_idxs is None:
- raise ValueError("if in training, matched_idxs should not be None")
- # during training, only focus on positive boxes
- num_images = len(proposals)
- mask_proposals = []
- pos_matched_idxs = []
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- mask_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- if self.mask_roi_pool is not None:
- mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
- mask_features = self.mask_head(mask_features)
- mask_logits = self.mask_predictor(mask_features)
- else:
- raise Exception("Expected mask_roi_pool to be not None")
- loss_mask = {}
- if self.training:
- if targets is None or pos_matched_idxs is None or mask_logits is None:
- raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
- gt_masks = [t["masks"] for t in targets]
- gt_labels = [t["labels"] for t in targets]
- rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
- loss_mask = {"loss_mask": rcnn_loss_mask}
- else:
- labels = [r["labels"] for r in result]
- masks_probs = maskrcnn_inference(mask_logits, labels)
- for mask_prob, r in zip(masks_probs, result):
- r["masks"] = mask_prob
- losses.update(loss_mask)
- # keep none checks in if conditional so torchscript will conditionally
- # compile each branch
- if self.has_keypoint():
- keypoint_proposals = [p["boxes"] for p in result]
- if self.training:
- # during training, only focus on positive boxes
- num_images = len(proposals)
- keypoint_proposals = []
- pos_matched_idxs = []
- if matched_idxs is None:
- raise ValueError("if in trainning, matched_idxs should not be None")
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- keypoint_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
- # tmp = keypoint_features[0][0]
- # plt.imshow(tmp.detach().numpy())
- # print(f'keypoint_features from roi_pool:{keypoint_features.shape}')
- keypoint_features = self.keypoint_head(keypoint_features)
- # print(f'keypoint_features:{keypoint_features.shape}')
- tmp = keypoint_features[0][0]
- plt.imshow(tmp.detach().numpy())
- keypoint_logits = self.keypoint_predictor(keypoint_features)
- # print(f'keypoint_logits:{keypoint_logits.shape}')
- """
- 接wirenet
- """
- loss_keypoint = {}
- if self.training:
- if targets is None or pos_matched_idxs is None:
- raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- gt_keypoints = [t["keypoints"] for t in targets]
- rcnn_loss_keypoint = keypointrcnn_loss(
- keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
- )
- loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
- else:
- if keypoint_logits is None or keypoint_proposals is None:
- raise ValueError(
- "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
- )
- keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
- for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
- r["keypoints"] = keypoint_prob
- r["keypoints_scores"] = kps
- losses.update(loss_keypoint)
- if self.has_wirepoint():
- wirepoint_proposals = [p["boxes"] for p in result]
- if self.training:
- # during training, only focus on positive boxes
- num_images = len(proposals)
- wirepoint_proposals = []
- pos_matched_idxs = []
- if matched_idxs is None:
- raise ValueError("if in trainning, matched_idxs should not be None")
- for img_id in range(num_images):
- pos = torch.where(labels[img_id] > 0)[0]
- wirepoint_proposals.append(proposals[img_id][pos])
- pos_matched_idxs.append(matched_idxs[img_id][pos])
- else:
- pos_matched_idxs = None
- wirepoint_features = self.wirepoint_roi_pool(features, wirepoint_proposals, image_shapes)
- outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
- # print(f"wirepoint_proposal:{type(wirepoint_proposals)}")
- # print(f"wirepoint_proposal:{wirepoint_proposals.__len__()}")
- # print(f"wirepoint_proposal[0].shape:{wirepoint_proposals[0].shape}")
- # print(f"wirepoint_proposal[0]:{wirepoint_proposals[0]}")
- # outputs = merge_features(outputs, wirepoint_proposals)
- combined_output, xy_preds, end_scores = heatmaps_to_keypoints_new(outputs, wirepoint_proposals)
- wire_combined_features, wire_xy_preds, wire_end_scores = heatmaps_to_keypoints_new(wirepoint_features, wirepoint_proposals)
- # print(f'combined_output:{combined_output.shape}')
- wirepoint_logits = self.wirepoint_predictor(inputs=combined_output, features=wire_combined_features,
- targets=targets)
- x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = wirepoint_logits
- # print(f'keypoint_features:{wirepoint_features.shape}')
- if self.training:
- if targets is None or pos_matched_idxs is None:
- raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
- loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
- rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, combined_output, x, y, idx, loss_weight)
- loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
- else:
- pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
- result.append(pred)
- loss_wirepoint = {}
- losses.update(loss_wirepoint)
- return result, losses
- def merge_features(features, proposals):
- print("merge==========================================================================start")
- print(f"Features type: {type(features)}, shape: {features.shape}")
- print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
- print(f"Proposals : {proposals[0].shape},")
- def diagnose_input(features, proposals):
- """诊断输入数据"""
- print("Input Diagnostics:")
- print(f"Features type: {type(features)}, shape: {features.shape}")
- print(f"Proposals type: {type(proposals)}, length: {len(proposals)}")
- for i, p in enumerate(proposals):
- print(f"Proposal {i} shape: {p.shape}")
- def validate_inputs(features, proposals):
- """验证输入的有效性"""
- if features is None or proposals is None:
- raise ValueError("Features or proposals cannot be None")
- proposals_count = sum([p.size(0) for p in proposals])
- features_size = features.size(0)
- if proposals_count != features_size:
- raise ValueError(
- f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
- )
- def safe_max_reduction(features_per_img):
- """安全的最大值压缩"""
- if features_per_img.numel() == 0:
- return torch.zeros_like(features_per_img).unsqueeze(0)
- try:
- # 沿着第0维求最大值,保持维度
- max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
- return max_features
- except Exception as e:
- print(f"Max reduction error: {e}")
- return features_per_img.unsqueeze(0)
- try:
- # 诊断输入(可选)
- # diagnose_input(features, proposals)
- # 验证输入
- validate_inputs(features, proposals)
- # 分割特征
- split_features = []
- start_idx = 0
- for proposal in proposals:
- # 提取当前图像的特征
- current_features = features[start_idx:start_idx + proposal.size(0)]
- split_features.append(current_features)
- start_idx += proposal.size(0)
- # 每张图像特征压缩
- features_imgs = []
- for features_per_img in split_features:
- compressed_features = safe_max_reduction(features_per_img)
- features_imgs.append(compressed_features)
- # 合并特征
- merged_features = torch.cat(features_imgs, dim=0)
- return merged_features
- except Exception as e:
- print(f"Error in merge_features: {e}")
- # 返回原始特征或None
- return features
|