123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- # -*- coding: utf-8 -*-
- # @Time : 2022/2/25 13:36
- # @Author : MaochengHu
- # @Email : wojiaohumaocheng@gmail.com
- # @File : basic_model_handler.py
- # @Project : server_develop
- import torch
- class BasicModelHandler(object):
- def __init__(self):
- self.device = None
- self.context = None
- self.map_location = None
- self.manifest = None
- def initialize(self, context):
- """Initialize function loads the model.pt file and initialized the model object.
- First try to load torchscript else load eager mode state_dict based model.
- Args:
- context (context): It is a JSON Object containing information
- pertaining to the model artifacts parameters.
- Raises:
- RuntimeError: Raises the Runtime error when the model.py is missing
- """
- properties = context.system_properties
- self.map_location = "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
- self.device = torch.device(
- self.map_location + ":" + str(properties.get("gpu_id"))
- if torch.cuda.is_available() and properties.get("gpu_id") is not None
- else self.map_location
- )
- self.manifest = context.manifest
- def preprocess(self, ata, cuda=True):
- pass
- def inference(self, data):
- pass
- def postprocess(self, data, output_methods):
- pass
- def handle(self, data, context):
- pass
- def explain_handle(self, data_preprocess, raw_data):
- """Captum explanations handler
- Args:
- data_preprocess (Torch Tensor): Preprocessed data to be used for captum
- raw_data (list): The unprocessed data to get target from the request
- Returns:
- dict : A dictionary response with the explanations response.
- """
- pass
- def _is_explain(self):
- pass
- def _is_describe(self):
- pass
- def describe_handle(self):
- """Customized describe handler
- Returns:
- dict : A dictionary response.
- """
- # pylint: disable=unnecessary-pass
- pass
- # pylint: enable=unnecessary-pass
- def main():
- pass
- if __name__ == "__main__":
- main()
|