| :py:mod:`airflow.providers.google.cloud.utils.mlengine_operator_utils` |
| ====================================================================== |
| |
| .. py:module:: airflow.providers.google.cloud.utils.mlengine_operator_utils |
| |
| .. autoapi-nested-parse:: |
| |
| This module contains helper functions for MLEngine operators. |
| |
| |
| |
| Module Contents |
| --------------- |
| |
| |
| Functions |
| ~~~~~~~~~ |
| |
| .. autoapisummary:: |
| |
| airflow.providers.google.cloud.utils.mlengine_operator_utils.create_evaluate_ops |
| |
| |
| |
| Attributes |
| ~~~~~~~~~~ |
| |
| .. autoapisummary:: |
| |
| airflow.providers.google.cloud.utils.mlengine_operator_utils.T |
| |
| |
| .. py:data:: T |
| |
| |
| |
| |
| .. py:function:: create_evaluate_ops(task_prefix, data_format, input_paths, prediction_path, metric_fn_and_keys, validate_fn, batch_prediction_job_id = None, region = None, project_id = None, dataflow_options = None, model_uri = None, model_name = None, version_name = None, dag = None, py_interpreter='python3') |
| |
| Creates Operators needed for model evaluation and returns. |
| |
| It gets prediction over inputs via Cloud ML Engine BatchPrediction API by |
| calling MLEngineBatchPredictionOperator, then summarize and validate |
| the result via Cloud Dataflow using DataFlowPythonOperator. |
| |
| For details and pricing about Batch prediction, please refer to the website |
| https://cloud.google.com/ml-engine/docs/how-tos/batch-predict |
| and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/ |
| |
| It returns three chained operators for prediction, summary, and validation, |
| named as ``<prefix>-prediction``, ``<prefix>-summary``, and ``<prefix>-validation``, |
| respectively. |
| (``<prefix>`` should contain only alphanumeric characters or hyphen.) |
| |
| The upstream and downstream can be set accordingly like: |
| |
| .. code-block:: python |
| |
| pred, _, val = create_evaluate_ops(...) |
| pred.set_upstream(upstream_op) |
| ... |
| downstream_op.set_upstream(val) |
| |
| Callers will provide two python callables, metric_fn and validate_fn, in |
| order to customize the evaluation behavior as they wish. |
| |
| - metric_fn receives a dictionary per instance derived from json in the |
| batch prediction result. The keys might vary depending on the model. |
| It should return a tuple of metrics. |
| - validation_fn receives a dictionary of the averaged metrics that metric_fn |
| generated over all instances. |
| The key/value of the dictionary matches to what's given by |
| metric_fn_and_keys arg. |
| The dictionary contains an additional metric, 'count' to represent the |
| total number of instances received for evaluation. |
| The function would raise an exception to mark the task as failed, in a |
| case the validation result is not okay to proceed (i.e. to set the trained |
| version as default). |
| |
| Typical examples are like this: |
| |
| .. code-block:: python |
| |
| def get_metric_fn_and_keys(): |
| import math # imports should be outside of the metric_fn below. |
| |
| def error_and_squared_error(inst): |
| label = float(inst["input_label"]) |
| classes = float(inst["classes"]) # 0 or 1 |
| err = abs(classes - label) |
| squared_err = math.pow(classes - label, 2) |
| return (err, squared_err) # returns a tuple. |
| |
| return error_and_squared_error, ["err", "mse"] # key order must match. |
| |
| |
| def validate_err_and_count(summary): |
| if summary["err"] > 0.2: |
| raise ValueError("Too high err>0.2; summary=%s" % summary) |
| if summary["mse"] > 0.05: |
| raise ValueError("Too high mse>0.05; summary=%s" % summary) |
| if summary["count"] < 1000: |
| raise ValueError("Too few instances<1000; summary=%s" % summary) |
| return summary |
| |
| For the details on the other BatchPrediction-related arguments (project_id, |
| job_id, region, data_format, input_paths, prediction_path, model_uri), |
| please refer to MLEngineBatchPredictionOperator too. |
| |
| :param task_prefix: a prefix for the tasks. Only alphanumeric characters and |
| hyphen are allowed (no underscores), since this will be used as dataflow |
| job name, which doesn't allow other characters. |
| |
| :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP' |
| |
| :param input_paths: a list of input paths to be sent to BatchPrediction. |
| |
| :param prediction_path: GCS path to put the prediction results in. |
| |
| :param metric_fn_and_keys: a tuple of metric_fn and metric_keys: |
| |
| - metric_fn is a function that accepts a dictionary (for an instance), |
| and returns a tuple of metric(s) that it calculates. |
| |
| - metric_keys is a list of strings to denote the key of each metric. |
| |
| :param validate_fn: a function to validate whether the averaged metric(s) is |
| good enough to push the model. |
| |
| :param batch_prediction_job_id: the id to use for the Cloud ML Batch |
| prediction job. Passed directly to the MLEngineBatchPredictionOperator as |
| the job_id argument. |
| |
| :param project_id: the Google Cloud project id in which to execute |
| Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s |
| `default_args['project_id']` will be used. |
| |
| :param region: the Google Cloud region in which to execute Cloud ML |
| Batch Prediction and Dataflow jobs. If None, then the `dag`'s |
| `default_args['region']` will be used. |
| |
| :param dataflow_options: options to run Dataflow jobs. If None, then the |
| `dag`'s `default_args['dataflow_default_options']` will be used. |
| |
| :param model_uri: GCS path of the model exported by Tensorflow using |
| ``tensorflow.estimator.export_savedmodel()``. It cannot be used with |
| model_name or version_name below. See MLEngineBatchPredictionOperator for |
| more detail. |
| |
| :param model_name: Used to indicate a model to use for prediction. Can be |
| used in combination with version_name, but cannot be used together with |
| model_uri. See MLEngineBatchPredictionOperator for more detail. If None, |
| then the `dag`'s `default_args['model_name']` will be used. |
| |
| :param version_name: Used to indicate a model version to use for prediction, |
| in combination with model_name. Cannot be used together with model_uri. |
| See MLEngineBatchPredictionOperator for more detail. If None, then the |
| `dag`'s `default_args['version_name']` will be used. |
| |
| :param dag: The `DAG` to use for all Operators. |
| |
| :param py_interpreter: Python version of the beam pipeline. |
| If None, this defaults to the python3. |
| To track python versions supported by beam and related |
| issues check: https://issues.apache.org/jira/browse/BEAM-1251 |
| |
| :returns: a tuple of three operators, (prediction, summary, validation) |
| :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator, |
| PythonOperator) |
| |
| |