Jelajahi Sumber

加入resfpn对应的unet解码器

RenLiqiang 4 bulan lalu
induk
melakukan
dc42aa88fa

+ 3 - 3
models/base/high_reso_resnet.py

@@ -317,9 +317,9 @@ if __name__ == "__main__":
     input_tensor = torch.randn(1, 3, 512, 512).to(device)
     output_tensor = model(input_tensor)
 
-    backbone = ResNet(Bottleneck,[3,4,6,3]).to(device)
-    features = backbone(input_tensor)
-    print("Raw backbone output:", list(features.keys()))
+    # backbone = ResNet(Bottleneck,[3,4,6,3]).to(device)
+    # features = backbone(input_tensor)
+    # print("Raw backbone output:", list(features.keys()))
     print(f"Input shape: {input_tensor.shape}")
     print(f'feat_names:{list(output_tensor.keys())}')
     print(f"Output shape0: {output_tensor['0'].shape}")

+ 139 - 0
models/line_detect/heads/arc_unet.py

@@ -0,0 +1,139 @@
+import torch
+from torch import nn
+
+from models.base.high_reso_resnet import Bottleneck, resnet101fpn
+from typing import Any, Callable, List, Optional, Type, Union
+
+class TestModel(nn.Sequential):
+    def __init__(self,block: Type[Union[Bottleneck]]):
+        super().__init__()
+        self.encoder=resnet101fpn(out_channels=256)
+        self.decoder=ArcUnet(block)
+
+    def forward(self, x):
+        res=self.encoder(x)
+        for k in res.keys():
+            print(f'k:{k}')
+        # print(f'res:{res}')
+
+        out=self.decoder(res)
+        return out
+
+
+class ArcUnet(nn.Sequential):
+    def __init__(self,block: Type[Union[Bottleneck]], in_channels=256):
+        super().__init__()
+        self._norm_layer = nn.BatchNorm2d
+        self.inplanes = 64
+        self.dilation = 1
+        self.groups = 1
+        self.base_width = 64
+
+        self.decoder0=self._make_decoder_layer(block,256,256)
+
+        self.upconv4 = self._make_upsample_layer(256 , 256)
+        self.upconv3 = self._make_upsample_layer(512, 256)
+        self.upconv2 = self._make_upsample_layer(512, 256)
+        self.upconv1 = self._make_upsample_layer(512, 256)
+        self.upconv0 = self._make_upsample_layer(512, 256)
+
+        self.final_up = self._make_upsample_layer(512, 128)
+
+        self.decoder4 = self._make_decoder_layer(block, 512, 256)
+        self.decoder3 = self._make_decoder_layer(block, 512, 256)
+        self.decoder2 = self._make_decoder_layer(block, 512, 256)
+        self.decoder1 = self._make_decoder_layer(block, 512, 256)
+
+        self.final_conv = nn.Conv2d(128, 1, kernel_size=1)
+
+
+
+    def forward(self, fpn_res):
+        # ------------------
+        # Encoder
+        # ------------------
+        e0 = fpn_res['0']  # [B, 64, H/4, W/4]
+        print(f'e0:{e0.shape}')
+        e1 = fpn_res['1']  # [B, 256, H/4, W/4]
+        print(f'e1:{e1.shape}')
+        e2 = fpn_res['2']  # [B, 512, H/8, W/8]
+        print(f'e2:{e2.shape}')
+        e3 = fpn_res['3']  # [B, 1024, H/16, W/16]
+        print(f'e3:{e3.shape}')
+        e4 = fpn_res['4']  # [B, 2048, H/32, W/32]
+        print(f'e4:{e4.shape}')
+
+        # ------------------
+        # Decoder
+        # ------------------
+
+        d4 = self.upconv4(e4)  # [B, 1024, H/16, W/16]
+        print(f'd4 = self.upconv4(e4):{d4.shape}')
+        d4 = torch.cat([d4, e3], dim=1)  # [B, 2048, H/16, W/16]
+        print(f' d4 = torch.cat([d4, e3], dim=1):{d4.shape}')
+        d4 = self.decoder4(d4)  # [B, 2048, H/16, W/16]
+        print(f'd4 = self.decoder4(d4):{d4.shape}')
+
+        d3 = self.upconv3(d4)  # [B, 512, H/8, W/8]
+        print(f'd3 = self.upconv3(d4):{d3.shape}')
+        d3 = torch.cat([d3, e2], dim=1)  # [B, 1024, H/8, W/8]
+        print(f'd3 = torch.cat([d3, e2], dim=1):{d3.shape}')
+        d3 = self.decoder3(d3)  # [B, 1024, H/8, W/8]
+        print(f'd3 = self.decoder3(d3):{d3.shape}')
+
+        d2 = self.upconv2(d3)  # [B, 256, H/4, W/4]
+        print(f'd2 = self.upconv2(d3):{d2.shape}')
+        d2 = torch.cat([d2, e1], dim=1)  # [B, 512, H/4, W/4]
+        print(f'd2 = torch.cat([d2, e1], dim=1):{d2.shape}')
+        d2 = self.decoder2(d2)  # [B, 512, H/4, W/4]
+        print(f'd2 = self.decoder2(d2):{d2.shape}')
+
+        d1 = self.upconv1(d2)  # [B, 64, H/2, W/2]
+        print(f'd1 = self.upconv1(d2):{d1.shape}')
+        d1 = torch.cat([d1, e0], dim=1)  # [B, 128, H/2, W/2]
+        print(f'd1 = torch.cat([d1, e0], dim=1):{d1.shape}')
+        d1 = self.decoder1(d1)  # [B, 128, H/2, W/2]
+        print(f'd1 =self.decoder1(d1):{d1.shape}')
+
+        # ------------------
+        # Output Head
+        # ------------------
+        d0=self.final_up(d1)
+        out = self.final_conv(d0)  # [B, num_classes, H/2, W/2]
+        print(f'out:{out.shape}')
+
+        return out
+
+
+
+
+
+    def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
+        """
+        使用转置卷积进行上采样
+        """
+        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
+
+    def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
+                            out_channels: int, blocks: int = 1) -> nn.Sequential:
+        """
+        """
+        # assert in_channels == out_channels, "in_channels must equal out_channels"
+        layers = []
+        for _ in range(blocks):
+            layers.append(
+                block(
+                    in_channels,
+                    in_channels // block.expansion,
+                    groups=self.groups,
+                    base_width=self.base_width,
+                    dilation=self.dilation,
+                    norm_layer=self._norm_layer,
+                )
+            )
+        return nn.Sequential(*layers)
+
+if __name__ == '__main__':
+    model=TestModel(Bottleneck)
+    x=torch.randn(3,3,512,512)
+    out=model(x)

+ 1 - 1
models/line_detect/heads/head_losses.py

@@ -580,7 +580,7 @@ def arc_points_to_heatmap(keypoints, rois, heatmap_size):
     y = keypoints[..., 1].unsqueeze(1)
     num_points=x.shape[2]
     print(f'num_points:{num_points}')
-    gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=1.0)
+    gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
     # show_heatmap(gs[0],'target')
     all_roi_heatmap = []
     for roi, heatmap in zip(rois, gs):

+ 5 - 3
models/line_detect/line_detect.py

@@ -23,6 +23,7 @@ from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extract
     BackboneWithFPN, resnet_fpn_backbone
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 from .heads.arc_heads import ArcHeads, ArcPredictor
+from .heads.arc_unet import ArcUnet
 from .heads.line_heads import LinePredictor
 from .heads.point_heads import PointHeads, PointPredictor
 from .loi_heads import RoIHeads
@@ -36,7 +37,7 @@ from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
 from ..base.high_reso_maxvit import maxvit_with_fpn
-from ..base.high_reso_resnet import resnet50fpn, resnet18fpn, resnet101fpn
+from ..base.high_reso_resnet import resnet50fpn, resnet18fpn, resnet101fpn, Bottleneck
 
 __all__ = [
     "LineDetect",
@@ -194,7 +195,8 @@ class LineDetect(BaseDetectionNet):
             arc_head=ArcHeads(8,layers)
         if detect_arc and arc_predictor is None:
             layers = tuple(num_points for _ in range(8))
-            arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
+            # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
+            arc_predictor=ArcUnet(Bottleneck)
 
 
 
@@ -608,7 +610,7 @@ def linedetect_resnet18_fpn(
 ) -> LineDetect:
 
     if num_classes is None:
-        num_classes = 3
+        num_classes = 4
     if num_points is None:
         num_points = 3
     size=1024

+ 2 - 1
models/line_detect/loi_heads.py

@@ -1469,7 +1469,8 @@ class RoIHeads(nn.Module):
         # cs_features= features['0']
         # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
-        cs_features=features['0']
+        # cs_features=features['0']
+        cs_features = features
         # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
         #
         # if len(filtered_proposals) > 0:

+ 4 - 4
models/line_detect/train.yaml

@@ -1,13 +1,13 @@
 io:
   logdir: train_results
-
-  datadir: \\192.168.50.222\share\rlq\datasets\arc_datasets_100
+#  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
+#  datadir: /data/share/zyh/arc/a_datasetb
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-#  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
+  datadir: \\192.168.50.222/share/rlq/datasets/arc_datasets_100
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb
@@ -20,7 +20,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 1
+  batch_size: 2
   max_epoch: 8000000
 #  augmentation: True
   augmentation: False

+ 2 - 2
models/line_detect/train_demo.py

@@ -17,10 +17,10 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    # model=linedetect_newresnet18fpn(num_points=3)
+    model=linedetect_newresnet18fpn(num_points=3)
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
-    model = linedetect_newresnet152fpn(num_points=3)
+    # model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()