blob: c8c580fc6c5258ed90a58034c7424389a4375180 [file] [log] [blame]
import json
from functools import singledispatch
from typing import Any, Dict
import pandas as pd
from hamilton_sdk.tracking import sql_utils
@singledispatch
def compute_stats(result, node_name: str, node_tags: dict) -> Dict[str, Any]:
"""This is the default implementation for computing stats on a result.
All other implementations should be registered with the `@compute_stats.register` decorator.
:param result:
:param node_name:
:param node_tags:
:return:
"""
return {
"observability_type": "unsupported",
"observability_value": {
"unsupported_type": str(type(result)),
"action": "reach out to the DAGWorks team to add support for this type.",
},
"observability_schema_version": "0.0.1",
}
@compute_stats.register(str)
@compute_stats.register(int)
@compute_stats.register(float)
@compute_stats.register(bool)
def compute_stats_primitives(result, node_name: str, node_tags: dict) -> Dict[str, Any]:
return {
"observability_type": "primitive",
"observability_value": {
"type": str(type(result)),
"value": result,
},
"observability_schema_version": "0.0.1",
}
@compute_stats.register(dict)
def compute_stats_dict(result: dict, node_name: str, node_tags: dict) -> Dict[str, Any]:
"""call summary stats on the values in the dict"""
try:
# if it's JSON serializable, take it.
json.dumps(result)
result_values = result
except Exception:
result_values = {}
for k, v in result.items():
# go through each value
if isinstance(v, (str, int, float, bool)):
result_values[k] = v
continue
# else it's a dict, list, tuple, etc. Compute stats.
v_result = compute_stats(v, node_name, node_tags)
# determine what to pull out of the result for the value
observed_type = v_result["observability_type"]
if observed_type == "primitive":
result_values[k] = v
elif observed_type == "unsupported":
str_value = str(v)
# else just string it -- max 200 chars.
if len(str_value) > 200:
str_value = str_value[:200] + "..."
result_values[k] = str_value
else:
# it's a DF, Series -- so take full result.
result_values[k] = v_result["observability_value"]
return {
"observability_type": "dict",
"observability_value": {
"type": str(type(result)),
"value": result_values,
},
"observability_schema_version": "0.0.2",
}
@compute_stats.register(tuple)
def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> Dict[str, Any]:
if "hamilton.data_loader" in node_tags and node_tags["hamilton.data_loader"] is True:
# assumption it's a tuple
if isinstance(result[1], dict):
try:
# double check that it's JSON serializable
raw_data = json.dumps(result[1])
_metadata = json.loads(raw_data)
except Exception:
_metadata = str(result[1])
if len(_metadata) > 1000:
_metadata = _metadata[:1000] + "..."
else:
# enrich it
if (
"SQL_QUERY" in _metadata
): # we might need to think how to make this a constant...
_metadata["QUERIED_TABLES"] = sql_utils.parse_sql_query(_metadata["SQL_QUERY"])
if isinstance(result[0], pd.DataFrame):
# TODO: move this to dataframe stats collection
_memory = result[0].memory_usage(deep=True)
_metadata["DF_MEMORY_TOTAL"] = int(_memory.sum())
_metadata["DF_MEMORY_BREAKDOWN"] = _memory.to_dict()
return {
"observability_type": "dict",
"observability_value": {
"type": str(type(result[1])),
"value": _metadata,
},
"observability_schema_version": "0.0.2",
}
return {
"observability_type": "unsupported",
"observability_value": {
"unsupported_type": str(type(result)),
"action": "reach out to the DAGWorks team to add support for this type.",
},
"observability_schema_version": "0.0.1",
}
@compute_stats.register(list)
def compute_stats_list(result: list, node_name: str, node_tags: dict) -> Dict[str, Any]:
"""call summary stats on the values in the list"""
try:
# if it's JSON serializable, take it.
json.dumps(result)
result_values = result
except Exception:
result_values = []
for v in result:
if isinstance(v, (list, dict, tuple)):
try:
json.dumps(v)
except Exception:
v = str(v)
# else just string it -- max 200 chars.
if len(v) > 200:
v = v[:200] + "..."
else:
v_result = compute_stats(v, node_name, node_tags)
# determine what to pull out of the result for the value
observed_type = v_result["observability_type"]
if observed_type == "dagworks_describe":
# it's a DF, Series -- so take full result.
v = v_result["observability_value"]
elif observed_type == "unsupported":
v = str(v)
# else just string it -- max 200 chars.
if len(v) > 200:
v = v[:200] + "..."
result_values.append(v)
return {
# yes dict type -- that's so that we can display in the UI. It's a hack.
"observability_type": "dict",
"observability_value": {
"type": str(type(result)),
"value": result_values,
},
"observability_schema_version": "0.0.2",
}