base_model.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import numpy as np
  2. import torch
  3. from abc import ABC, abstractmethod
  4. from models.base.base_trainer import BaseTrainer
  5. class BaseModel(ABC, torch.nn.Module):
  6. def __init__(self, **kwargs):
  7. super().__init__()
  8. self.cfg = None
  9. self.trainer = None
  10. @abstractmethod
  11. def train(self, cfg):
  12. return
  13. @abstractmethod
  14. def get_loss(self, Loss, results, inputs, device):
  15. """Computes the loss given the network input and outputs.
  16. Args:
  17. Loss: A loss object.
  18. results: This is the output of the model.
  19. inputs: This is the input to the model.
  20. device: The torch device to be used.
  21. Returns:
  22. Returns the loss value.
  23. """
  24. return
  25. @abstractmethod
  26. def get_optimizer(self, cfg_pipeline):
  27. """Returns an optimizer object for the model.
  28. Args:
  29. cfg_pipeline: A Config object with the configuration of the pipeline.
  30. Returns:
  31. Returns a new optimizer object.
  32. """
  33. return
  34. @abstractmethod
  35. def preprocess(self, cfg_pipeline):
  36. """Data preprocessing function.
  37. This function is called before training to preprocess the data from a
  38. dataset.
  39. Args:
  40. data: A sample from the dataset.
  41. attr: The corresponding attributes.
  42. Returns:
  43. Returns the preprocessed data
  44. """
  45. return
  46. @abstractmethod
  47. def transform(self, cfg_pipeline):
  48. """Transform function for the point cloud and features.
  49. Args:
  50. cfg_pipeline: config file for pipeline.
  51. """
  52. return
  53. @abstractmethod
  54. def inference_begin(self, data):
  55. """Function called right before running inference.
  56. Args:
  57. data: A data from the dataset.
  58. """
  59. return
  60. @abstractmethod
  61. def inference_preprocess(self):
  62. """This function prepares the inputs for the model.
  63. Returns:
  64. The inputs to be consumed by the call() function of the model.
  65. """
  66. return
  67. @abstractmethod
  68. def inference_end(self, inputs, results):
  69. """This function is called after the inference.
  70. This function can be implemented to apply post-processing on the
  71. network outputs.
  72. Args:
  73. results: The model outputs as returned by the call() function.
  74. Post-processing is applied on this object.
  75. Returns:
  76. Returns True if the inference is complete and otherwise False.
  77. Returning False can be used to implement inference for large point
  78. clouds which require multiple passes.
  79. """
  80. return