@@ -10,9 +10,10 @@ import torch
from torch import nn, Tensor
from libs.vision_libs.utils import _log_api_usage_once
+from models.base.base_model import BaseModel
-class BaseDetectionNet(nn.Module):
+class BaseDetectionNet(BaseModel):
"""
Main class for Generalized R-CNN.
@@ -61,7 +61,7 @@ class LineNet(BaseDetectionNet):
backbone=backbone_factory.get_resnet50_fpn()
print(f'out_chanenels:{backbone.out_channels}')
self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
-
+
def __construct__(
self,