|
@@ -51,25 +51,25 @@ def _default_anchorgen():
|
|
|
|
|
|
|
|
|
class LineNet(BaseDetectionNet):
|
|
|
- def __init__(self, cfg, **kwargs):
|
|
|
- cfg = read_yaml(cfg)
|
|
|
- self.cfg=cfg
|
|
|
- backbone = cfg['backbone']
|
|
|
- print(f'LineNet Backbone:{backbone}')
|
|
|
- num_classes = cfg['num_classes']
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
- if backbone == 'resnet50_fpn':
|
|
|
- backbone=backbone_factory.get_resnet50_fpn()
|
|
|
- print(f'out_chanenels:{backbone.out_channels}')
|
|
|
- elif backbone== 'mobilenet_v3_large_fpn':
|
|
|
- backbone=backbone_factory.get_mobilenet_v3_large_fpn()
|
|
|
- elif backbone=='resnet18_fpn':
|
|
|
- backbone=backbone_factory.get_resnet18_fpn()
|
|
|
|
|
|
- self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
|
|
|
-
|
|
|
-
|
|
|
- def __construct__(
|
|
|
+ def __init__(
|
|
|
self,
|
|
|
backbone,
|
|
|
num_classes=None,
|
|
@@ -134,12 +134,15 @@ class LineNet(BaseDetectionNet):
|
|
|
|
|
|
out_channels = backbone.out_channels
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
if line_head is None:
|
|
|
num_class = 5
|
|
|
line_head = LineRCNNHeads(out_channels, num_class)
|
|
|
|
|
|
if line_predictor is None:
|
|
|
- line_predictor = LineRCNNPredictor(self.cfg)
|
|
|
+ line_predictor = LineRCNNPredictor()
|
|
|
|
|
|
if rpn_anchor_generator is None:
|
|
|
rpn_anchor_generator = _default_anchorgen()
|
|
@@ -199,6 +202,7 @@ class LineNet(BaseDetectionNet):
|
|
|
|
|
|
super().__init__(backbone, rpn, roi_heads, transform)
|
|
|
|
|
|
+
|
|
|
self.roi_heads = roi_heads
|
|
|
|
|
|
self.roi_heads.line_head = line_head
|