blob: 24594f0e1b06812ce39ba120d62e16b51b139c2a [file] [log] [blame]
:mod:`airflow.contrib.hooks.sagemaker_hook`
===========================================
.. py:module:: airflow.contrib.hooks.sagemaker_hook
Module Contents
---------------
.. py:class:: LogState
Bases: :class:`object`
.. attribute:: STARTING
:annotation: = 1
.. attribute:: WAIT_IN_PROGRESS
:annotation: = 2
.. attribute:: TAILING
:annotation: = 3
.. attribute:: JOB_COMPLETE
:annotation: = 4
.. attribute:: COMPLETE
:annotation: = 5
.. data:: Position
.. function:: argmin(arr, f)
Return the index, i, in arr that minimizes f(arr[i])
.. function:: secondary_training_status_changed(current_job_description, prev_job_description)
Returns true if training job's secondary status message has changed.
:param current_job_description: Current job description, returned from DescribeTrainingJob call.
:type current_job_description: dict
:param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
:type prev_job_description: dict
:return: Whether the secondary status message of a training job changed or not.
.. function:: secondary_training_status_message(job_description, prev_description)
Returns a string contains start time and the secondary training job status message.
:param job_description: Returned response from DescribeTrainingJob call
:type job_description: dict
:param prev_description: Previous job description from DescribeTrainingJob call
:type prev_description: dict
:return: Job status string to be printed.
.. py:class:: SageMakerHook(*args, **kwargs)
Bases: :class:`airflow.contrib.hooks.aws_hook.AwsHook`
Interact with Amazon SageMaker.
.. attribute:: non_terminal_states
.. attribute:: endpoint_non_terminal_states
.. attribute:: failed_states
.. method:: tar_and_s3_upload(self, path, key, bucket)
Tar the local file or directory and upload to s3
:param path: local file or directory
:type path: str
:param key: s3 key
:type key: str
:param bucket: s3 bucket
:type bucket: str
:return: None
.. method:: configure_s3_resources(self, config)
Extract the S3 operations from the configuration and execute them.
:param config: config of SageMaker operation
:type config: dict
:rtype: dict
.. method:: check_s3_url(self, s3url)
Check if an S3 URL exists
:param s3url: S3 url
:type s3url: str
:rtype: bool
.. method:: check_training_config(self, training_config)
Check if a training configuration is valid
:param training_config: training_config
:type training_config: dict
:return: None
.. method:: check_tuning_config(self, tuning_config)
Check if a tuning configuration is valid
:param tuning_config: tuning_config
:type tuning_config: dict
:return: None
.. method:: get_conn(self)
Establish an AWS connection for SageMaker
:rtype: :py:class:`SageMaker.Client`
.. method:: get_log_conn(self)
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_conn` instead.
.. method:: log_stream(self, log_group, stream_name, start_time=0, skip=0)
This method is deprecated.
Please use :py:meth:`airflow.contrib.hooks.AwsLogsHook.get_log_events` instead.
.. method:: multi_stream_iter(self, log_group, streams, positions=None)
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
:param log_group: The name of the log group.
:type log_group: str
:param streams: A list of the log stream names. The position of the stream in this list is
the stream number.
:type streams: list
:param positions: A list of pairs of (timestamp, skip) which represents the last record
read from each stream.
:type positions: list
:return: A tuple of (stream number, cloudwatch log event).
.. method:: create_training_job(self, config, wait_for_completion=True, print_log=True, check_interval=30, max_ingestion_time=None)
Create a training job
:param config: the config for training
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to training job creation
.. method:: create_tuning_job(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None)
Create a tuning job
:param config: the config for tuning
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to tuning job creation
.. method:: create_transform_job(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None)
Create a transform job
:param config: the config for transform job
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to transform job creation
.. method:: create_model(self, config)
Create a model job
:param config: the config for model
:type config: dict
:return: A response to model creation
.. method:: create_endpoint_config(self, config)
Create an endpoint config
:param config: the config for endpoint-config
:type config: dict
:return: A response to endpoint config creation
.. method:: create_endpoint(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None)
Create an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint creation
.. method:: update_endpoint(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None)
Update an endpoint
:param config: the config for endpoint
:type config: dict
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: A response to endpoint update
.. method:: describe_training_job(self, name)
Return the training job info associated with the name
:param name: the name of the training job
:type name: str
:return: A dict contains all the training job info
.. method:: describe_training_job_with_log(self, job_name, positions, stream_names, instance_count, state, last_description, last_describe_job_call)
Return the training job info associated with job_name and print CloudWatch logs
.. method:: describe_tuning_job(self, name)
Return the tuning job info associated with the name
:param name: the name of the tuning job
:type name: str
:return: A dict contains all the tuning job info
.. method:: describe_model(self, name)
Return the SageMaker model info associated with the name
:param name: the name of the SageMaker model
:type name: str
:return: A dict contains all the model info
.. method:: describe_transform_job(self, name)
Return the transform job info associated with the name
:param name: the name of the transform job
:type name: str
:return: A dict contains all the transform job info
.. method:: describe_endpoint_config(self, name)
Return the endpoint config info associated with the name
:param name: the name of the endpoint config
:type name: str
:return: A dict contains all the endpoint config info
.. method:: describe_endpoint(self, name)
:param name: the name of the endpoint
:type name: str
:return: A dict contains all the endpoint info
.. method:: check_status(self, job_name, key, describe_function, check_interval, max_ingestion_time, non_terminal_states=None)
Check status of a SageMaker job
:param job_name: name of the job to check status
:type job_name: str
:param key: the key of the response dict
that points to the state
:type key: str
:param describe_function: the function used to retrieve the status
:type describe_function: python callable
:param args: the arguments for the function
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker job
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:param non_terminal_states: the set of nonterminal states
:type non_terminal_states: set
:return: response of describe call after job is done
.. method:: check_training_status_with_log(self, job_name, non_terminal_states, failed_states, wait_for_completion, check_interval, max_ingestion_time)
Display the logs for a given training job, optionally tailing them until the
job is complete.
:param job_name: name of the training job to check status and display logs for
:type job_name: str
:param non_terminal_states: the set of non_terminal states
:type non_terminal_states: set
:param failed_states: the set of failed states
:type failed_states: set
:param wait_for_completion: Whether to keep looking for new log entries
until the job completes
:type wait_for_completion: bool
:param check_interval: The interval in seconds between polling for new log entries and job completion
:type check_interval: int
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
:type max_ingestion_time: int
:return: None