소스 검색

change to linenet_resnet50

RenLiqiang 3 달 전
부모
커밋
8d994ffb62
4개의 변경된 파일69개의 추가작업 그리고 41개의 파일을 삭제
  1. 22 18
      models/line_detect/line_net.py
  2. 2 2
      models/line_detect/line_net.yaml
  3. 43 20
      models/line_detect/line_predictor.py
  4. 2 1
      models/line_detect/test_train.py

+ 22 - 18
models/line_detect/line_net.py

@@ -51,25 +51,25 @@ def _default_anchorgen():
 
 
 class LineNet(BaseDetectionNet):
-    def __init__(self, cfg, **kwargs):
-        cfg = read_yaml(cfg)
-        self.cfg=cfg
-        backbone = cfg['backbone']
-        print(f'LineNet Backbone:{backbone}')
-        num_classes = cfg['num_classes']
+    # def __init__(self, cfg, **kwargs):
+    #     cfg = read_yaml(cfg)
+    #     self.cfg=cfg
+    #     backbone = cfg['backbone']
+    #     print(f'LineNet Backbone:{backbone}')
+    #     num_classes = cfg['num_classes']
+    #
+    #     if backbone == 'resnet50_fpn':
+    #         backbone=backbone_factory.get_resnet50_fpn()
+    #         print(f'out_chanenels:{backbone.out_channels}')
+    #     elif backbone== 'mobilenet_v3_large_fpn':
+    #         backbone=backbone_factory.get_mobilenet_v3_large_fpn()
+    #     elif backbone=='resnet18_fpn':
+    #         backbone=backbone_factory.get_resnet18_fpn()
+    #
+    #     self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
 
-        if backbone == 'resnet50_fpn':
-            backbone=backbone_factory.get_resnet50_fpn()
-            print(f'out_chanenels:{backbone.out_channels}')
-        elif backbone== 'mobilenet_v3_large_fpn':
-            backbone=backbone_factory.get_mobilenet_v3_large_fpn()
-        elif backbone=='resnet18_fpn':
-            backbone=backbone_factory.get_resnet18_fpn()
 
-        self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
-
-
-    def __construct__(
+    def __init__(
             self,
             backbone,
             num_classes=None,
@@ -134,12 +134,15 @@ class LineNet(BaseDetectionNet):
 
         out_channels = backbone.out_channels
 
+        # cfg = read_yaml(cfg)
+        # self.cfg=cfg
+
         if line_head is None:
             num_class = 5
             line_head = LineRCNNHeads(out_channels, num_class)
 
         if line_predictor is None:
-            line_predictor = LineRCNNPredictor(self.cfg)
+            line_predictor = LineRCNNPredictor()
 
         if rpn_anchor_generator is None:
             rpn_anchor_generator = _default_anchorgen()
@@ -199,6 +202,7 @@ class LineNet(BaseDetectionNet):
 
         super().__init__(backbone, rpn, roi_heads, transform)
 
+
         self.roi_heads = roi_heads
 
         self.roi_heads.line_head = line_head

+ 2 - 2
models/line_detect/line_net.yaml

@@ -16,8 +16,8 @@ lneg: 1
 boxes: 1.0
 
 # backbone parameters
-#backbone: resnet50_fpn
-backbone: resnet18_fpn
+backbone: resnet50_fpn
+#backbone: resnet18_fpn
 #backbone: mobilenet_v3_large_fpn
 #  backbone: unet
 depth: 4

+ 43 - 20
models/line_detect/line_predictor.py

@@ -47,30 +47,53 @@ class Bottleneck1D(nn.Module):
         return x + self.op(x)
 
 class LineRCNNPredictor(nn.Module):
-    def __init__(self, cfg):
+    def __init__(self,**kwargs):
         super().__init__()
         # self.backbone = backbone
         # self.cfg = read_yaml(cfg)
         # self.cfg = read_yaml(r'./config/wireframe.yaml')
-        self.cfg = cfg
-        self.n_pts0 = self.cfg['n_pts0']
-        self.n_pts1 = self.cfg['n_pts1']
-        self.n_stc_posl = self.cfg['n_stc_posl']
-        self.dim_loi = self.cfg['dim_loi']
-        self.use_conv = self.cfg['use_conv']
-        self.dim_fc = self.cfg['dim_fc']
-        self.n_out_line = self.cfg['n_out_line']
-        self.n_out_junc = self.cfg['n_out_junc']
-        self.loss_weight = self.cfg['loss_weight']
-        self.n_dyn_junc = self.cfg['n_dyn_junc']
-        self.eval_junc_thres = self.cfg['eval_junc_thres']
-        self.n_dyn_posl = self.cfg['n_dyn_posl']
-        self.n_dyn_negl = self.cfg['n_dyn_negl']
-        self.n_dyn_othr = self.cfg['n_dyn_othr']
-        self.use_cood = self.cfg['use_cood']
-        self.use_slop = self.cfg['use_slop']
-        self.n_stc_negl = self.cfg['n_stc_negl']
-        self.head_size = self.cfg['head_size']
+
+        # print(f'linePredictor cfg:{cfg}')
+        #
+        # self.cfg = cfg
+        # self.n_pts0 = self.cfg['n_pts0']
+        # self.n_pts1 = self.cfg['n_pts1']
+        # self.n_stc_posl = self.cfg['n_stc_posl']
+        # self.dim_loi = self.cfg['dim_loi']
+        # self.use_conv = self.cfg['use_conv']
+        # self.dim_fc = self.cfg['dim_fc']
+        # self.n_out_line = self.cfg['n_out_line']
+        # self.n_out_junc = self.cfg['n_out_junc']
+        # self.loss_weight = self.cfg['loss_weight']
+        # self.n_dyn_junc = self.cfg['n_dyn_junc']
+        # self.eval_junc_thres = self.cfg['eval_junc_thres']
+        # self.n_dyn_posl = self.cfg['n_dyn_posl']
+        # self.n_dyn_negl = self.cfg['n_dyn_negl']
+        # self.n_dyn_othr = self.cfg['n_dyn_othr']
+        # self.use_cood = self.cfg['use_cood']
+        # self.use_slop = self.cfg['use_slop']
+        # self.n_stc_negl = self.cfg['n_stc_negl']
+        # self.head_size = self.cfg['head_size']
+
+
+        self.n_pts0 = 32
+        self.n_pts1 = 8
+        self.n_stc_posl =300
+        self.dim_loi = 128
+        self.use_conv = 0
+        self.dim_fc = 1024
+        self.n_out_line = 2500
+        self.n_out_junc =250
+        # self.loss_weight =
+        self.n_dyn_junc = 300
+        self.eval_junc_thres = 0.008
+        self.n_dyn_posl =300
+        self.n_dyn_negl =80
+        self.n_dyn_othr = 600
+        self.use_cood = 0
+        self.use_slop = 0
+        self.n_stc_negl = 80
+        self.head_size = [[2], [1], [2]]
 
         self.num_class = sum(sum(self.head_size, []))
         self.head_off = np.cumsum([sum(h) for h in self.head_size])

+ 2 - 1
models/line_detect/test_train.py

@@ -6,7 +6,8 @@ from models.line_detect.trainer import Trainer
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 
-    model = LineNet('line_net.yaml')
+    # model = LineNet('line_net.yaml')
+    model=linenet_resnet50_fpn()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     model.train_by_cfg(cfg='train.yaml')