Source code for sasctl.pzmm.pickle_model
# Copyright (c) 2020, SAS Institute Inc., Cary, NC, USA. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# %%
import codecs
import pickle
import shutil
from pathlib import Path
from typing import Any, Optional, Union
try:
import h2o
except ImportError:
h2o = None
from ..utils.misc import check_if_jupyter
PICKLE = ".pickle"
# TODO: Break up function to lower CC
[docs]
class PickleModel:
notebook_output: bool = check_if_jupyter()
[docs]
@classmethod
def pickle_trained_model(
cls,
model_prefix: str,
trained_model: Optional[Any] = None,
pickle_path: Union[str, Path, None] = None,
is_h2o_model: bool = False,
is_binary_model: bool = False,
is_binary_string: bool = False,
mlflow_details: Optional[dict] = None,
) -> Union[dict, str, None]:
"""
Write trained model to a binary pickle file, H2O MOJO file, or a binary string
object.
The following files are generated by this function:
* '\*.pickle'
Binary pickle file containing a trained model.
* '\*.mojo'
Archived H2O.ai MOJO file containing a trained model.
Parameters
---------------
model_prefix : str or pathlib.Path
Variable name for the model to be displayed in SAS Open Model Manager
(i.e. hmeqClassTree + [Score.py || .pickle]).
trained_model : Any
The trained model to be exported.
pickle_path : str, optional
File location for the output pickle file. The default value is None.
is_h2o_model : bool, optional
Sets whether the model file is an H2O.ai MOJO file. If set as True,
the MOJO file will be gzipped before uploading to SAS Model Manager.
The default value is False.
is_binary_model : bool, optional
Sets whether the H2O model provided is a binary model or a MOJO model. The
default value is False.
is_binary_string : bool, optional
Sets whether the model is to be set as a binary string instead of a pickle
file. The default value is False.
mlflow_details : dict, optional
Model details from an MLFlow model. This dictionary is created by the
readMLModelFile function. The default value is None.
Returns
-------
binary_string : str
When the is_binary_string flag is set to True, return a binary string
representation of the model instead of a pickle or MOJO file.
dict
Dictionary containing a key-value pair representing the file name and pickle
dump respectively if pickle_path is None. This is not valid for H2O.ai
models.
"""
from .write_score_code import ScoreCode
sanitized_prefix = ScoreCode.sanitize_model_prefix(model_prefix)
if is_binary_string:
# For models that use a binary string representation
binary_string = codecs.encode(
pickle.dumps(trained_model), "base64"
).decode()
return binary_string
elif mlflow_details:
ml_pickle_path = (
Path(mlflow_details["mlflowPath"]) / mlflow_details["model_path"]
)
if pickle_path:
# For models imported from MLFlow
shutil.copy(ml_pickle_path, pickle_path)
pzmm_pickle_path = Path(pickle_path) / mlflow_details["model_path"]
pzmm_pickle_path.rename(Path(pickle_path) / (sanitized_prefix + PICKLE))
else:
with open(ml_pickle_path, "rb") as pickle_file:
return {sanitized_prefix + PICKLE: pickle.load(pickle_file)}
else:
# For all other model types
if not is_h2o_model:
if pickle_path:
with open(
Path(pickle_path) / (sanitized_prefix + PICKLE), "wb"
) as pickle_file:
pickle.dump(trained_model, pickle_file)
if cls.notebook_output:
print(
f"Model {model_prefix} was successfully pickled and saved "
f"to {Path(pickle_path) / (sanitized_prefix + PICKLE)}."
)
else:
return {sanitized_prefix + PICKLE: pickle.dumps(trained_model)}
# For binary H2O models, save the binary file as a "pickle" file
elif is_h2o_model and is_binary_model and pickle_path:
if not h2o:
raise RuntimeError(
"The h2o package is required to save the model as a binary h2o"
"model."
)
h2o.save_model(
model=trained_model,
force=True,
path=str(pickle_path),
filename=f"{sanitized_prefix}.pickle",
)
# For MOJO H2O models, save as a mojo file and adjust the extension to .mojo
elif is_h2o_model and pickle_path:
if not h2o:
raise RuntimeError(
"The h2o package is required to save the model as a mojo model."
)
trained_model.save_mojo(
force=True,
path=str(pickle_path),
filename=f"{sanitized_prefix}.mojo",
)
elif is_binary_model or is_h2o_model:
raise ValueError(
"There is currently no support for file-less H2O.ai model handling."
" Please include a value for the pickle_path argument."
)