# Copyright (c) 2022, SAS Institute Inc., Cary, NC, USA. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
from distutils.version import StrictVersion
from pathlib import Path
from typing import Any, Optional, Tuple, Union
import pandas as pd
from pandas import DataFrame
from .._services.model_repository import ModelRepository as mr
from ..core import RestObj, current_session, is_uuid
try:
import xgboost
except:
xgboost = None
MODEL_PROPERTIES = [
("targetVariable", "targetVariable"),
("targetLevel", "targetLevel"),
("targetEventValue", "targetEvent"),
("eventProbabilityVariable", "eventProbVar"),
("function", "function"),
]
# TODO: Maybe just move _find_file altogether?
def _find_file(model: Union[str, dict, RestObj], file_name: str) -> Tuple[RestObj, str]:
"""
Retrieves the contents of the first file from a registered model on SAS Model
Manager that contains the provided file_name as an exact match or substring.
Parameters
----------
model : str or dict
The name or id of the model, or a dictionary representation of the model.
file_name : str
The name of the desired file or a substring that is contained within the file
name.
Returns
-------
RestObj, str
The contents and name of the first file with a name containing file_name.
"""
file_list = mr.get_model_contents(model)
for file in file_list:
if file_name.lower() in file.name.lower():
correct_file = mr.get(f"models/{model}/contents/{file.id}/content")
return correct_file, file.name
raise ValueError(f'No file containing "{file_name}" exists within model files.')
[docs]
class ModelParameters:
@staticmethod
def _update_json(model: str, model_json: dict, kpis: DataFrame) -> dict:
"""
Updates the contents of the hyperparameter json file
Parameters
----------
model: str
The id of the model being updated.
model_json: dict
The contents of the current KPI/parameters file within SAS Model Manager.
kpis: pandas.DataFrame
The dataframe containing the KPI/parameter values stored within SAS Model
Manager at runtime.
Returns
-------
dict
The updated hyperparameter json file to be uploaded to SAS Model Manager.
"""
model_rows = kpis.loc[kpis["ModelUUID"] == model]
if not model_rows.empty:
model_rows = model_rows.drop(columns=["ModelUUID"])
model_rows.set_index("TimeLabel", inplace=True)
kpi_json = model_rows.to_json(orient="index")
parsed_json = json.loads(kpi_json)
model_json["kpis"] = parsed_json
return model_json
[docs]
@staticmethod
def generate_hyperparameters(
model: Any, model_prefix: str, pickle_path: Union[str, Path]
) -> None:
"""
Generates hyperparameters for a given model and creates a JSON file
representation.
Currently only supports generation of scikit-learn model hyperparameters.
This function creates a json file named {model_prefix}Hyperparameters.json.
Parameters
----------
model : Any
Python object representing the model.
model_prefix : str
Name used to create model files. (e.g. (model_prefix) +
"Hyperparameters.json")
pickle_path : str, pathlib.Path
Directory location of model files.
"""
def sklearn_params():
"""
Generates hyperparameters for the models generated by scikit-learn.
"""
hyperparameters = model.get_params()
model_json = {"hyperparameters": hyperparameters}
with open(
Path(pickle_path) / f"{model_prefix}Hyperparameters.json", "w"
) as f:
f.write(json.dumps(model_json, indent=4))
def tf_params():
"""
Generates hyperparameters for the models generated by tensorflow.
"""
hyperparameters = model.get_config()
model_json = {"hyperparameters": hyperparameters}
with open(
Path(pickle_path) / f"{model_prefix}Hyperparameters.json", "w"
) as f:
f.write(json.dumps(model_json, indent=4))
def xg_params():
"""
Generates hyperparameters for the models generated by xgboost.
"""
if not xgboost:
raise RuntimeError(
"XGBoost is required to generate xgboost hyperparameters."
)
hyperparameters = json.loads(model.save_config())
model_json = {"hyperparameters": hyperparameters}
with open(
Path(pickle_path) / f"{model_prefix}Hyperparameters.json", "w"
) as f:
f.write(json.dumps(model_json, indent=4))
def h2o_params():
"""
Generates hyperparameters for the models generated by h2o.ai.
"""
hyperparameters = model.get_params()
model_json = {"hyperparameters": hyperparameters}
with open(
Path(pickle_path) / f"{model_prefix}Hyperparameters.json", "w"
) as f:
f.write(json.dumps(model_json, indent=4))
def statsmodels_params():
"""
Generates hyperparameters for the models generated by statsmodels.
"""
hyperparameters = dict()
hyperparameters["model_type"] = model.__class__.__name__
hyperparameters["input_variables"] = model.exog_names
hyperparameters["weights"] = model.weights.tolist()
model_json = {"hyperparameters": hyperparameters}
with open(
Path(pickle_path) / f"{model_prefix}Hyperparameters.json", "w"
) as f:
f.write(json.dumps(model_json, indent=4))
if model.__class__.__module__.__contains__("sklearn"):
sklearn_params()
elif model.__class__.__module__.startswith("keras"):
tf_params()
elif model.__class__.__module__.startswith("xgboost"):
xg_params()
elif model.__class__.__module__.startswith("h2o"):
h2o_params()
elif model.__class__.__module__.startswith("statsmodels"):
statsmodels_params()
else:
raise ValueError(
"This model type is not currently supported for hyperparameter "
"generation."
)
[docs]
@classmethod
def update_kpis(
cls,
project: Union[str, dict, RestObj],
server: Optional[str] = "cas-shared-default",
caslib: Optional[str] = "ModelPerformanceData",
) -> None:
"""
Updates hyperparameter file to include KPIs generated by performance
definitions, as well as any custom KPIs imported by user to the SAS KPI data
table.
Parameters
----------
project : str, dict, or RestObj
The name or id of the project, or a dictionary representation of the
project.
server : str, optional
Server on which the KPI data table is stored. The default value is
"cas-shared-default".
caslib : str, optional
CAS Library on which the KPI data table is stored. The default value is
"ModelPerformanceData".
"""
kpis = cls.get_project_kpis(project, server, caslib)
models_to_update = kpis["ModelUUID"].unique().tolist()
for model in models_to_update:
try:
current_params, file_name = _find_file(model, "hyperparameters")
except:
print(
f'No hyperparameter file for current model {kpis.loc[kpis["ModelUUID"]==model, "ModelName"].iloc[0]}. Attempting for next model...'
)
else:
updated_json = cls._update_json(model, current_params, kpis)
mr.add_model_content(
model, json.dumps(updated_json, indent=4), file_name
)
[docs]
@staticmethod
def get_hyperparameters(model: Union[str, dict, RestObj]) -> Tuple[dict, str]:
"""
Retrieves the hyperparameter json file from specified model on SAS Model
Manager.
Parameters
----------
model : str, dict, or RestObj
The name or id of the model, or a dictionary representation of the model.
Returns
-------
dict, str
Dictionary containing the contents of the hyperparameter file and the file
name.
"""
if mr.is_uuid(model):
id_ = model
elif isinstance(model, dict) and "id" in model:
id_ = model["id"]
else:
model = mr.get_model(model)
id_ = model["id"]
file_contents, file_name = _find_file(id_, "hyperparameters")
return file_contents, file_name
[docs]
@classmethod
def add_hyperparameters(cls, model: Union[str, dict, RestObj], **kwargs) -> None:
"""
Adds custom hyperparameters to the hyperparameter file contained within the
model in SAS Model Manager.
Parameters
----------
model : str, dict, or RestObj
The name or id of the model, or a dictionary representation of the model.
**kwargs
Named variables pairs representing hyperparameters to be added to the
hyperparameter file.
"""
if mr.is_uuid(model):
id_ = model
elif isinstance(model, dict) and "id" in model:
id_ = model["id"]
else:
model = mr.get_model(model)
id_ = model["id"]
hyperparameters, file_name = cls.get_hyperparameters(id_)
for key, value in kwargs.items():
hyperparameters["hyperparameters"][key] = value
mr.add_model_content(
model,
json.dumps(hyperparameters, indent=4),
file_name,
)
[docs]
@staticmethod
def get_project_kpis(
project: Union[str, dict, RestObj],
server: Optional[str] = "cas-shared-default",
caslib: Optional[str] = "ModelPerformanceData",
filter_column: Optional[str] = None,
filter_value: Optional[str] = None,
) -> DataFrame:
"""
Create a call to CAS to return the MM_STD_KPI table (SAS Model Manager
Standard KPI) generated when custom KPIs are uploaded or when a performance
definition is executed on SAS Model Manager on SAS Viya 4.
Filtering options are available as additional arguments. The filtering is based
on column name and column value. Currently, only exact matches are available
when filtering by this method.
Parameters
----------
project : str, dict, RestObj
The name or id of the project, or a dictionary representation of the
project.
server : str, optional
SAS Viya 4 server where the MM_STD_KPI table exists. The default value is
"cas-shared-default".
caslib : str, optional
SAS Viya 4 caslib where the MM_STD_KPI table exists. The default value is
"ModelPerformanceData".
filter_column : str, optional
Column name from the MM_STD_KPI table to be filtered. The default value is
None.
filter_value : str, optional
Column value filter by. The default value is None
Returns
-------
kpi_table_df : pandas.DataFrame
A pandas DataFrame representing the MM_STD_KPI table. Note that SAS
missing values are replaced with pandas-valid missing values.
"""
# Check the pandas version for where the json_normalize function exists
if pd.__version__ >= StrictVersion("1.0.3"):
from pandas import json_normalize
else:
from pandas.io.json import json_normalize
# Collect the current session for authentication of API calls
sess = current_session()
# Step through options to determine project UUID
if is_uuid(project):
project_id = project
elif isinstance(project, dict) and "id" in project:
project_id = project["id"]
else:
project = mr.get_project(project)
project_id = project["id"]
# TODO: include case for large MM_STD_KPI tables
# Call the casManagement service to collect the column names in the table
kpi_table_columns = sess.get(
f"casManagement/servers/{server}/"
+ f"caslibs/{caslib}/tables/"
+ f"{project_id}.MM_STD_KPI/columns?limit=10000"
)
if not kpi_table_columns:
project = mr.get_project(project)
raise SystemError(
f"No KPI table exists for project {project.name}."
+ " Please confirm that the performance definition completed"
+ " or custom KPIs have been uploaded successfully."
)
# Parse through the json response to create a pandas DataFrame
cols = json_normalize(kpi_table_columns.json(), "items")
# Convert the columns to a readable list
col_names = cols["name"].to_list()
# Filter rows returned by column and value provided in arguments
where_statement = ""
if filter_column and filter_value:
where_statement = f"&where={filter_column}='{filter_value}'"
# Call the casRowSets service to return row values
# Optional where statement is included
kpi_table_rows = sess.get(
f"casRowSets/servers/{server}/"
+ f"caslibs/{caslib}/tables/"
+ f"{project_id}.MM_STD_KPI/rows?limit=10000"
+ f"{where_statement}"
)
# If no "cells" are found in the json response, return an error
try:
kpi_table_df = pd.DataFrame(
json_normalize(kpi_table_rows.json()["items"])["cells"].to_list(),
columns=col_names,
)
except KeyError:
if filter_column and filter_value:
raise SystemError(
"No KPIs were found when filtering with {filter_column}='{"
"filter_value}'."
)
else:
project_name = mr.get_project(project)["name"]
raise SystemError(f"No KPIs were found for project {project_name}.")
# Strip leading spaces from cells of KPI table; convert missing values to None
kpi_table_df = kpi_table_df.apply(lambda x: x.str.strip()).replace(
{".": None, "": None}
)
return kpi_table_df
[docs]
@staticmethod
def sync_model_properties(
project: Union[str, dict, RestObj], overrwrite: Optional[bool] = False
):
# Step through options to determine project UUID
if is_uuid(project):
project_id = project
elif isinstance(project, dict) and "id" in project:
project_id = project["id"]
else:
project = mr.get_project(project)
project_id = project["id"]
# Get List of Models that exist in project
models = mr.get(f"/projects/{project_id}/models")
model_ids = [model.id for model in models]
for id in model_ids:
model = mr.get_model(id)
for project_property, model_property in MODEL_PROPERTIES:
# Check if property is set in project
if project_property in project:
# If property is set in project, check if it's set in model, and update model accordingly
if model_property not in model or overrwrite:
model[model_property] = project[project_property]
mr.update_model(model)