|
|
@@ -23,7 +23,7 @@ from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extract
|
|
|
BackboneWithFPN, resnet_fpn_backbone
|
|
|
from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
|
from .heads.arc_heads import ArcHeads, ArcPredictor
|
|
|
-from .heads.arc_unet import ArcUnet
|
|
|
+from .heads.decoder import FPNDecoder
|
|
|
from .heads.line_heads import LinePredictor
|
|
|
from .heads.point_heads import PointHeads, PointPredictor
|
|
|
from .loi_heads import RoIHeads
|
|
|
@@ -180,8 +180,8 @@ class LineDetect(BaseDetectionNet):
|
|
|
|
|
|
if line_predictor is None and detect_line:
|
|
|
# keypoint_dim_reduced = 512 # == keypoint_layers[-1]
|
|
|
- # line_predictor = LinePredictor(in_channels=256)
|
|
|
- line_predictor = ArcUnet(Bottleneck)
|
|
|
+ line_predictor = LinePredictor(in_channels=256)
|
|
|
+ # line_predictor = ArcUnet(Bottleneck)
|
|
|
|
|
|
if point_head is None and detect_point:
|
|
|
layers = tuple(num_points for _ in range(8))
|
|
|
@@ -197,7 +197,7 @@ class LineDetect(BaseDetectionNet):
|
|
|
if detect_arc and arc_predictor is None:
|
|
|
layers = tuple(num_points for _ in range(8))
|
|
|
# arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
|
|
|
- arc_predictor=ArcUnet(Bottleneck)
|
|
|
+ arc_predictor=FPNDecoder(Bottleneck)
|
|
|
|
|
|
|
|
|
|