blob: d21f02dc22e15984d174b3677d1f6937b3f40b0e [file] [log] [blame]
import logging
import aggregate_data
import click
import list_data
import process_data
import ray
from dask import distributed
from hamilton import driver, log_setup
from hamilton.execution import executors
from hamilton.plugins import h_dask, h_ray
log_setup.setup_logging(logging.INFO)
@click.command()
@click.option(
"--mode",
type=click.Choice(["local", "multithreading", "dask", "ray"]),
required=True,
help="Where to run remote tasks.",
)
def main(mode: str):
shutdown = None
if mode == "local":
remote_executor = executors.SynchronousLocalTaskExecutor()
elif mode == "multithreading":
remote_executor = executors.MultiThreadingExecutor(max_tasks=100)
elif mode == "dask":
cluster = distributed.LocalCluster()
client = distributed.Client(cluster)
remote_executor = h_dask.DaskExecutor(client=client)
shutdown = cluster.close
else:
remote_executor = h_ray.RayTaskExecutor(num_cpus=4)
shutdown = ray.shutdown
dr = (
driver.Builder()
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(remote_executor) # We only need to specify remote exeecutor
# The local executor just runs it synchronously
.with_modules(aggregate_data, list_data, process_data)
.build()
)
print(
dr.execute(final_vars=["statistics_by_city"], inputs={"data_dir": "data"})[
"statistics_by_city"
]
)
if shutdown:
shutdown()
if __name__ == "__main__":
main()