Browse Source

添加backbone_factory

RenLiqiang 3 months ago
parent
commit
0cad5847fb
2 changed files with 3 additions and 2 deletions
  1. 2 1
      models/base/base_detection_net.py
  2. 1 1
      models/line_detect/line_net.py

+ 2 - 1
models/base/base_detection_net.py

@@ -10,9 +10,10 @@ import torch
 from torch import nn, Tensor
 
 from libs.vision_libs.utils import _log_api_usage_once
+from models.base.base_model import BaseModel
 
 
-class BaseDetectionNet(nn.Module):
+class BaseDetectionNet(BaseModel):
     """
     Main class for Generalized R-CNN.
 

+ 1 - 1
models/line_detect/line_net.py

@@ -61,7 +61,7 @@ class LineNet(BaseDetectionNet):
             backbone=backbone_factory.get_resnet50_fpn()
             print(f'out_chanenels:{backbone.out_channels}')
             self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
-            
+
 
     def __construct__(
             self,