RenLiqiang 7 месяцев назад
Родитель
Сommit
9e1eb072f6

+ 2 - 1
models/base/backbone_factory.py

@@ -43,6 +43,7 @@ def get_mobilenet_v3_large_fpn():
     backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer)
     backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
     return backbone
+
 def get_convnext_fpn():
     convnext = models.convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)
     # convnext = models.convnext_small(pretrained=True)
@@ -97,7 +98,7 @@ def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
 
 
 # 加载 ConvNeXt 模型
-convnext = models.convnext_base(pretrained=True)
+# convnext = models.convnext_base(pretrained=True)
 # convnext = models.convnext_tiny(pretrained=True)
 # convnext = models.convnext_small(pretrained=True)
 # print(convnext)

+ 21 - 20
models/line_detect/line_net.py

@@ -29,6 +29,7 @@ from .roi_heads import RoIHeads
 from .trainer import Trainer
 from ..base import backbone_factory
 from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
+# from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
@@ -522,11 +523,11 @@ class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
     DEFAULT = COCO_V1
 
 
-@register_model()
-@handle_legacy_interface(
-    weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
-    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
-)
+# @register_model()
+# @handle_legacy_interface(
+#     weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
+#     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+# )
 def linenet_resnet18_fpn(
         *,
         weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
@@ -681,11 +682,11 @@ def linenet_resnet50_fpn(
 
 
 
-@register_model()
-@handle_legacy_interface(
-    weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
-    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
-)
+# @register_model()
+# @handle_legacy_interface(
+#     weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
+#     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+# )
 def linenet_resnet50_fpn_v2(
         *,
         weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
@@ -802,11 +803,11 @@ def _linenet_mobilenet_v3_large_fpn(
     return model
 
 
-@register_model()
-@handle_legacy_interface(
-    weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
-    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
-)
+# @register_model()
+# @handle_legacy_interface(
+#     weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
+#     weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+# )
 def linenet_mobilenet_v3_large_320_fpn(
         *,
         weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
@@ -876,11 +877,11 @@ def linenet_mobilenet_v3_large_320_fpn(
     )
 
 
-@register_model()
-@handle_legacy_interface(
-    weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
-    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
-)
+# @register_model()
+# @handle_legacy_interface(
+#     weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
+#     weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+# )
 def linenet_mobilenet_v3_large_fpn(
         *,
         weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,

+ 1 - 1
models/line_detect/train_demo.py

@@ -1,6 +1,6 @@
 import torch
 
-from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, get_line_net_convnext_fpn
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
 from models.line_detect.trainer import Trainer
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

+ 7 - 1
models/line_detect/trainer.py

@@ -15,8 +15,12 @@ from models.line_detect.dataset_LD import WirePointDataset
 from models.wirenet.postprocess import postprocess
 from tools import utils
 from torchvision import transforms
+import matplotlib as mpl
 
-
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
 def _loss(losses):
     total_loss = 0
     for i in losses.keys():
@@ -29,6 +33,8 @@ def _loss(losses):
         loss = loss_labels[0][name].mean()
         total_loss += loss
     return total_loss
+def c(x):
+    return sm.to_rgba(x)
 
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')