소스 검색

重构Model trainer

RenLiqiang 3 달 전
부모
커밋
695fc785b0

+ 36 - 3
models/base/base_detection_net.py

@@ -10,9 +10,11 @@ import torch
 from torch import nn, Tensor
 
 from libs.vision_libs.utils import _log_api_usage_once
+from models.base.base_model import BaseModel
+from models.line_detect.trainer import Trainer
 
 
-class BaseDetectionNet(nn.Module):
+class BaseDetectionNet(BaseModel):
     """
     Main class for Generalized R-CNN.
 
@@ -25,6 +27,30 @@ class BaseDetectionNet(nn.Module):
             the model
     """
 
+    def train(self, cfg):
+        pass
+
+    def get_loss(self, Loss, results, inputs, device):
+        pass
+
+    def get_optimizer(self, cfg_pipeline):
+        pass
+
+    def preprocess(self, cfg_pipeline):
+        pass
+
+    def transform(self, cfg_pipeline):
+        pass
+
+    def inference_begin(self, data):
+        pass
+
+    def inference_preprocess(self):
+        pass
+
+    def inference_end(self, inputs, results):
+        pass
+
     def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
         super().__init__()
         _log_api_usage_once(self)
@@ -43,6 +69,13 @@ class BaseDetectionNet(nn.Module):
 
         return detections
 
+    def train(self, cfg):
+        self.__trainer.train(self, "test")
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.load_state_dict(state_dict)
+
     def forward(self, images, targets=None):
         # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
         """
@@ -105,12 +138,12 @@ class BaseDetectionNet(nn.Module):
         proposals, proposal_losses = self.rpn(images, features, targets)
 
         detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
-        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
+        detections = self.transform.postprocess(detections, images.image_sizes,
+                                                original_image_sizes)  # type: ignore[operator]
 
         # ->multi task head
         # ->learner,->vectorize
 
-
         losses = {}
         losses.update(detector_losses)
         losses.update(proposal_losses)

+ 106 - 0
models/base/base_model.py

@@ -0,0 +1,106 @@
+import numpy as np
+import torch
+from abc import ABC, abstractmethod
+
+from models.base.base_trainer import BaseTrainer
+
+
+class BaseModel(ABC, torch.nn.Module):
+
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.cfg = None
+        self.trainer = None
+
+    @abstractmethod
+    def train(self, cfg):
+        return
+
+    @abstractmethod
+    def get_loss(self, Loss, results, inputs, device):
+        """Computes the loss given the network input and outputs.
+
+        Args:
+            Loss: A loss object.
+            results: This is the output of the model.
+            inputs: This is the input to the model.
+            device: The torch device to be used.
+
+        Returns:
+            Returns the loss value.
+        """
+        return
+
+    @abstractmethod
+    def get_optimizer(self, cfg_pipeline):
+        """Returns an optimizer object for the model.
+
+        Args:
+            cfg_pipeline: A Config object with the configuration of the pipeline.
+
+        Returns:
+            Returns a new optimizer object.
+        """
+        return
+
+    @abstractmethod
+    def preprocess(self, cfg_pipeline):
+        """Data preprocessing function.
+
+        This function is called before training to preprocess the data from a
+        dataset.
+
+        Args:
+            data: A sample from the dataset.
+            attr: The corresponding attributes.
+
+        Returns:
+            Returns the preprocessed data
+        """
+        return
+
+    @abstractmethod
+    def transform(self, cfg_pipeline):
+        """Transform function for the point cloud and features.
+
+        Args:
+            cfg_pipeline: config file for pipeline.
+        """
+        return
+
+    @abstractmethod
+    def inference_begin(self, data):
+        """Function called right before running inference.
+
+        Args:
+            data: A data from the dataset.
+        """
+        return
+
+    @abstractmethod
+    def inference_preprocess(self):
+        """This function prepares the inputs for the model.
+
+        Returns:
+            The inputs to be consumed by the call() function of the model.
+        """
+        return
+
+    @abstractmethod
+    def inference_end(self, inputs, results):
+        """This function is called after the inference.
+
+        This function can be implemented to apply post-processing on the
+        network outputs.
+
+        Args:
+            results: The model outputs as returned by the call() function.
+                Post-processing is applied on this object.
+
+        Returns:
+            Returns True if the inference is complete and otherwise False.
+            Returning False can be used to implement inference for large point
+            clouds which require multiple passes.
+        """
+        return

+ 21 - 0
models/base/base_trainer.py

@@ -0,0 +1,21 @@
+from abc import ABC, abstractmethod
+
+
+class BaseTrainer(ABC):
+    def __init__(self,
+                 model=None,
+                 dataset=None,
+                 device='cuda',
+                 **kwargs):
+
+        self.model = model
+        self.dataset = dataset
+        self.device=device
+
+    @abstractmethod
+    def train_cfg(self,model,cfg):
+        return
+
+    @abstractmethod
+    def train(self,model, **kwargs):
+        return

+ 264 - 338
models/line_detect/line_net.py

@@ -1,4 +1,3 @@
-
 from typing import Any, Callable, List, Optional, Tuple, Union
 import torch
 from torch import nn
@@ -13,7 +12,6 @@ from libs.vision_libs.ops import misc as misc_nn_ops
 from libs.vision_libs.transforms._presets import ObjectDetection
 from .line_head import LineRCNNHeads
 from .line_predictor import LineRCNNPredictor
-from .roi_heads import RoIHeads
 from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
 from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
@@ -22,10 +20,13 @@ from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 
-from models.config.config_tool import read_yaml
-import numpy as np
+from .roi_heads import RoIHeads
+from .trainer import Trainer
+from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
+from ..config.config_tool import read_yaml
+
 FEATURE_DIM = 8
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -40,50 +41,6 @@ __all__ = [
     "linenet_mobilenet_v3_large_fpn",
     "linenet_mobilenet_v3_large_320_fpn",
 ]
-# __all__ = [
-#     "LineNet",
-#     "LineRCNN_ResNet50_FPN_Weights",
-#     "linercnn_resnet50_fpn",
-# ]
-
-
-def non_maximum_suppression(a):
-    ap = F.max_pool2d(a, 3, stride=1, padding=1)
-    mask = (a == ap).float().clamp(min=0.0)
-    return a * mask
-
-
-# class Bottleneck1D(nn.Module):
-#     def __init__(self, inplanes, outplanes):
-#         super(Bottleneck1D, self).__init__()
-#
-#         planes = outplanes // 2
-#         self.op = nn.Sequential(
-#             nn.BatchNorm1d(inplanes),
-#             nn.ReLU(inplace=True),
-#             nn.Conv1d(inplanes, planes, kernel_size=1),
-#             nn.BatchNorm1d(planes),
-#             nn.ReLU(inplace=True),
-#             nn.Conv1d(planes, planes, kernel_size=3, padding=1),
-#             nn.BatchNorm1d(planes),
-#             nn.ReLU(inplace=True),
-#             nn.Conv1d(planes, outplanes, kernel_size=1),
-#         )
-#
-#     def forward(self, x):
-#         return x + self.op(x)
-
-
-
-
-
-
-
-
-
-from .roi_heads import RoIHeads
-
-from ..base.base_detection_net import BaseDetectionNet
 
 
 def _default_anchorgen():
@@ -93,259 +50,229 @@ def _default_anchorgen():
 
 
 class LineNet(BaseDetectionNet):
-    """
-    Implements Faster R-CNN.
-
-    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
-    image, and should be in 0-1 range. Different images can have different sizes.
-
-    The behavior of the model changes depending on if it is in training or evaluation mode.
-
-    During training, the model expects both the input tensors and targets (list of dictionary),
-    containing:
-        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
-          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (Int64Tensor[N]): the class label for each ground-truth box
-
-    The model returns a Dict[Tensor] during training, containing the classification and regression
-    losses for both the RPN and the R-CNN.
-
-    During inference, the model requires only the input tensors, and returns the post-processed
-    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
-    follows:
-        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
-          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
-        - labels (Int64Tensor[N]): the predicted labels for each image
-        - scores (Tensor[N]): the scores or each prediction
-
-    Args:
-        backbone (nn.Module): the network used to compute the features for the model.
-            It should contain an out_channels attribute, which indicates the number of output
-            channels that each feature map has (and it should be the same for all feature maps).
-            The backbone should return a single Tensor or and OrderedDict[Tensor].
-        num_classes (int): number of output classes of the model (including the background).
-            If box_predictor is specified, num_classes should be None.
-        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
-        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
-        image_mean (Tuple[float, float, float]): mean values used for input normalization.
-            They are generally the mean values of the dataset on which the backbone has been trained
-            on
-        image_std (Tuple[float, float, float]): std values used for input normalization.
-            They are generally the std values of the dataset on which the backbone has been trained on
-        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
-            maps.
-        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
-        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
-        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
-        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
-        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
-        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
-        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
-            considered as positive during training of the RPN.
-        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
-            considered as negative during training of the RPN.
-        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
-            for computing the loss
-        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
-            of the RPN
-        rpn_score_thresh (float): during inference, only return proposals with a classification score
-            greater than rpn_score_thresh
-        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
-            the locations indicated by the bounding boxes
-        box_head (nn.Module): module that takes the cropped feature maps as input
-        box_predictor (nn.Module): module that takes the output of box_head and returns the
-            classification logits and box regression deltas.
-        box_score_thresh (float): during inference, only return proposals with a classification score
-            greater than box_score_thresh
-        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
-        box_detections_per_img (int): maximum number of detections per image, for all classes.
-        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
-            considered as positive during training of the classification head
-        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
-            considered as negative during training of the classification head
-        box_batch_size_per_image (int): number of proposals that are sampled during training of the
-            classification head
-        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
-            of the classification head
-        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
-            bounding boxes
-
-    Example::
-
-        >>> import torch
-        >>> import torchvision
-        >>> from torchvision.models.detection import FasterRCNN
-        >>> from torchvision.models.detection.rpn import AnchorGenerator
-        >>> # load a pre-trained model for classification and return
-        >>> # only the features
-        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
-        >>> # FasterRCNN needs to know the number of
-        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
-        >>> # so we need to add it here
-        >>> backbone.out_channels = 1280
-        >>>
-        >>> # let's make the RPN generate 5 x 3 anchors per spatial
-        >>> # location, with 5 different sizes and 3 different aspect
-        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
-        >>> # map could potentially have different sizes and
-        >>> # aspect ratios
-        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
-        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
-        >>>
-        >>> # let's define what are the feature maps that we will
-        >>> # use to perform the region of interest cropping, as well as
-        >>> # the size of the crop after rescaling.
-        >>> # if your backbone returns a Tensor, featmap_names is expected to
-        >>> # be ['0']. More generally, the backbone should return an
-        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
-        >>> # feature maps to use.
-        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
-        >>>                                                 output_size=7,
-        >>>                                                 sampling_ratio=2)
-        >>>
-        >>> # put the pieces together inside a FasterRCNN model
-        >>> model = FasterRCNN(backbone,
-        >>>                    num_classes=2,
-        >>>                    rpn_anchor_generator=anchor_generator,
-        >>>                    box_roi_pool=roi_pooler)
-        >>> model.eval()
-        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
-        >>> predictions = model(x)
-    """
-
-    def __init__(
-        self,
-        backbone,
-        num_classes=None,
-        # transform parameters
-        min_size=512,
-        max_size=1333,
-        image_mean=None,
-        image_std=None,
-        # RPN parameters
-        rpn_anchor_generator=None,
-        rpn_head=None,
-        rpn_pre_nms_top_n_train=2000,
-        rpn_pre_nms_top_n_test=1000,
-        rpn_post_nms_top_n_train=2000,
-        rpn_post_nms_top_n_test=1000,
-        rpn_nms_thresh=0.7,
-        rpn_fg_iou_thresh=0.7,
-        rpn_bg_iou_thresh=0.3,
-        rpn_batch_size_per_image=256,
-        rpn_positive_fraction=0.5,
-        rpn_score_thresh=0.0,
-        # Box parameters
-        box_roi_pool=None,
-        box_head=None,
-        box_predictor=None,
-        box_score_thresh=0.05,
-        box_nms_thresh=0.5,
-        box_detections_per_img=100,
-        box_fg_iou_thresh=0.5,
-        box_bg_iou_thresh=0.5,
-        box_batch_size_per_image=512,
-        box_positive_fraction=0.25,
-        bbox_reg_weights=None,
-        # line parameters
-        line_head=None,
-        line_predictor=None,
-        **kwargs,
-    ):
-
-        if not hasattr(backbone, "out_channels"):
-            raise ValueError(
-                "backbone should contain an attribute out_channels "
-                "specifying the number of output channels (assumed to be the "
-                "same for all the levels)"
-            )
-
-        if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
-            raise TypeError(
-                f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
-            )
-        if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
-            raise TypeError(
-                f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
-            )
-
-        if num_classes is not None:
-            if box_predictor is not None:
-                raise ValueError("num_classes should be None when box_predictor is specified")
-        else:
-            if box_predictor is None:
-                raise ValueError("num_classes should not be None when box_predictor is not specified")
-
-        out_channels = backbone.out_channels
-
-        if line_head is None:
-            num_class = 5
-            line_head = LineRCNNHeads(out_channels, num_class)
-
-        if line_predictor is None:
-            line_predictor = LineRCNNPredictor()
-
-        if rpn_anchor_generator is None:
+    def __init__(self, cfg, **kwargs):
+        cfg = read_yaml(cfg)
+        backbone = cfg['model']['backbone']
+        num_classes = cfg['model']['num_classes']
+
+        if backbone == 'resnet50_fpn':
+            is_trained = False
+            trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
+            norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+            backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
+            backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+            out_channels = backbone.out_channels
+
+            min_size = 512,
+            max_size = 1333,
+            rpn_pre_nms_top_n_train = 2000,
+            rpn_pre_nms_top_n_test = 1000,
+            rpn_post_nms_top_n_train = 2000,
+            rpn_post_nms_top_n_test = 1000,
+            rpn_nms_thresh = 0.7,
+            rpn_fg_iou_thresh = 0.7,
+            rpn_bg_iou_thresh = 0.3,
+            rpn_batch_size_per_image = 256,
+            rpn_positive_fraction = 0.5,
+            rpn_score_thresh = 0.0,
+            box_score_thresh = 0.05,
+            box_nms_thresh = 0.5,
+            box_detections_per_img = 100,
+            box_fg_iou_thresh = 0.5,
+            box_bg_iou_thresh = 0.5,
+            box_batch_size_per_image = 512,
+            box_positive_fraction = 0.25,
+            bbox_reg_weights = None,
+
+            line_head = LineRCNNHeads(out_channels, 5)
+            line_predictor = LineRCNNPredictor(cfg)
             rpn_anchor_generator = _default_anchorgen()
-        if rpn_head is None:
             rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+            rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+            rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+            rpn = RegionProposalNetwork(
+                rpn_anchor_generator,
+                rpn_head,
+                rpn_fg_iou_thresh,
+                rpn_bg_iou_thresh,
+                rpn_batch_size_per_image,
+                rpn_positive_fraction,
+                rpn_pre_nms_top_n,
+                rpn_post_nms_top_n,
+                rpn_nms_thresh,
+                score_thresh=rpn_score_thresh,
+            )
 
-        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
-        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
-
-        rpn = RegionProposalNetwork(
-            rpn_anchor_generator,
-            rpn_head,
-            rpn_fg_iou_thresh,
-            rpn_bg_iou_thresh,
-            rpn_batch_size_per_image,
-            rpn_positive_fraction,
-            rpn_pre_nms_top_n,
-            rpn_post_nms_top_n,
-            rpn_nms_thresh,
-            score_thresh=rpn_score_thresh,
-        )
-
-        if box_roi_pool is None:
             box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
 
-        if box_head is None:
             resolution = box_roi_pool.output_size[0]
             representation_size = 1024
-            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
-
-        if box_predictor is None:
+            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
             representation_size = 1024
             box_predictor = BoxPredictor(representation_size, num_classes)
 
-        roi_heads = RoIHeads(
-            # Box
-            box_roi_pool,
-            box_head,
-            box_predictor,
-            line_head,
-            line_predictor,
-            box_fg_iou_thresh,
-            box_bg_iou_thresh,
-            box_batch_size_per_image,
-            box_positive_fraction,
-            bbox_reg_weights,
-            box_score_thresh,
-            box_nms_thresh,
-            box_detections_per_img,
-        )
-
-        if image_mean is None:
+            roi_heads = RoIHeads(
+                # Box
+                box_roi_pool,
+                box_head,
+                box_predictor,
+                line_head,
+                line_predictor,
+                box_fg_iou_thresh,
+                box_bg_iou_thresh,
+                box_batch_size_per_image,
+                box_positive_fraction,
+                bbox_reg_weights,
+                box_score_thresh,
+                box_nms_thresh,
+                box_detections_per_img,
+            )
             image_mean = [0.485, 0.456, 0.406]
-        if image_std is None:
             image_std = [0.229, 0.224, 0.225]
-        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
-
-        super().__init__(backbone, rpn, roi_heads, transform)
-
-        self.roi_heads = roi_heads
-        # self.roi_heads.line_head = line_head
-        # self.roi_heads.line_predictor = line_predictor
+            transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+            super().__init__(backbone, rpn, roi_heads, transform)
+            self.roi_heads = roi_heads
+
+    # def __init__(
+    #         self,
+    #         backbone,
+    #         num_classes=None,
+    #         # transform parameters
+    #         min_size=512,
+    #         max_size=1333,
+    #         image_mean=None,
+    #         image_std=None,
+    #         # RPN parameters
+    #         rpn_anchor_generator=None,
+    #         rpn_head=None,
+    #         rpn_pre_nms_top_n_train=2000,
+    #         rpn_pre_nms_top_n_test=1000,
+    #         rpn_post_nms_top_n_train=2000,
+    #         rpn_post_nms_top_n_test=1000,
+    #         rpn_nms_thresh=0.7,
+    #         rpn_fg_iou_thresh=0.7,
+    #         rpn_bg_iou_thresh=0.3,
+    #         rpn_batch_size_per_image=256,
+    #         rpn_positive_fraction=0.5,
+    #         rpn_score_thresh=0.0,
+    #         # Box parameters
+    #         box_roi_pool=None,
+    #         box_head=None,
+    #         box_predictor=None,
+    #         box_score_thresh=0.05,
+    #         box_nms_thresh=0.5,
+    #         box_detections_per_img=100,
+    #         box_fg_iou_thresh=0.5,
+    #         box_bg_iou_thresh=0.5,
+    #         box_batch_size_per_image=512,
+    #         box_positive_fraction=0.25,
+    #         bbox_reg_weights=None,
+    #         # line parameters
+    #         line_head=None,
+    #         line_predictor=None,
+    #         **kwargs,
+    # ):
+    #
+    #     if not hasattr(backbone, "out_channels"):
+    #         raise ValueError(
+    #             "backbone should contain an attribute out_channels "
+    #             "specifying the number of output channels (assumed to be the "
+    #             "same for all the levels)"
+    #         )
+    #
+    #     if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
+    #         raise TypeError(
+    #             f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
+    #         )
+    #     if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
+    #         raise TypeError(
+    #             f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
+    #         )
+    #
+    #     if num_classes is not None:
+    #         if box_predictor is not None:
+    #             raise ValueError("num_classes should be None when box_predictor is specified")
+    #     else:
+    #         if box_predictor is None:
+    #             raise ValueError("num_classes should not be None when box_predictor is not specified")
+    #
+    #     out_channels = backbone.out_channels
+    #
+    #     if line_head is None:
+    #         num_class = 5
+    #         line_head = LineRCNNHeads(out_channels, num_class)
+    #
+    #     if line_predictor is None:
+    #         line_predictor = LineRCNNPredictor()
+    #
+    #     if rpn_anchor_generator is None:
+    #         rpn_anchor_generator = _default_anchorgen()
+    #     if rpn_head is None:
+    #         rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+    #
+    #     rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+    #     rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+    #
+    #     rpn = RegionProposalNetwork(
+    #         rpn_anchor_generator,
+    #         rpn_head,
+    #         rpn_fg_iou_thresh,
+    #         rpn_bg_iou_thresh,
+    #         rpn_batch_size_per_image,
+    #         rpn_positive_fraction,
+    #         rpn_pre_nms_top_n,
+    #         rpn_post_nms_top_n,
+    #         rpn_nms_thresh,
+    #         score_thresh=rpn_score_thresh,
+    #     )
+    #
+    #     if box_roi_pool is None:
+    #         box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+    #
+    #     if box_head is None:
+    #         resolution = box_roi_pool.output_size[0]
+    #         representation_size = 1024
+    #         box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+    #
+    #     if box_predictor is None:
+    #         representation_size = 1024
+    #         box_predictor = BoxPredictor(representation_size, num_classes)
+    #
+    #     roi_heads = RoIHeads(
+    #         # Box
+    #         box_roi_pool,
+    #         box_head,
+    #         box_predictor,
+    #         line_head,
+    #         line_predictor,
+    #         box_fg_iou_thresh,
+    #         box_bg_iou_thresh,
+    #         box_batch_size_per_image,
+    #         box_positive_fraction,
+    #         bbox_reg_weights,
+    #         box_score_thresh,
+    #         box_nms_thresh,
+    #         box_detections_per_img,
+    #     )
+    #
+    #     if image_mean is None:
+    #         image_mean = [0.485, 0.456, 0.406]
+    #     if image_std is None:
+    #         image_std = [0.229, 0.224, 0.225]
+    #     transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+    #
+    #     super().__init__(backbone, rpn, roi_heads, transform)
+    #
+    #     self.roi_heads = roi_heads
+
+    # self.roi_heads.line_head = line_head
+    # self.roi_heads.line_predictor = line_predictor
+
+    def train(self, cfg):
+        # cfg = read_yaml(cfg)
+        self.trainer = Trainer()
+        self.trainer.train_cfg(model=self, cfg=cfg)
 
 
 class TwoMLPHead(nn.Module):
@@ -374,11 +301,11 @@ class TwoMLPHead(nn.Module):
 
 class LineNetConvFCHead(nn.Sequential):
     def __init__(
-        self,
-        input_size: Tuple[int, int, int],
-        conv_layers: List[int],
-        fc_layers: List[int],
-        norm_layer: Optional[Callable[..., nn.Module]] = None,
+            self,
+            input_size: Tuple[int, int, int],
+            conv_layers: List[int],
+            fc_layers: List[int],
+            norm_layer: Optional[Callable[..., nn.Module]] = None,
     ):
         """
         Args:
@@ -533,13 +460,13 @@ class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
 def linenet_resnet50_fpn(
-    *,
-    weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
-    progress: bool = True,
-    num_classes: Optional[int] = None,
-    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
-    trainable_backbone_layers: Optional[int] = None,
-    **kwargs: Any,
+        *,
+        weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
 ) -> LineNet:
     """
     Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
@@ -652,13 +579,13 @@ def linenet_resnet50_fpn(
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
 def linenet_resnet50_fpn_v2(
-    *,
-    weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
-    progress: bool = True,
-    num_classes: Optional[int] = None,
-    weights_backbone: Optional[ResNet50_Weights] = None,
-    trainable_backbone_layers: Optional[int] = None,
-    **kwargs: Any,
+        *,
+        weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = None,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
 ) -> LineNet:
     """
     Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
@@ -727,13 +654,13 @@ def linenet_resnet50_fpn_v2(
 
 
 def _linenet_mobilenet_v3_large_fpn(
-    *,
-    weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
-    progress: bool,
-    num_classes: Optional[int],
-    weights_backbone: Optional[MobileNet_V3_Large_Weights],
-    trainable_backbone_layers: Optional[int],
-    **kwargs: Any,
+        *,
+        weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
+        progress: bool,
+        num_classes: Optional[int],
+        weights_backbone: Optional[MobileNet_V3_Large_Weights],
+        trainable_backbone_layers: Optional[int],
+        **kwargs: Any,
 ) -> LineNet:
     if weights is not None:
         weights_backbone = None
@@ -748,14 +675,14 @@ def _linenet_mobilenet_v3_large_fpn(
     backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
     backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
     anchor_sizes = (
-        (
-            32,
-            64,
-            128,
-            256,
-            512,
-        ),
-    ) * 3
+                       (
+                           32,
+                           64,
+                           128,
+                           256,
+                           512,
+                       ),
+                   ) * 3
     aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
     model = LineNet(
         backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
@@ -773,13 +700,13 @@ def _linenet_mobilenet_v3_large_fpn(
     weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
 )
 def linenet_mobilenet_v3_large_320_fpn(
-    *,
-    weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
-    progress: bool = True,
-    num_classes: Optional[int] = None,
-    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
-    trainable_backbone_layers: Optional[int] = None,
-    **kwargs: Any,
+        *,
+        weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
 ) -> LineNet:
     """
     Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
@@ -847,13 +774,13 @@ def linenet_mobilenet_v3_large_320_fpn(
     weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
 )
 def linenet_mobilenet_v3_large_fpn(
-    *,
-    weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
-    progress: bool = True,
-    num_classes: Optional[int] = None,
-    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
-    trainable_backbone_layers: Optional[int] = None,
-    **kwargs: Any,
+        *,
+        weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
 ) -> LineNet:
     """
     Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
@@ -909,4 +836,3 @@ def linenet_mobilenet_v3_large_fpn(
         trainable_backbone_layers=trainable_backbone_layers,
         **kwargs,
     )
-

+ 68 - 0
models/line_detect/line_net.yaml

@@ -0,0 +1,68 @@
+
+
+
+
+model:
+  image:
+      mean: [109.730, 103.832, 98.681]
+      stddev: [22.275, 22.124, 23.229]
+
+  batch_size: 4
+  batch_size_eval: 2
+
+  # backbone multi-task parameters
+  head_size: [[2], [1], [2]]
+  loss_weight:
+    jmap: 8.0
+    lmap: 0.5
+    joff: 0.25
+    lpos: 1
+    lneg: 1
+    boxes: 1.0
+
+  # backbone parameters
+  backbone: resnet50_fpn
+#  backbone: unet
+  depth: 4
+  num_stacks: 1
+  num_blocks: 1
+  num_classes: 2
+
+  # sampler parameters
+  ## static sampler
+  n_stc_posl: 300
+  n_stc_negl: 40
+
+  ## dynamic sampler
+  n_dyn_junc: 300
+  n_dyn_posl: 300
+  n_dyn_negl: 80
+  n_dyn_othr: 600
+
+  # LOIPool layer parameters
+  n_pts0: 32
+  n_pts1: 8
+
+  # line verification network parameters
+  dim_loi: 128
+  dim_fc: 1024
+
+  # maximum junction and line outputs
+  n_out_junc: 250
+  n_out_line: 2500
+
+  # additional ablation study parameters
+  use_cood: 0
+  use_slop: 0
+  use_conv: 0
+
+  # junction threashold for evaluation (See #5)
+  eval_junc_thres: 0.008
+
+optim:
+  name: Adam
+  lr: 4.0e-4
+  amsgrad: True
+  weight_decay: 1.0e-4
+  max_epoch: 1000
+  lr_decay_epoch: 10

+ 7 - 4
models/line_detect/line_predictor.py

@@ -20,17 +20,21 @@ import numpy as np
 import torch.nn.functional as F
 
 FEATURE_DIM = 8
+
+
 def non_maximum_suppression(a):
     ap = F.max_pool2d(a, 3, stride=1, padding=1)
     mask = (a == ap).float().clamp(min=0.0)
     return a * mask
 
+
 class LineRCNNPredictor(nn.Module):
-    def __init__(self):
+    def __init__(self, cfg):
         super().__init__()
         # self.backbone = backbone
         # self.cfg = read_yaml(cfg)
-        self.cfg = read_yaml(r'./config/wireframe.yaml')
+        # self.cfg = read_yaml(r'./config/wireframe.yaml')
+        self.cfg = cfg
         self.n_pts0 = self.cfg['model']['n_pts0']
         self.n_pts1 = self.cfg['model']['n_pts1']
         self.n_stc_posl = self.cfg['model']['n_stc_posl']
@@ -316,9 +320,8 @@ class LineRCNNPredictor(nn.Module):
             return line, label.float(), feat, jcs
 
 
-
 _COMMON_META = {
     "categories": _COCO_PERSON_CATEGORIES,
     "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
     "min_size": (1, 1),
-}
+}

+ 9 - 0
models/line_detect/test_train.py

@@ -0,0 +1,9 @@
+import torch
+
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if __name__ == '__main__':
+
+    model = LineNet('line_net.yaml')
+    model.train(cfg='./train.yaml')

+ 71 - 0
models/line_detect/train.yaml

@@ -0,0 +1,71 @@
+io:
+  logdir: logs/
+  datadir: I:/datasets/wirenet_lm
+  resume_from:
+  num_workers: 8
+  tensorboard_port: 6000
+  validation_interval: 300
+
+model:
+  image:
+      mean: [109.730, 103.832, 98.681]
+      stddev: [22.275, 22.124, 23.229]
+
+  batch_size: 4
+  batch_size_eval: 2
+
+  # backbone multi-task parameters
+  head_size: [[2], [1], [2]]
+  loss_weight:
+    jmap: 8.0
+    lmap: 0.5
+    joff: 0.25
+    lpos: 1
+    lneg: 1
+    boxes: 1.0
+
+  # backbone parameters
+  backbone: fasterrcnn_resnet50
+#  backbone: unet
+  depth: 4
+  num_stacks: 1
+  num_blocks: 1
+
+  # sampler parameters
+  ## static sampler
+  n_stc_posl: 300
+  n_stc_negl: 40
+
+  ## dynamic sampler
+  n_dyn_junc: 300
+  n_dyn_posl: 300
+  n_dyn_negl: 80
+  n_dyn_othr: 600
+
+  # LOIPool layer parameters
+  n_pts0: 32
+  n_pts1: 8
+
+  # line verification network parameters
+  dim_loi: 128
+  dim_fc: 1024
+
+  # maximum junction and line outputs
+  n_out_junc: 250
+  n_out_line: 2500
+
+  # additional ablation study parameters
+  use_cood: 0
+  use_slop: 0
+  use_conv: 0
+
+  # junction threashold for evaluation (See #5)
+  eval_junc_thres: 0.008
+
+optim:
+  name: Adam
+  lr: 4.0e-4
+  amsgrad: True
+  weight_decay: 1.0e-4
+  max_epoch: 1000
+  lr_decay_epoch: 10

+ 107 - 0
models/line_detect/trainer.py

@@ -0,0 +1,107 @@
+import torch
+from torch.utils.tensorboard import SummaryWriter
+
+from models.base.base_trainer import BaseTrainer
+from models.config.config_tool import read_yaml
+from models.line_detect.dataset_LD import WirePointDataset
+from utils.log_util import show_line
+from tools import utils
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+
+    return total_loss
+
+
+class Trainer(BaseTrainer):
+    def __init__(self, model=None,
+                 dataset=None,
+                 device='cuda',
+                 **kwargs):
+        super().__init__(model,dataset,device,**kwargs)
+
+    def move_to_device(self, data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(self.move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: self.move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+
+    def writer_loss(self, writer, losses, epoch):
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            writer.add_scalar(f'loss/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+    def train_cfg(self, model, cfg):
+        # cfg = r'./config/wireframe.yaml'
+        cfg = read_yaml(cfg)
+        print(f'cfg:{cfg}')
+        print(cfg['model']['n_dyn_negl'])
+        self.train(model, **cfg)
+
+    def train(self, model, **cfg):
+        dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
+        train_sampler = torch.utils.data.RandomSampler(dataset_train)
+        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
+        train_collate_fn = utils.collate_fn_wirepoint
+        data_loader_train = torch.utils.data.DataLoader(
+            dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
+        )
+
+        dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+        val_sampler = torch.utils.data.RandomSampler(dataset_val)
+        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
+        val_collate_fn = utils.collate_fn_wirepoint
+        data_loader_val = torch.utils.data.DataLoader(
+            dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
+        )
+
+        # model = linenet_resnet50_fpn().to(self.device)
+
+        optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
+        writer = SummaryWriter(cfg['io']['logdir'])
+
+        for epoch in range(cfg['optim']['max_epoch']):
+            print(f"epoch:{epoch}")
+            model.train()
+
+            for imgs, targets in data_loader_train:
+                losses = model(self.move_to_device(imgs, self.device), self.move_to_device(targets, self.device))
+                # print(losses)
+                loss = _loss(losses)
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+                self.writer_loss(writer, losses, epoch)
+
+            model.eval()
+            with torch.no_grad():
+                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                    pred = model(self.move_to_device(imgs, self.device))
+                    if batch_idx == 0:
+                        show_line(imgs[0], pred, epoch, writer)
+                    break

+ 0 - 44
models/utils.py

@@ -1,44 +0,0 @@
-# import torch
-#
-#
-# def evaluate(model, data_loader, device):
-#     n_threads = torch.get_num_threads()
-#     # FIXME remove this and make paste_masks_in_image run on the GPU
-#     torch.set_num_threads(1)
-#     cpu_device = torch.device("cpu")
-#     model.eval()
-#     metric_logger = utils.MetricLogger(delimiter="  ")
-#     header = "Test:"
-#
-#     coco = get_coco_api_from_dataset(data_loader.dataset)
-#     iou_types = _get_iou_types(model)
-#     coco_evaluator = CocoEvaluator(coco, iou_types)
-#
-#     print(f'start to evaluate!!!')
-#     for images, targets in metric_logger.log_every(data_loader, 10, header):
-#         images = list(img.to(device) for img in images)
-#
-#         if torch.cuda.is_available():
-#             torch.cuda.synchronize()
-#         model_time = time.time()
-#         outputs = model(images)
-#
-#         outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
-#         model_time = time.time() - model_time
-#
-#         res = {target["image_id"]: output for target, output in zip(targets, outputs)}
-#         evaluator_time = time.time()
-#         coco_evaluator.update(res)
-#         evaluator_time = time.time() - evaluator_time
-#         metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
-#
-#     # gather the stats from all processes
-#     metric_logger.synchronize_between_processes()
-#     print("Averaged stats:", metric_logger)
-#     coco_evaluator.synchronize_between_processes()
-#
-#     # accumulate predictions from all images
-#     coco_evaluator.accumulate()
-#     coco_evaluator.summarize()
-#     torch.set_num_threads(n_threads)
-#     return coco_evaluator

+ 0 - 0
utils/__init__.py


+ 52 - 0
utils/log_util.py

@@ -0,0 +1,52 @@
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+
+from libs.vision_libs.utils import draw_bounding_boxes
+from models.wirenet.postprocess import postprocess
+from torchvision import transforms
+
+def show_line(img, pred,  epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred[1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.8]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)