123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import numpy as np
- import torch
- from abc import ABC, abstractmethod
- from models.base.base_trainer import BaseTrainer
- class BaseModel(ABC, torch.nn.Module):
- def __init__(self, **kwargs):
- super().__init__()
- self.cfg = None
- self.trainer = None
- @abstractmethod
- def train(self, cfg):
- return
- @abstractmethod
- def get_loss(self, Loss, results, inputs, device):
- """Computes the loss given the network input and outputs.
- Args:
- Loss: A loss object.
- results: This is the output of the model.
- inputs: This is the input to the model.
- device: The torch device to be used.
- Returns:
- Returns the loss value.
- """
- return
- @abstractmethod
- def get_optimizer(self, cfg_pipeline):
- """Returns an optimizer object for the model.
- Args:
- cfg_pipeline: A Config object with the configuration of the pipeline.
- Returns:
- Returns a new optimizer object.
- """
- return
- @abstractmethod
- def preprocess(self, cfg_pipeline):
- """Data preprocessing function.
- This function is called before training to preprocess the data from a
- dataset.
- Args:
- data: A sample from the dataset.
- attr: The corresponding attributes.
- Returns:
- Returns the preprocessed data
- """
- return
- @abstractmethod
- def transform(self, cfg_pipeline):
- """Transform function for the point cloud and features.
- Args:
- cfg_pipeline: config file for pipeline.
- """
- return
- @abstractmethod
- def inference_begin(self, data):
- """Function called right before running inference.
- Args:
- data: A data from the dataset.
- """
- return
- @abstractmethod
- def inference_preprocess(self):
- """This function prepares the inputs for the model.
- Returns:
- The inputs to be consumed by the call() function of the model.
- """
- return
- @abstractmethod
- def inference_end(self, inputs, results):
- """This function is called after the inference.
- This function can be implemented to apply post-processing on the
- network outputs.
- Args:
- results: The model outputs as returned by the call() function.
- Post-processing is applied on this object.
- Returns:
- Returns True if the inference is complete and otherwise False.
- Returning False can be used to implement inference for large point
- clouds which require multiple passes.
- """
- return
|