| :py:mod:`airflow.providers.google.cloud.utils.mlengine_prediction_summary` |
| ========================================================================== |
| |
| .. py:module:: airflow.providers.google.cloud.utils.mlengine_prediction_summary |
| |
| .. autoapi-nested-parse:: |
| |
| A template called by DataFlowPythonOperator to summarize BatchPrediction. |
| |
| It accepts a user function to calculate the metric(s) per instance in |
| the prediction results, then aggregates to output as a summary. |
| |
| It accepts the following arguments: |
| |
| - ``--prediction_path``: |
| The GCS folder that contains BatchPrediction results, containing |
| ``prediction.results-NNNNN-of-NNNNN`` files in the json format. |
| Output will be also stored in this folder, as 'prediction.summary.json'. |
| - ``--metric_fn_encoded``: |
| An encoded function that calculates and returns a tuple of metric(s) |
| for a given instance (as a dictionary). It should be encoded |
| via ``base64.b64encode(dill.dumps(fn, recurse=True))``. |
| - ``--metric_keys``: |
| A comma-separated key(s) of the aggregated metric(s) in the summary |
| output. The order and the size of the keys must match to the output |
| of metric_fn. |
| The summary will have an additional key, 'count', to represent the |
| total number of instances, so the keys shouldn't include 'count'. |
| |
| |
| Usage example: |
| |
| .. code-block: python |
| |
| from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator |
| |
| |
| def get_metric_fn(): |
| import math # all imports must be outside of the function to be passed. |
| def metric_fn(inst): |
| label = float(inst["input_label"]) |
| classes = float(inst["classes"]) |
| prediction = float(inst["scores"][1]) |
| log_loss = math.log(1 + math.exp( |
| -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) |
| squared_err = (classes-label)**2 |
| return (log_loss, squared_err) |
| return metric_fn |
| metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) |
| DataflowCreatePythonJobOperator( |
| task_id="summary-prediction", |
| py_options=["-m"], |
| py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
| options={ |
| "prediction_path": prediction_path, |
| "metric_fn_encoded": metric_fn_encoded, |
| "metric_keys": "log_loss,mse" |
| }, |
| dataflow_default_options={ |
| "project": "xxx", "region": "us-east1", |
| "staging_location": "gs://yy", "temp_location": "gs://zz", |
| } |
| ) >> dag |
| |
| When the input file is like the following:: |
| |
| {"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} |
| {"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} |
| {"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} |
| {"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} |
| |
| The output file will be:: |
| |
| {"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} |
| |
| To test outside of the dag: |
| |
| .. code-block:: python |
| |
| subprocess.check_call( |
| [ |
| "python", |
| "-m", |
| "airflow.providers.google.cloud.utils.mlengine_prediction_summary", |
| "--prediction_path=gs://...", |
| "--metric_fn_encoded=" + metric_fn_encoded, |
| "--metric_keys=log_loss,mse", |
| "--runner=DataflowRunner", |
| "--staging_location=gs://...", |
| "--temp_location=gs://...", |
| ] |
| ) |
| |
| .. spelling:: |
| |
| pcoll |
| |
| |
| |
| Module Contents |
| --------------- |
| |
| Classes |
| ~~~~~~~ |
| |
| .. autoapisummary:: |
| |
| airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder |
| |
| |
| |
| Functions |
| ~~~~~~~~~ |
| |
| .. autoapisummary:: |
| |
| airflow.providers.google.cloud.utils.mlengine_prediction_summary.MakeSummary |
| airflow.providers.google.cloud.utils.mlengine_prediction_summary.run |
| |
| |
| |
| .. py:class:: JsonCoder |
| |
| JSON encoder/decoder. |
| |
| .. py:method:: encode(x) |
| :staticmethod: |
| |
| JSON encoder. |
| |
| |
| .. py:method:: decode(x) |
| :staticmethod: |
| |
| JSON decoder. |
| |
| |
| |
| .. py:function:: MakeSummary(pcoll, metric_fn, metric_keys) |
| |
| Summary PTransform used in Dataflow. |
| |
| |
| .. py:function:: run(argv=None) |
| |
| Helper for obtaining prediction summary. |
| |
| |