RenLiqiang 3 mēneši atpakaļ
vecāks
revīzija
d16559053e

+ 1 - 1
config/wireframe.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: D:\python\PycharmProjects\data
+  datadir: I:/datasets/wirenet_lm
   resume_from:
   num_workers: 8
   tensorboard_port: 6000

+ 2 - 2
lcnn/trainer.py

@@ -228,7 +228,7 @@ class Trainer(object):
             )
 
         if training:
-            self.model.train()
+            self.model.train1()
 
     def verify_freeze_params(model, freeze_config):
         """
@@ -253,7 +253,7 @@ class Trainer(object):
                                 print(f"  {param_name}: requires_grad = {param.requires_grad}")
 
     def train_epoch(self):
-        self.model.train()
+        self.model.train1()
 
         time = timer()
         for batch_idx, (image, meta, target, target_b) in enumerate(self.train_loader):

+ 2 - 1
libs/vision_libs/__init__.py

@@ -3,7 +3,8 @@ import warnings
 from modulefinder import Module
 
 import torch
-from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
+# from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
+from torchvision import datasets, io, models, ops, transforms, utils
 
 from .extension import _HAS_OPS
 

+ 1 - 0
libs/vision_libs/models/detection/rpn.py

@@ -370,6 +370,7 @@ class RegionProposalNetwork(torch.nn.Module):
         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
         proposals = proposals.view(num_images, -1, 4)
         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
+        # print(f'boxes:{boxes.shape},scores:{scores.shape}')
 
         losses = {}
         if self.training:

+ 15 - 0
models/base/backbone_factory.py

@@ -0,0 +1,15 @@
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.ops import misc as misc_nn_ops
+from torch import nn
+
+
+def get_resnet50_fpn():
+    is_trained = False
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+    backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    return backbone

+ 2 - 1
models/base/base_detection_net.py

@@ -10,9 +10,10 @@ import torch
 from torch import nn, Tensor
 
 from libs.vision_libs.utils import _log_api_usage_once
+from models.base.base_model import BaseModel
 
 
-class BaseDetectionNet(nn.Module):
+class BaseDetectionNet(BaseModel):
     """
     Main class for Generalized R-CNN.
 

+ 106 - 0
models/base/base_model.py

@@ -0,0 +1,106 @@
+import numpy as np
+import torch
+from abc import ABC, abstractmethod
+
+from models.base.base_trainer import BaseTrainer
+
+
+class BaseModel(ABC, torch.nn.Module):
+
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.cfg = None
+        self.trainer = None
+
+    @abstractmethod
+    def train_by_cfg(self, cfg):
+        return
+
+    # @abstractmethod
+    # def get_loss(self, Loss, results, inputs, device):
+    #     """Computes the loss given the network input and outputs.
+    #
+    #     Args:
+    #         Loss: A loss object.
+    #         results: This is the output of the model.
+    #         inputs: This is the input to the model.
+    #         device: The torch device to be used.
+    #
+    #     Returns:
+    #         Returns the loss value.
+    #     """
+    #     return
+    #
+    # @abstractmethod
+    # def get_optimizer(self, cfg_pipeline):
+    #     """Returns an optimizer object for the model.
+    #
+    #     Args:
+    #         cfg_pipeline: A Config object with the configuration of the pipeline.
+    #
+    #     Returns:
+    #         Returns a new optimizer object.
+    #     """
+    #     return
+    #
+    # @abstractmethod
+    # def preprocess(self, cfg_pipeline):
+    #     """Data preprocessing function.
+    #
+    #     This function is called before training to preprocess the data from a
+    #     dataset.
+    #
+    #     Args:
+    #         data: A sample from the dataset.
+    #         attr: The corresponding attributes.
+    #
+    #     Returns:
+    #         Returns the preprocessed data
+    #     """
+    #     return
+    # #
+    # # @abstractmethod
+    # # def transform(self, cfg_pipeline):
+    # #     """Transform function for the point cloud and features.
+    # #
+    # #     Args:
+    # #         cfg_pipeline: config file for pipeline.
+    # #     """
+    # #     return
+    #
+    # @abstractmethod
+    # def inference_begin(self, data):
+    #     """Function called right before running inference.
+    #
+    #     Args:
+    #         data: A data from the dataset.
+    #     """
+    #     return
+    #
+    # @abstractmethod
+    # def inference_preprocess(self):
+    #     """This function prepares the inputs for the model.
+    #
+    #     Returns:
+    #         The inputs to be consumed by the call() function of the model.
+    #     """
+    #     return
+    #
+    # @abstractmethod
+    # def inference_end(self, inputs, results):
+    #     """This function is called after the inference.
+    #
+    #     This function can be implemented to apply post-processing on the
+    #     network outputs.
+    #
+    #     Args:
+    #         results: The model outputs as returned by the call() function.
+    #             Post-processing is applied on this object.
+    #
+    #     Returns:
+    #         Returns True if the inference is complete and otherwise False.
+    #         Returning False can be used to implement inference for large point
+    #         clouds which require multiple passes.
+    #     """
+    #     return

+ 21 - 0
models/base/base_trainer.py

@@ -0,0 +1,21 @@
+from abc import ABC, abstractmethod
+
+
+class BaseTrainer(ABC):
+    def __init__(self,
+                 model=None,
+                 dataset=None,
+                 device='cuda',
+                 **kwargs):
+        #
+        self.model = model
+        self.dataset = dataset
+        self.device=device
+
+    # @abstractmethod
+    # def train_cfg(self,model,cfg):
+    #     return
+
+    # @abstractmethod
+    # def train(self,model, **kwargs):
+    #     return

+ 1 - 1
models/ins/trainer.py

@@ -14,7 +14,7 @@ from tools import utils, presets
 
 
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
-    model.train()
+    model.train1()
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"

+ 4 - 0
models/ins_detect/trainer.py

@@ -14,7 +14,11 @@ from tools import utils, presets
 
 
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+<<<<<<< HEAD
     model.train()
+=======
+    model.train1()
+>>>>>>> dev
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"

+ 1 - 1
models/keypoint/trainer.py

@@ -33,7 +33,7 @@ def log_losses_to_tensorboard(writer, result, step):
 
 
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
-    model.train()
+    model.train1()
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"

+ 12 - 77
models/line_detect/line_net.py

@@ -1,33 +1,25 @@
-
 from typing import Any, Callable, List, Optional, Tuple, Union
+
 import torch
+import torch.nn.functional as F
 from torch import nn
 from torchvision.ops import MultiScaleRoIAlign
 
-from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
-from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
-from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
-from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
-from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
-from libs.vision_libs.ops import misc as misc_nn_ops
+from  libs.vision_libs.ops import misc as misc_nn_ops
 from libs.vision_libs.transforms._presets import ObjectDetection
-from .line_head import LineRCNNHeads
-from .line_predictor import LineRCNNPredictor
-from .roi_heads import RoIHeads
 from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
-from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
+from libs.vision_libs.models._meta import _COCO_CATEGORIES
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
 from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
 from libs.vision_libs.models.detection._utils import overwrite_eps
-from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
+from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
+from libs.vision_libs.models.detection.backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
 
-from models.config.config_tool import read_yaml
-import numpy as np
-import torch.nn.functional as F
+from libs.vision_libs.models.detection.rpn import RegionProposalNetwork, RPNHead
+from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
 
-FEATURE_DIM = 8
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+######## 弃用  ###########
 
 __all__ = [
     "LineNet",
@@ -36,50 +28,10 @@ __all__ = [
     "LineNet_MobileNet_V3_Large_FPN_Weights",
     "LineNet_MobileNet_V3_Large_320_FPN_Weights",
     "linenet_resnet50_fpn",
-    "linenet_resnet50_fpn_v2",
+    "fasterrcnn_resnet50_fpn_v2",
     "linenet_mobilenet_v3_large_fpn",
     "linenet_mobilenet_v3_large_320_fpn",
 ]
-# __all__ = [
-#     "LineNet",
-#     "LineRCNN_ResNet50_FPN_Weights",
-#     "linercnn_resnet50_fpn",
-# ]
-
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-# class Bottleneck1D(nn.Module):
-#     def __init__(self, inplanes, outplanes):
-#         super(Bottleneck1D, self).__init__()
-#
-#         planes = outplanes // 2
-#         self.op = nn.Sequential(
-#             nn.BatchNorm1d(inplanes),
-#             nn.ReLU(inplace=True),
-#             nn.Conv1d(inplanes, planes, kernel_size=1),
-#             nn.BatchNorm1d(planes),
-#             nn.ReLU(inplace=True),
-#             nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-#             nn.BatchNorm1d(planes),
-#             nn.ReLU(inplace=True),
-#             nn.Conv1d(planes, outplanes, kernel_size=1),
-#         )
-#
-#     def forward(self, x):
-#         return x + self.op(x)
-
-
-
-
-
-
-
-
 
 from .roi_heads import RoIHeads
 
@@ -247,9 +199,6 @@ class LineNet(BaseDetectionNet):
         box_batch_size_per_image=512,
         box_positive_fraction=0.25,
         bbox_reg_weights=None,
-        # line parameters
-        line_head=None,
-        line_predictor=None,
         **kwargs,
     ):
 
@@ -278,13 +227,6 @@ class LineNet(BaseDetectionNet):
 
         out_channels = backbone.out_channels
 
-        if line_head is None:
-            num_class = 5
-            line_head = LineRCNNHeads(out_channels, num_class)
-
-        if line_predictor is None:
-            line_predictor = LineRCNNPredictor()
-
         if rpn_anchor_generator is None:
             rpn_anchor_generator = _default_anchorgen()
         if rpn_head is None:
@@ -323,8 +265,6 @@ class LineNet(BaseDetectionNet):
             box_roi_pool,
             box_head,
             box_predictor,
-            line_head,
-            line_predictor,
             box_fg_iou_thresh,
             box_bg_iou_thresh,
             box_batch_size_per_image,
@@ -343,10 +283,6 @@ class LineNet(BaseDetectionNet):
 
         super().__init__(backbone, rpn, roi_heads, transform)
 
-        self.roi_heads = roi_heads
-        # self.roi_heads.line_head = line_head
-        # self.roi_heads.line_predictor = line_predictor
-
 
 class TwoMLPHead(nn.Module):
     """
@@ -651,7 +587,7 @@ def linenet_resnet50_fpn(
     weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
-def linenet_resnet50_fpn_v2(
+def fasterrcnn_resnet50_fpn_v2(
     *,
     weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
     progress: bool = True,
@@ -909,4 +845,3 @@ def linenet_mobilenet_v3_large_fpn(
         trainable_backbone_layers=trainable_backbone_layers,
         **kwargs,
     )
-

+ 57 - 0
models/line_detect/line_net.yaml

@@ -0,0 +1,57 @@
+image:
+  mean: [109.730, 103.832, 98.681]
+  stddev: [22.275, 22.124, 23.229]
+
+batch_size: 4
+batch_size_eval: 2
+
+# backbone multi-task parameters
+head_size: [[2], [1], [2]]
+loss_weight:
+jmap: 8.0
+lmap: 0.5
+joff: 0.25
+lpos: 1
+lneg: 1
+boxes: 1.0
+
+# backbone parameters
+backbone: resnet50_fpn
+#  backbone: unet
+depth: 4
+num_stacks: 1
+num_blocks: 1
+num_classes: 2
+
+# sampler parameters
+## static sampler
+n_stc_posl: 300
+n_stc_negl: 40
+
+## dynamic sampler
+n_dyn_junc: 300
+n_dyn_posl: 300
+n_dyn_negl: 80
+n_dyn_othr: 600
+
+# LOIPool layer parameters
+n_pts0: 32
+n_pts1: 8
+
+# line verification network parameters
+dim_loi: 128
+dim_fc: 1024
+
+# maximum junction and line outputs
+n_out_junc: 250
+n_out_line: 2500
+
+# additional ablation study parameters
+use_cood: 0
+use_slop: 0
+use_conv: 0
+
+# junction threashold for evaluation (See #5)
+eval_junc_thres: 0.008
+
+

+ 27 - 21
models/line_detect/line_predictor.py

@@ -20,35 +20,40 @@ import numpy as np
 import torch.nn.functional as F
 
 FEATURE_DIM = 8
+
 def non_maximum_suppression(a):
     ap = F.max_pool2d(a, 3, stride=1, padding=1)
     mask = (a == ap).float().clamp(min=0.0)
     return a * mask
 
+
+
 class LineRCNNPredictor(nn.Module):
-    def __init__(self):
+    def __init__(self, cfg):
         super().__init__()
         # self.backbone = backbone
         # self.cfg = read_yaml(cfg)
-        self.cfg = read_yaml(r'./config/wireframe.yaml')
-        self.n_pts0 = self.cfg['model']['n_pts0']
-        self.n_pts1 = self.cfg['model']['n_pts1']
-        self.n_stc_posl = self.cfg['model']['n_stc_posl']
-        self.dim_loi = self.cfg['model']['dim_loi']
-        self.use_conv = self.cfg['model']['use_conv']
-        self.dim_fc = self.cfg['model']['dim_fc']
-        self.n_out_line = self.cfg['model']['n_out_line']
-        self.n_out_junc = self.cfg['model']['n_out_junc']
-        self.loss_weight = self.cfg['model']['loss_weight']
-        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
-        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
-        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
-        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
-        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
-        self.use_cood = self.cfg['model']['use_cood']
-        self.use_slop = self.cfg['model']['use_slop']
-        self.n_stc_negl = self.cfg['model']['n_stc_negl']
-        self.head_size = self.cfg['model']['head_size']
+        # self.cfg = read_yaml(r'./config/wireframe.yaml')
+        self.cfg = cfg
+        self.n_pts0 = self.cfg['n_pts0']
+        self.n_pts1 = self.cfg['n_pts1']
+        self.n_stc_posl = self.cfg['n_stc_posl']
+        self.dim_loi = self.cfg['dim_loi']
+        self.use_conv = self.cfg['use_conv']
+        self.dim_fc = self.cfg['dim_fc']
+        self.n_out_line = self.cfg['n_out_line']
+        self.n_out_junc = self.cfg['n_out_junc']
+        self.loss_weight = self.cfg['loss_weight']
+        self.n_dyn_junc = self.cfg['n_dyn_junc']
+        self.eval_junc_thres = self.cfg['eval_junc_thres']
+        self.n_dyn_posl = self.cfg['n_dyn_posl']
+        self.n_dyn_negl = self.cfg['n_dyn_negl']
+        self.n_dyn_othr = self.cfg['n_dyn_othr']
+        self.use_cood = self.cfg['use_cood']
+        self.use_slop = self.cfg['use_slop']
+        self.n_stc_negl = self.cfg['n_stc_negl']
+        self.head_size = self.cfg['head_size']
+
         self.num_class = sum(sum(self.head_size, []))
         self.head_off = np.cumsum([sum(h) for h in self.head_size])
 
@@ -321,4 +326,5 @@ _COMMON_META = {
     "categories": _COCO_PERSON_CATEGORIES,
     "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
     "min_size": (1, 1),
-}
+}
+

+ 0 - 0
predict.py → models/line_detect/postprocess.py


+ 12 - 0
models/line_detect/test_train.py

@@ -0,0 +1,12 @@
+import torch
+
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet
+from models.line_detect.trainer import Trainer
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if __name__ == '__main__':
+
+    model = LineNet('line_net.yaml')
+    # trainer = Trainer()
+    # trainer.train_cfg(model,cfg='./train.yaml')
+    model.train_by_cfg(cfg='train.yaml')

+ 15 - 0
models/line_detect/train.yaml

@@ -0,0 +1,15 @@
+io:
+  logdir: logs/
+  datadir: I:/datasets/wirenet_1000
+  resume_from:
+  num_workers: 8
+  tensorboard_port: 6000
+  validation_interval: 300
+
+optim:
+  name: Adam
+  lr: 4.0e-4
+  amsgrad: True
+  weight_decay: 1.0e-4
+  max_epoch: 1000
+  lr_decay_epoch: 10

+ 169 - 0
models/line_detect/trainer.py

@@ -0,0 +1,169 @@
+import os
+
+import torch
+from torch.utils.tensorboard import SummaryWriter
+
+from models.base.base_model import BaseModel
+from models.base.base_trainer import BaseTrainer
+from models.config.config_tool import read_yaml
+from models.line_detect.dataset_LD import WirePointDataset
+from models.line_detect.postprocess import box_line_, show_
+from utils.log_util import show_line, save_latest_model, save_best_model
+from tools import utils
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+
+    return total_loss
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+def move_to_device(data, device):
+    if isinstance(data, (list, tuple)):
+        return type(data)(move_to_device(item, device) for item in data)
+    elif isinstance(data, dict):
+        return {key: move_to_device(value, device) for key, value in data.items()}
+    elif isinstance(data, torch.Tensor):
+        return data.to(device)
+    else:
+        return data  # 对于非张量类型的数据不做任何改变
+
+class Trainer(BaseTrainer):
+    def __init__(self, model=None,
+                 dataset=None,
+                 device='cuda',
+                 **kwargs):
+
+        super().__init__(model,dataset,device,**kwargs)
+
+    def move_to_device(self, data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(self.move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: self.move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+    def load_best_model(self,model, optimizer, save_path, device):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            if optimizer is not None:
+                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            epoch = checkpoint['epoch']
+            loss = checkpoint['loss']
+            print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model, optimizer
+
+    def writer_loss(self, writer, losses, epoch):
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            writer.add_scalar(f'loss/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+    def train_cfg(self, model:BaseModel, cfg):
+        # cfg = r'./config/wireframe.yaml'
+        cfg = read_yaml(cfg)
+        print(f'cfg:{cfg}')
+        # print(cfg['n_dyn_negl'])
+        self.train(model, **cfg)
+
+    def train(self, model, **kwargs):
+        dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
+        train_sampler = torch.utils.data.RandomSampler(dataset_train)
+        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
+        train_collate_fn = utils.collate_fn_wirepoint
+        data_loader_train = torch.utils.data.DataLoader(
+            dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
+        )
+
+        dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
+        val_sampler = torch.utils.data.RandomSampler(dataset_val)
+        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
+        val_collate_fn = utils.collate_fn_wirepoint
+        data_loader_val = torch.utils.data.DataLoader(
+            dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
+        )
+
+        # model = linenet_resnet50_fpn().to(self.device)
+
+        optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
+        writer = SummaryWriter(kwargs['io']['logdir'])
+        model.to(device)
+
+        # 加载权重
+        save_path = 'logs/pth/best_model.pth'
+        model, optimizer = self.load_best_model(model, optimizer, save_path, device)
+
+        logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
+        os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
+        latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
+        best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
+        global_step = 0
+
+        for epoch in range(kwargs['optim']['max_epoch']):
+            print(f"epoch:{epoch}")
+            total_train_loss = 0.0
+            model.train()
+
+            for imgs, targets in data_loader_train:
+                imgs = move_to_device(imgs, device)
+                targets=move_to_device(targets,device)
+                # print(f'imgs:{len(imgs)}')
+                # print(f'targets:{len(targets)}')
+                losses = model(imgs, targets)
+                # print(losses)
+                loss = _loss(losses)
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+                self.writer_loss(writer, losses, epoch)
+
+            model.eval()
+            with torch.no_grad():
+                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                    pred = model(self.move_to_device(imgs, self.device))
+                    pred_ = box_line_(pred)  # 将box与line对应
+                    show_(imgs, pred_, epoch, writer)
+                    if batch_idx == 0:
+                        show_line(imgs[0], pred, epoch, writer)
+                    break
+            avg_train_loss = total_train_loss / len(data_loader_train)
+            writer.add_scalar('loss/train', avg_train_loss, epoch)
+            best_loss = 10000
+            save_latest_model(
+                model,
+                latest_model_path,
+                epoch,
+                optimizer
+            )
+            best_loss = save_best_model(
+                model,
+                best_model_path,
+                epoch,
+                avg_train_loss,
+                best_loss,
+                optimizer
+            )

+ 0 - 120
models/line_net/fasterrcnn_resnet50.py

@@ -1,120 +0,0 @@
-import torch
-import torch.nn as nn
-import torchvision
-from typing import Dict, List, Optional, Tuple
-import torch.nn.functional as F
-from torchvision.ops import MultiScaleRoIAlign
-from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
-from torchvision.models.detection.transform import GeneralizedRCNNTransform
-
-
-def get_model(num_classes):
-    # 加载预训练的ResNet-50 FPN backbone
-    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
-
-    # 获取分类器的输入特征数
-    in_features = model.roi_heads.box_predictor.cls_score.in_features
-
-    # 替换分类器以适应新的类别数量
-    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
-
-    return model
-
-
-def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
-    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
-    """
-    Computes the loss for Faster R-CNN.
-
-    Args:
-        class_logits (Tensor)
-        box_regression (Tensor)
-        labels (list[BoxList])
-        regression_targets (Tensor)
-
-    Returns:
-        classification_loss (Tensor)
-        box_loss (Tensor)
-    """
-
-    labels = torch.cat(labels, dim=0)
-    regression_targets = torch.cat(regression_targets, dim=0)
-
-    classification_loss = F.cross_entropy(class_logits, labels)
-
-    # get indices that correspond to the regression targets for
-    # the corresponding ground truth labels, to be used with
-    # advanced indexing
-    sampled_pos_inds_subset = torch.where(labels > 0)[0]
-    labels_pos = labels[sampled_pos_inds_subset]
-    N, num_classes = class_logits.shape
-    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
-
-    box_loss = F.smooth_l1_loss(
-        box_regression[sampled_pos_inds_subset, labels_pos],
-        regression_targets[sampled_pos_inds_subset],
-        beta=1 / 9,
-        reduction="sum",
-    )
-    box_loss = box_loss / labels.numel()
-
-    return classification_loss, box_loss
-
-
-class Fasterrcnn_resnet50(nn.Module):
-    def __init__(self, num_classes=5, num_stacks=1):
-        super(Fasterrcnn_resnet50, self).__init__()
-
-        self.model = get_model(num_classes=5)
-        self.backbone = self.model.backbone
-
-        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
-
-        out_channels = self.backbone.out_channels
-        resolution = self.box_roi_pool.output_size[0]
-        representation_size = 1024
-        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-
-        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
-
-        # 多任务输出层
-        self.score_layers = nn.ModuleList([
-            nn.Sequential(
-                nn.Conv2d(256, 128, kernel_size=3, padding=1),
-                nn.BatchNorm2d(128),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(128, num_classes, kernel_size=1)
-            )
-            for _ in range(num_stacks)
-        ])
-
-    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
-
-        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
-                                             image_std=[0.229, 0.224, 0.225])
-        images, targets = transform(x, target1)
-        x_ = self.backbone(images.tensors)
-
-        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
-        # print(f'backbone:{self.backbone}')
-        # print(f'Fasterrcnn_resnet50 x_:{x_}')
-        feature_ = x_['0']  # 图片特征
-        outputs = []
-        for score_layer in self.score_layers:
-            output = score_layer(feature_)
-            outputs.append(output)  # 多头
-
-        if train_or_val == "training":
-            loss_box = self.model(x, target1)
-            return outputs, feature_, loss_box
-        else:
-            box_all = self.model(x, target1)
-            return outputs, feature_, box_all
-
-
-def fasterrcnn_resnet50(**kwargs):
-    model = Fasterrcnn_resnet50(
-        num_classes=kwargs.get("num_classes", 5),
-        num_stacks=kwargs.get("num_stacks", 1)
-    )
-    return model

+ 0 - 0
models/obj/__init__.py


+ 0 - 44
models/utils.py

@@ -1,44 +0,0 @@
-# import torch
-#
-#
-# def evaluate(model, data_loader, device):
-#     n_threads = torch.get_num_threads()
-#     # FIXME remove this and make paste_masks_in_image run on the GPU
-#     torch.set_num_threads(1)
-#     cpu_device = torch.device("cpu")
-#     model.eval()
-#     metric_logger = utils.MetricLogger(delimiter="  ")
-#     header = "Test:"
-#
-#     coco = get_coco_api_from_dataset(data_loader.dataset)
-#     iou_types = _get_iou_types(model)
-#     coco_evaluator = CocoEvaluator(coco, iou_types)
-#
-#     print(f'start to evaluate!!!')
-#     for images, targets in metric_logger.log_every(data_loader, 10, header):
-#         images = list(img.to(device) for img in images)
-#
-#         if torch.cuda.is_available():
-#             torch.cuda.synchronize()
-#         model_time = time.time()
-#         outputs = model(images)
-#
-#         outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
-#         model_time = time.time() - model_time
-#
-#         res = {target["image_id"]: output for target, output in zip(targets, outputs)}
-#         evaluator_time = time.time()
-#         coco_evaluator.update(res)
-#         evaluator_time = time.time() - evaluator_time
-#         metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
-#
-#     # gather the stats from all processes
-#     metric_logger.synchronize_between_processes()
-#     print("Averaged stats:", metric_logger)
-#     coco_evaluator.synchronize_between_processes()
-#
-#     # accumulate predictions from all images
-#     coco_evaluator.accumulate()
-#     coco_evaluator.summarize()
-#     torch.set_num_threads(n_threads)
-#     return coco_evaluator

+ 1 - 1
models/wirenet2/trainer.py

@@ -13,7 +13,7 @@ from models.ins.maskrcnn_dataset import MaskRCNNDataset
 from models.keypoint.keypoint_dataset import KeypointDataset
 from tools import utils, presets
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
-    model.train()
+    model.train1()
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"

+ 1 - 1
tools/engine.py

@@ -10,7 +10,7 @@ from tools.coco_utils import get_coco_api_from_dataset
 
 
 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
-    model.train()
+    model.train1()
     metric_logger = utils.MetricLogger(delimiter="  ")
     metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
     header = f"Epoch: [{epoch}]"

+ 1 - 6
train——line_rcnn.py

@@ -1,8 +1,6 @@
 # 2025/2/9
 import os
-from typing import Optional, Any
 
-import cv2
 import numpy as np
 import torch
 
@@ -11,19 +9,16 @@ from models.line_detect.dataset_LD import WirePointDataset
 from tools import utils
 
 from torch.utils.tensorboard import SummaryWriter
-import matplotlib.pyplot as plt
 import matplotlib as mpl
-from skimage import io
 
 from models.line_detect.line_net import linenet_resnet50_fpn
 from torchvision.utils import draw_bounding_boxes
 from models.wirenet.postprocess import postprocess
 from torchvision import transforms
-from collections import OrderedDict
 
 from PIL import Image
 
-from predict import box_line_, show_
+from models.line_detect.postprocess import box_line_, show_
 import matplotlib.pyplot as plt
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

+ 0 - 0
models/line_net/__init__.py → utils/__init__.py


+ 88 - 0
utils/log_util.py

@@ -0,0 +1,88 @@
+import os
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+
+from libs.vision_libs.utils import draw_bounding_boxes
+from models.wirenet.postprocess import postprocess
+from torchvision import transforms
+
+
+def save_latest_model(model, save_path, epoch, optimizer=None):
+    os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+    checkpoint = {
+        'epoch': epoch,
+        'model_state_dict': model.state_dict(),
+    }
+
+    if optimizer is not None:
+        checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+    torch.save(checkpoint, save_path)
+
+def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
+    os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+    if current_loss < best_loss:
+        checkpoint = {
+            'epoch': epoch,
+            'model_state_dict': model.state_dict(),
+            'loss': current_loss
+        }
+
+        if optimizer is not None:
+            checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+        torch.save(checkpoint, save_path)
+        print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
+
+        return current_loss
+
+    return best_loss
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    # print(f'pred[1]:{pred[1]}')
+    H = pred[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.85]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)