123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import numpy as np
- import torch
- from abc import ABC, abstractmethod
- from models.base.base_trainer import BaseTrainer
- class BaseModel(ABC, torch.nn.Module):
- def __init__(self, **kwargs):
- super().__init__()
- self.cfg = None
- self.trainer = None
- @abstractmethod
- def train_by_cfg(self, cfg):
- return
- # @abstractmethod
- # def get_loss(self, Loss, results, inputs, device):
- # """Computes the loss given the network input and outputs.
- #
- # Args:
- # Loss: A loss object.
- # results: This is the output of the model.
- # inputs: This is the input to the model.
- # device: The torch device to be used.
- #
- # Returns:
- # Returns the loss value.
- # """
- # return
- #
- # @abstractmethod
- # def get_optimizer(self, cfg_pipeline):
- # """Returns an optimizer object for the model.
- #
- # Args:
- # cfg_pipeline: A Config object with the configuration of the pipeline.
- #
- # Returns:
- # Returns a new optimizer object.
- # """
- # return
- #
- # @abstractmethod
- # def preprocess(self, cfg_pipeline):
- # """Data preprocessing function.
- #
- # This function is called before training to preprocess the data from a
- # dataset.
- #
- # Args:
- # data: A sample from the dataset.
- # attr: The corresponding attributes.
- #
- # Returns:
- # Returns the preprocessed data
- # """
- # return
- # #
- # # @abstractmethod
- # # def transform(self, cfg_pipeline):
- # # """Transform function for the point cloud and features.
- # #
- # # Args:
- # # cfg_pipeline: config file for pipeline.
- # # """
- # # return
- #
- # @abstractmethod
- # def inference_begin(self, data):
- # """Function called right before running inference.
- #
- # Args:
- # data: A data from the dataset.
- # """
- # return
- #
- # @abstractmethod
- # def inference_preprocess(self):
- # """This function prepares the inputs for the model.
- #
- # Returns:
- # The inputs to be consumed by the call() function of the model.
- # """
- # return
- #
- # @abstractmethod
- # def inference_end(self, inputs, results):
- # """This function is called after the inference.
- #
- # This function can be implemented to apply post-processing on the
- # network outputs.
- #
- # Args:
- # results: The model outputs as returned by the call() function.
- # Post-processing is applied on this object.
- #
- # Returns:
- # Returns True if the inference is complete and otherwise False.
- # Returning False can be used to implement inference for large point
- # clouds which require multiple passes.
- # """
- # return
|