blob: 1966100430e98546dac48a0e68c310d724713b8f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This file contains the pipeline options to configure
the Dataflow pipeline."""
from datetime import datetime
from typing import Any
import config as cfg
from apache_beam.options.pipeline_options import PipelineOptions
def get_pipeline_options(
project: str,
job_name: str,
mode: str,
device: str,
num_workers: int = cfg.NUM_WORKERS,
**kwargs: Any,
) -> PipelineOptions:
"""Function to retrieve the pipeline options.
Args:
project: GCP project to run on
mode: Indicator to run local, cloud or template
num_workers: Number of Workers for running the job parallely
Returns:
Dataflow pipeline options
"""
job_name = f'{job_name}-{datetime.now().strftime("%Y%m%d%H%M%S")}'
staging_bucket = f"gs://{cfg.PROJECT_ID}-ml-examples"
# For a list of available options, check:
# https://cloud.google.com/dataflow/docs/guides/specifying-exec-params#setting-other-cloud-dataflow-pipeline-options
dataflow_options = {
"runner": "DirectRunner" if mode == "local" else "DataflowRunner",
"job_name": job_name,
"project": project,
"region": cfg.REGION,
"staging_location": f"{staging_bucket}/dflow-staging",
"temp_location": f"{staging_bucket}/dflow-temp",
"setup_file": "./setup.py",
}
flags = []
if device == "GPU":
flags = [
"--experiment=worker_accelerator=type:nvidia-tesla-p4;count:1;"\
"install-nvidia-driver",
]
dataflow_options.update({
"sdk_container_image": cfg.DOCKER_IMG,
"machine_type": "n1-standard-4",
})
# Optional parameters
if num_workers:
dataflow_options.update({"num_workers": num_workers})
return PipelineOptions(flags=flags, **dataflow_options)