basic_model_handler.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2022/2/25 13:36
  3. # @Author : MaochengHu
  4. # @Email : wojiaohumaocheng@gmail.com
  5. # @File : basic_model_handler.py
  6. # @Project : server_develop
  7. import torch
  8. class BasicModelHandler(object):
  9. def __init__(self):
  10. self.device = None
  11. self.context = None
  12. self.map_location = None
  13. self.manifest = None
  14. def initialize(self, context):
  15. """Initialize function loads the model.pt file and initialized the model object.
  16. First try to load torchscript else load eager mode state_dict based model.
  17. Args:
  18. context (context): It is a JSON Object containing information
  19. pertaining to the model artifacts parameters.
  20. Raises:
  21. RuntimeError: Raises the Runtime error when the model.py is missing
  22. """
  23. properties = context.system_properties
  24. self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
  25. self.device = torch.device(
  26. self.map_location + ":" + str(properties.get("gpu_id"))
  27. if torch.cuda.is_available() and properties.get("gpu_id") is not None
  28. else self.map_location
  29. )
  30. self.manifest = context.manifest
  31. def preprocess(self, ata, cuda=True):
  32. pass
  33. def inference(self, data):
  34. pass
  35. def postprocess(self, data, output_methods):
  36. pass
  37. def handle(self, data, context):
  38. pass
  39. def explain_handle(self, data_preprocess, raw_data):
  40. """Captum explanations handler
  41. Args:
  42. data_preprocess (Torch Tensor): Preprocessed data to be used for captum
  43. raw_data (list): The unprocessed data to get target from the request
  44. Returns:
  45. dict : A dictionary response with the explanations response.
  46. """
  47. pass
  48. def _is_explain(self):
  49. pass
  50. def _is_describe(self):
  51. pass
  52. def describe_handle(self):
  53. """Customized describe handler
  54. Returns:
  55. dict : A dictionary response.
  56. """
  57. # pylint: disable=unnecessary-pass
  58. pass
  59. # pylint: enable=unnecessary-pass
  60. def main():
  61. pass
  62. if __name__ == "__main__":
  63. main()