|
@@ -22,6 +22,7 @@ from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead
|
|
|
|
|
|
from .roi_heads import RoIHeads
|
|
|
from .trainer import Trainer
|
|
|
+from ..base import backbone_factory
|
|
|
from ..base.base_detection_net import BaseDetectionNet
|
|
|
import torch.nn.functional as F
|
|
|
|
|
@@ -57,87 +58,12 @@ class LineNet(BaseDetectionNet):
|
|
|
num_classes = cfg['num_classes']
|
|
|
|
|
|
if backbone == 'resnet50_fpn':
|
|
|
- is_trained = False
|
|
|
- trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
|
|
|
- norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
|
|
- backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
|
|
|
- backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
|
|
+ backbone=backbone_factory.get_resnet50_fpn()
|
|
|
print(f'out_chanenels:{backbone.out_channels}')
|
|
|
- self.construct(backbone=backbone,num_classes=num_classes,**kwargs)
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- def construct(
|
|
|
+ self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+ def __construct__(
|
|
|
self,
|
|
|
backbone,
|
|
|
num_classes=None,
|