|
@@ -24,6 +24,7 @@ from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extract
|
|
|
from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
|
from .heads.arc_heads import ArcHeads, ArcPredictor
|
|
from .heads.arc_heads import ArcHeads, ArcPredictor
|
|
|
from .heads.circle_heads import CircleHeads, CirclePredictor
|
|
from .heads.circle_heads import CircleHeads, CirclePredictor
|
|
|
|
|
+from .heads.decoder import FPNDecoder
|
|
|
from .heads.line_heads import LinePredictor
|
|
from .heads.line_heads import LinePredictor
|
|
|
from .heads.point_heads import PointHeads, PointPredictor
|
|
from .heads.point_heads import PointHeads, PointPredictor
|
|
|
from .loi_heads import RoIHeads
|
|
from .loi_heads import RoIHeads
|
|
@@ -202,7 +203,7 @@ class LineDetect(BaseDetectionNet):
|
|
|
if detect_arc and arc_predictor is None:
|
|
if detect_arc and arc_predictor is None:
|
|
|
layers = tuple(num_points for _ in range(8))
|
|
layers = tuple(num_points for _ in range(8))
|
|
|
# arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
|
|
# arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
|
|
|
- arc_predictor=ArcUnet(Bottleneck)
|
|
|
|
|
|
|
+ arc_predictor=FPNDecoder(Bottleneck)
|
|
|
|
|
|
|
|
if detect_circle and circle_head is None:
|
|
if detect_circle and circle_head is None:
|
|
|
layers = tuple(num_points for _ in range(8))
|
|
layers = tuple(num_points for _ in range(8))
|
|
@@ -210,7 +211,7 @@ class LineDetect(BaseDetectionNet):
|
|
|
if detect_circle and circle_predictor is None:
|
|
if detect_circle and circle_predictor is None:
|
|
|
layers = tuple(num_points for _ in range(8))
|
|
layers = tuple(num_points for _ in range(8))
|
|
|
# arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
|
|
# arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
|
|
|
- circle_predictor = CirclePredictor(in_channels=256)
|
|
|
|
|
|
|
+ circle_predictor = CirclePredictor(in_channels=256,out_channels=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -400,7 +401,7 @@ def linedetect_newresnet50fpn(
|
|
|
if num_points is None:
|
|
if num_points is None:
|
|
|
num_points = 4
|
|
num_points = 4
|
|
|
|
|
|
|
|
- size=768
|
|
|
|
|
|
|
+ size=512
|
|
|
backbone =resnet50fpn(out_channels=256)
|
|
backbone =resnet50fpn(out_channels=256)
|
|
|
featmap_names=['0', '1', '2', '3','4','pool']
|
|
featmap_names=['0', '1', '2', '3','4','pool']
|
|
|
# print(f'featmap_names:{featmap_names}')
|
|
# print(f'featmap_names:{featmap_names}')
|
|
@@ -511,12 +512,12 @@ def linedetect_maxvitfpn(
|
|
|
# weights = LineNet_ResNet50_FPN_Weights.verify(weights)
|
|
# weights = LineNet_ResNet50_FPN_Weights.verify(weights)
|
|
|
# weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
# weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
|
if num_classes is None:
|
|
if num_classes is None:
|
|
|
- num_classes = 4
|
|
|
|
|
|
|
+ num_classes = 5
|
|
|
if num_points is None:
|
|
if num_points is None:
|
|
|
num_points = 3
|
|
num_points = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
- size=224*3
|
|
|
|
|
|
|
+ size=224*2
|
|
|
|
|
|
|
|
|
|
|
|
|
maxvit = MaxVitBackbone(input_size=(size,size))
|
|
maxvit = MaxVitBackbone(input_size=(size,size))
|
|
@@ -537,7 +538,7 @@ def linedetect_maxvitfpn(
|
|
|
return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'},
|
|
return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'},
|
|
|
# ç¡®ä¿è¿äºé®å¯¹åºå°å®é
çå±
|
|
# ç¡®ä¿è¿äºé®å¯¹åºå°å®é
çå±
|
|
|
in_channels_list=in_channels_list,
|
|
in_channels_list=in_channels_list,
|
|
|
- out_channels=128
|
|
|
|
|
|
|
+ out_channels=256
|
|
|
)
|
|
)
|
|
|
test_input = torch.randn(1, 3,size,size)
|
|
test_input = torch.randn(1, 3,size,size)
|
|
|
|
|
|
|
@@ -550,7 +551,8 @@ def linedetect_maxvitfpn(
|
|
|
box_roi_pool=roi_pooler,
|
|
box_roi_pool=roi_pooler,
|
|
|
detect_line=False,
|
|
detect_line=False,
|
|
|
detect_point=False,
|
|
detect_point=False,
|
|
|
- detect_arc=True,
|
|
|
|
|
|
|
+ detect_arc=False,
|
|
|
|
|
+ detect_circle=True,
|
|
|
)
|
|
)
|
|
|
return model
|
|
return model
|
|
|
|
|
|