|
@@ -187,7 +187,7 @@ class LineDetect(BaseDetectionNet):
|
|
|
|
|
|
|
|
if point_predictor is None and detect_point:
|
|
if point_predictor is None and detect_point:
|
|
|
# keypoint_dim_reduced = 512 # == keypoint_layers[-1]
|
|
# keypoint_dim_reduced = 512 # == keypoint_layers[-1]
|
|
|
- point_predictor = PointPredictor(in_channels=128)
|
|
|
|
|
|
|
+ point_predictor = PointPredictor(in_channels=256)
|
|
|
|
|
|
|
|
if detect_arc and arc_head is None:
|
|
if detect_arc and arc_head is None:
|
|
|
layers = tuple(num_points for _ in range(8))
|
|
layers = tuple(num_points for _ in range(8))
|
|
@@ -389,7 +389,11 @@ def linedetect_newresnet50fpn(
|
|
|
|
|
|
|
|
anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
|
|
anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
|
|
|
|
|
|
|
|
- model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
|
|
|
|
|
|
|
+ model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,
|
|
|
|
|
+ detect_point=True,
|
|
|
|
|
+ detect_line=False,
|
|
|
|
|
+ detect_arc=False,
|
|
|
|
|
+ **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|