Source code for sasctl.pzmm.pickle_model

# Copyright (c) 2020, SAS Institute Inc., Cary, NC, USA.  All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# %%
import codecs
import gzip
import pickle
import shutil
from pathlib import Path
from typing import Any, Optional, Union

try:
    import h2o
except ImportError:
    h2o = None

from ..utils.misc import check_if_jupyter

PICKLE = ".pickle"
# TODO: Break up function to lower CC


[docs] class PickleModel: notebook_output: bool = check_if_jupyter()
[docs] @classmethod def pickle_trained_model( cls, model_prefix: str, trained_model: Optional[Any] = None, pickle_path: Union[str, Path, None] = None, is_h2o_model: bool = False, is_binary_model: bool = False, is_binary_string: bool = False, mlflow_details: Optional[dict] = None, ) -> Union[dict, str, None]: """ Write trained model to a binary pickle file, H2O MOJO file, or a binary string object. The following files are generated by this function: * '*.pickle' Binary pickle file containing a trained model. * '*.mojo' Archived H2O.ai MOJO file containing a trained model. Parameters --------------- model_prefix : str or Path Variable name for the model to be displayed in SAS Open Model Manager (i.e. hmeqClassTree + [Score.py || .pickle]). trained_model : model object The trained model to be exported. pickle_path : str, optional File location for the output pickle file. The default value is None. is_h2o_model : bool, optional Sets whether the model file is an H2O.ai MOJO file. If set as True, the MOJO file will be gzipped before uploading to SAS Model Manager. The default value is False. is_binary_model : bool, optional Sets whether the H2O model provided is a binary model or a MOJO model. The default value is False. is_binary_string : bool, optional Sets whether the model is to be set as a binary string instead of a pickle file. The default value is False. mlflow_details : dict, optional Model details from an MLFlow model. This dictionary is created by the readMLModelFile function. The default value is None. Returns ------- binary_string : binary str When the is_binary_string flag is set to True, return a binary string representation of the model instead of a pickle or MOJO file. dict Dictionary containing a key-value pair representing the file name and pickle dump respectively if pickle_path is None. This is not valid for H2O.ai models. """ if is_binary_string: # For models that use a binary string representation binary_string = codecs.encode( pickle.dumps(trained_model), "base64" ).decode() return binary_string elif mlflow_details: ml_pickle_path = ( Path(mlflow_details["mlflowPath"]) / mlflow_details["model_path"] ) if pickle_path: # For models imported from MLFlow shutil.copy(ml_pickle_path, pickle_path) pzmm_pickle_path = Path(pickle_path) / mlflow_details["model_path"] pzmm_pickle_path.rename(Path(pickle_path) / (model_prefix + PICKLE)) else: with open(ml_pickle_path, "rb") as pickle_file: return {model_prefix + PICKLE: pickle.load(pickle_file)} else: # For all other model types if not is_h2o_model: if pickle_path: with open( Path(pickle_path) / (model_prefix + PICKLE), "wb" ) as pickle_file: pickle.dump(trained_model, pickle_file) if cls.notebook_output: print( f"Model {model_prefix} was successfully pickled and saved " f"to {Path(pickle_path) / (model_prefix + PICKLE)}." ) else: return {model_prefix + PICKLE: pickle.dumps(trained_model)} # For binary H2O models, save the binary file as a "pickle" file elif is_h2o_model and is_binary_model and pickle_path: if not h2o: raise RuntimeError( "The h2o package is required to save the model as a binary h2o" "model." ) h2o.save_model( model=trained_model, force=True, path=str(pickle_path), filename=f"{model_prefix}.pickle", ) # For MOJO H2O models, save as a mojo file and adjust the extension to .mojo elif is_h2o_model and pickle_path: if not h2o: raise RuntimeError( "The h2o package is required to save the model as a mojo model." ) trained_model.save_mojo( force=True, path=str(pickle_path), filename=f"{model_prefix}.mojo" ) elif is_binary_model or is_h2o_model: raise ValueError( "There is currently no support for file-less H2O.ai model handling." " Please include a value for the pickle_path argument." )