base_model.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import numpy as np
  2. import torch
  3. from abc import ABC, abstractmethod
  4. from models.base.base_trainer import BaseTrainer
  5. class BaseModel(ABC, torch.nn.Module):
  6. def __init__(self, **kwargs):
  7. super().__init__()
  8. self.cfg = None
  9. self.trainer = None
  10. @abstractmethod
  11. def train_by_cfg(self, cfg):
  12. return
  13. # @abstractmethod
  14. # def get_loss(self, Loss, results, inputs, device):
  15. # """Computes the loss given the network input and outputs.
  16. #
  17. # Args:
  18. # Loss: A loss object.
  19. # results: This is the output of the model.
  20. # inputs: This is the input to the model.
  21. # device: The torch device to be used.
  22. #
  23. # Returns:
  24. # Returns the loss value.
  25. # """
  26. # return
  27. #
  28. # @abstractmethod
  29. # def get_optimizer(self, cfg_pipeline):
  30. # """Returns an optimizer object for the model.
  31. #
  32. # Args:
  33. # cfg_pipeline: A Config object with the configuration of the pipeline.
  34. #
  35. # Returns:
  36. # Returns a new optimizer object.
  37. # """
  38. # return
  39. #
  40. # @abstractmethod
  41. # def preprocess(self, cfg_pipeline):
  42. # """Data preprocessing function.
  43. #
  44. # This function is called before training to preprocess the data from a
  45. # dataset.
  46. #
  47. # Args:
  48. # data: A sample from the dataset.
  49. # attr: The corresponding attributes.
  50. #
  51. # Returns:
  52. # Returns the preprocessed data
  53. # """
  54. # return
  55. # #
  56. # # @abstractmethod
  57. # # def transform(self, cfg_pipeline):
  58. # # """Transform function for the point cloud and features.
  59. # #
  60. # # Args:
  61. # # cfg_pipeline: config file for pipeline.
  62. # # """
  63. # # return
  64. #
  65. # @abstractmethod
  66. # def inference_begin(self, data):
  67. # """Function called right before running inference.
  68. #
  69. # Args:
  70. # data: A data from the dataset.
  71. # """
  72. # return
  73. #
  74. # @abstractmethod
  75. # def inference_preprocess(self):
  76. # """This function prepares the inputs for the model.
  77. #
  78. # Returns:
  79. # The inputs to be consumed by the call() function of the model.
  80. # """
  81. # return
  82. #
  83. # @abstractmethod
  84. # def inference_end(self, inputs, results):
  85. # """This function is called after the inference.
  86. #
  87. # This function can be implemented to apply post-processing on the
  88. # network outputs.
  89. #
  90. # Args:
  91. # results: The model outputs as returned by the call() function.
  92. # Post-processing is applied on this object.
  93. #
  94. # Returns:
  95. # Returns True if the inference is complete and otherwise False.
  96. # Returning False can be used to implement inference for large point
  97. # clouds which require multiple passes.
  98. # """
  99. # return