소스 검색

修改line_predictor为原来版本

lstrlq 4 달 전
부모
커밋
c07c4076ce

+ 2 - 2
models/line_detect/heads/arc_unet.py → models/line_detect/heads/decoder.py

@@ -8,7 +8,7 @@ class TestModel(nn.Sequential):
     def __init__(self,block: Type[Union[Bottleneck]]):
         super().__init__()
         self.encoder=resnet101fpn(out_channels=256)
-        self.decoder=ArcUnet(block)
+        self.decoder=FPNDecoder(block)
 
     def forward(self, x):
         res=self.encoder(x)
@@ -20,7 +20,7 @@ class TestModel(nn.Sequential):
         return out
 
 
-class ArcUnet(nn.Sequential):
+class FPNDecoder(nn.Sequential):
     def __init__(self,block: Type[Union[Bottleneck]], in_channels=256):
         super().__init__()
         self._norm_layer = nn.BatchNorm2d

+ 4 - 4
models/line_detect/heads/head_losses.py

@@ -411,8 +411,8 @@ def heatmaps_to_lines(maps, rois):
     line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
     line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
 
-    # line_maps=maps[:,1]
-    line_maps = maps.squeeze(1)
+    line_maps=maps[:,1]
+    # line_maps = maps.squeeze(1)
 
     for i in range(len(rois)):
         line_roi_map = line_maps[i].unsqueeze(0)
@@ -503,8 +503,8 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # line_logits = line_logits.view(N * K, H * W)
     # print(f'line_logits[valid]:{line_logits[valid].shape}')
     print(f'loss1 line_logits:{line_logits.shape}')
-    # line_logits = line_logits[:,1,:,:]
-    line_logits = line_logits.squeeze(1)
+    line_logits = line_logits[:,1,:,:]
+    # line_logits = line_logits.squeeze(1)
     print(f'loss2 line_logits:{line_logits.shape}')
 
     # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])

+ 4 - 4
models/line_detect/line_detect.py

@@ -23,7 +23,7 @@ from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extract
     BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 from .heads.arc_heads import ArcHeads, ArcPredictor
-from .heads.arc_unet import ArcUnet
+from .heads.decoder import FPNDecoder
 from .heads.line_heads import LinePredictor
 from .heads.point_heads import PointHeads, PointPredictor
 from .loi_heads import RoIHeads
@@ -180,8 +180,8 @@ class LineDetect(BaseDetectionNet):
 
         if line_predictor is None and detect_line:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-        #     line_predictor = LinePredictor(in_channels=256)
-            line_predictor = ArcUnet(Bottleneck)
+            line_predictor = LinePredictor(in_channels=256)
+            # line_predictor = ArcUnet(Bottleneck)
 
         if point_head is None and detect_point:
             layers = tuple(num_points for _ in range(8))
@@ -197,7 +197,7 @@ class LineDetect(BaseDetectionNet):
         if detect_arc and arc_predictor is None:
             layers = tuple(num_points for _ in range(8))
             # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
-            arc_predictor=ArcUnet(Bottleneck)
+            arc_predictor=FPNDecoder(Bottleneck)
 
 
 

+ 2 - 2
models/line_detect/loi_heads.py

@@ -1422,9 +1422,9 @@ class RoIHeads(nn.Module):
         # cs_features= features['0']
         # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
-        # cs_features=features['0']
+        cs_features=features['0']
 
-        cs_features = features
+        # cs_features = features
         # filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
         #
         # if len(filtered_proposals) > 0:

+ 1 - 1
models/line_detect/train_demo.py

@@ -21,7 +21,7 @@ if __name__ == '__main__':
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
     model = linedetect_newresnet152fpn(num_points=3)
-    model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250728_080143/weights/best_val.pth')
+    # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250728_080143/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()