import numpy as np
import torch
from abc import ABC, abstractmethod

from models.base.base_trainer import BaseTrainer


class BaseModel(ABC, torch.nn.Module):

    def __init__(self, **kwargs):
        super().__init__()

        self.cfg = None
        self.trainer = None

    @abstractmethod
    def train_by_cfg(self, cfg):
        return

    # @abstractmethod
    # def get_loss(self, Loss, results, inputs, device):
    #     """Computes the loss given the network input and outputs.
    #
    #     Args:
    #         Loss: A loss object.
    #         results: This is the output of the model.
    #         inputs: This is the input to the model.
    #         device: The torch device to be used.
    #
    #     Returns:
    #         Returns the loss value.
    #     """
    #     return
    #
    # @abstractmethod
    # def get_optimizer(self, cfg_pipeline):
    #     """Returns an optimizer object for the model.
    #
    #     Args:
    #         cfg_pipeline: A Config object with the configuration of the pipeline.
    #
    #     Returns:
    #         Returns a new optimizer object.
    #     """
    #     return
    #
    # @abstractmethod
    # def preprocess(self, cfg_pipeline):
    #     """Data preprocessing function.
    #
    #     This function is called before training to preprocess the data from a
    #     dataset.
    #
    #     Args:
    #         data: A sample from the dataset.
    #         attr: The corresponding attributes.
    #
    #     Returns:
    #         Returns the preprocessed data
    #     """
    #     return
    # #
    # # @abstractmethod
    # # def transform(self, cfg_pipeline):
    # #     """Transform function for the point cloud and features.
    # #
    # #     Args:
    # #         cfg_pipeline: config file for pipeline.
    # #     """
    # #     return
    #
    # @abstractmethod
    # def inference_begin(self, data):
    #     """Function called right before running inference.
    #
    #     Args:
    #         data: A data from the dataset.
    #     """
    #     return
    #
    # @abstractmethod
    # def inference_preprocess(self):
    #     """This function prepares the inputs for the model.
    #
    #     Returns:
    #         The inputs to be consumed by the call() function of the model.
    #     """
    #     return
    #
    # @abstractmethod
    # def inference_end(self, inputs, results):
    #     """This function is called after the inference.
    #
    #     This function can be implemented to apply post-processing on the
    #     network outputs.
    #
    #     Args:
    #         results: The model outputs as returned by the call() function.
    #             Post-processing is applied on this object.
    #
    #     Returns:
    #         Returns True if the inference is complete and otherwise False.
    #         Returning False can be used to implement inference for large point
    #         clouds which require multiple passes.
    #     """
    #     return