blob: 88a18635990d98bb5c8a622a16c2fe2f287b3530 [file]
import importlib
import logging
import os
from typing import List
import boto3
import click
from tenacity import retry, stop_after_delay
from hamilton import dataflows, driver
from hamilton.io.materialization import to
from hamilton.log_setup import setup_logging
importlib.import_module("adapters")
setup_logging(logging.INFO)
logger = logging.getLogger(__name__)
def _create_driver():
# TODO -- use the hub
caption_images = dataflows.import_module("caption_images", "elijahbenizzy")
generate_images = dataflows.import_module("generate_images", "elijahbenizzy")
dr = driver.Driver({"include_embeddings": True}, caption_images, generate_images)
return dr
def determine_state(
initial_image_path: str, storage_mode: str, image_name: str, params: dict[str, str]
):
"""Determines where we are in the iteration loop"""
iteration = 0
image_url = initial_image_path
has_original = True
if storage_mode == "local":
files = os.listdir(os.path.join(params["base_dir"], image_name))
paths = [os.path.basename(f) for f in files]
elif storage_mode == "s3":
s3_client = boto3.resource("s3")
bucket = s3_client.Bucket(params["bucket"])
contents = [item for item in bucket.objects.filter(Prefix=f"{image_name}/")]
paths = [item.key.replace(f"{image_name}/", "") for item in contents]
else:
raise ValueError(f"Invalid storage engine: {storage_mode}")
if "original.png" not in paths:
# No original one, starting from zero
has_original = False
if initial_image_path is None:
raise ValueError(
"Must provide initial_image_path if no original image is present in the target location"
)
else:
metadata_paths = sorted(
[item for item in paths if "metadata_" in item],
key=lambda item: int(item.split("_")[1].split(".")[0]),
)
image_paths = sorted(
[item for item in paths if "image_" in item],
key=lambda item: int(item.split("_")[1].split(".")[0]),
)
iteration = min(len(metadata_paths), len(image_paths))
if len(image_paths) > 0:
image_path = image_paths[-1]
image_url = (
os.path.join(params["base_dir"], image_name, image_path)
if storage_mode == "local"
else f"s3://{params['bucket']}/{image_name}/{image_path}"
)
return iteration, image_url, has_original
@retry(stop=stop_after_delay(120))
def caption_step(
dr: driver,
image_name: str,
iteration: int,
storage_params: dict,
storage_mode: str,
image_url: str,
image_formats: List[str],
descriptiveness: str,
save_original_image: bool = False,
) -> str:
"""
Step to caption an image. Returns the caption generated by OpenAI.
:param dr: The driver to use
:param image_name: The name of the image, globally unique
:param iteration: The iteration number
:param storage_params: The storage parameters
:param storage_mode: The storage mode
:param image_url: The URL of the image to caption
:param image_formats: The image formats to save
:param descriptiveness: The descriptiveness of the image
:param save_original_image: Whether to save the original image
:return: The caption generated by OpenAI
"""
metadata_save_path = os.path.join(image_name, f"metadata_{iteration}.json")
materializers = []
if storage_mode == "s3":
bucket = storage_params["bucket"]
materializers.append(
to.json_s3(
bucket=bucket,
key=f"{image_name}/metadata_{iteration}.json",
dependencies=["metadata"],
id="save_metadata",
)
)
if save_original_image:
for format in image_formats:
materializers.append(
to.image_s3(
bucket=bucket,
key=f"{image_name}/original.{format}",
dependencies=["image_url"],
id=f"save_original_image_{format}_s3",
format=format,
)
)
if storage_mode == "local":
base_dir = storage_params["base_dir"]
materializers.append(
to.json(
path=os.path.join(base_dir, metadata_save_path),
dependencies=["metadata"],
id="save_metadata",
)
)
if save_original_image:
for format in image_formats:
materializers.append(
to.image(
path=os.path.join(base_dir, f"{image_name}/original.{format}"),
dependencies=["image_url"],
id=f"save_original_image_{format}",
format=format,
)
)
materializer_data, results = dr.materialize(
*materializers,
additional_vars=["generated_caption"],
inputs={
"image_url": image_url,
"descriptiveness": descriptiveness,
"additional_metadata": {
"descriptiveness": descriptiveness,
"iteration": iteration,
"creator": "github.com/elijahbenizzy",
},
},
)
return results["generated_caption"]
@retry(stop=stop_after_delay(120))
def generate_step(
dr: driver,
image_name: str,
iteration: int,
caption: str,
storage_mode: str,
formats: List[str],
storage_params: dict,
) -> List[str]:
"""Step to generate an image. Returns the temporary (openAI-generated) URL of the image for display.
:param dr: The driver to use
:param image_name: The name of the image, globally unique
:param iteration: The iteration number
:param caption: The caption to use
:param storage_mode: The storage mode (s3 or local)
:param formats: The image formats to save
:param storage_params: The storage parameters to pass to the storage mode
:return: The temporary (openAI-generated) URL of the image for display
"""
materializers = []
save_uris = []
if storage_mode == "s3":
bucket = storage_params["bucket"]
for format in formats:
materializers.append(
to.image_s3(
bucket=bucket,
key=f"{image_name}/image_{iteration}.{format}",
dependencies=["generated_image"],
id=f"save_image_{format}_s3",
format=format,
)
)
save_uris.append(f"s3://{bucket}/{image_name}/image_{iteration}.{format}")
if storage_mode == "local":
base_dir = storage_params["base_dir"]
for format in formats:
image_path = os.path.join(base_dir, f"{image_name}/image_{iteration}.{format}")
materializers.append(
to.image(
path=os.path.join(base_dir, f"{image_name}/image_{iteration}.{format}"),
dependencies=["generated_image"],
id=f"save_image_{format}",
format=format,
)
)
save_uris.append(image_path)
inputs = {"image_generation_prompt": caption}
dr.materialize(*materializers, inputs=inputs)
return save_uris # URL for the generated image from OpenAI
@click.command()
@click.option(
"-o",
"--output-format",
required=True,
help="Output formats to convert to. Will always save to png.",
type=str,
multiple=True,
)
@click.option(
"-i",
"--image-name",
required=True,
help="Image name to convert, must be globally unique",
type=str,
)
@click.option(
"-b",
"--bucket",
required=False,
help="S3 bucket to use use if mode=='s3'",
type=str,
default=None,
)
@click.option(
"-d",
"--base-dir",
required=False,
help="Base directory to use if mode=='local'",
type=str,
default=None,
)
@click.option(
"-m",
"--mode",
required=True,
help="Mode to use, either 's3' or 'local'",
type=click.Choice(["s3", "local"]),
)
@click.option(
"-p",
"--initial-path",
required=False,
help="Initial path to use if not restarting workflow",
type=str,
default=None,
)
@click.option(
"-n",
"--num-iterations",
required=False,
help="Number of iterations to run",
type=int,
default=150,
)
@click.option(
"-d",
"--descriptiveness",
required=False,
help="Descriptiveness of the image. Adverb to describe how descriptive it is. E.G. "
"'quite', 'somewhat', 'not particularly'",
type=str,
default=None,
)
def run(
image_name: str,
output_format: List[str],
bucket: str,
base_dir: str,
mode: str,
initial_path: str,
num_iterations: int,
descriptiveness: str,
):
logger.info(f"Running image telephone for {image_name} with output formats {output_format}")
output_format = list(set(output_format))
if "png" not in output_format:
output_format.append("png")
if mode == "local":
assert base_dir is not None, "Must provide base_dir if mode is local"
if mode == "s3":
assert bucket is not None, "Must provide bucket if mode is s3"
storage_params = {"bucket": bucket} if mode == "s3" else {"base_dir": base_dir}
if mode == "local" and not os.path.exists(image_path := os.path.join(base_dir, image_name)):
os.makedirs(image_path)
iteration, image_url, has_original = determine_state(
initial_path, mode, image_name, storage_params
)
logger.info(f"Starting from iteration {iteration} with image {image_url}")
dr = _create_driver()
while iteration < num_iterations:
logger.info(f" Beginning iteration: {iteration} with image URL: {image_url}")
generated_caption = caption_step(
dr,
image_name,
iteration,
storage_params,
mode,
image_url,
output_format,
descriptiveness=descriptiveness,
save_original_image=not has_original,
)
logger.info(f"Captioned image: {image_url} with caption: {generated_caption}")
saved_image_uris = generate_step(
dr,
image_name,
iteration,
generated_caption,
mode,
output_format,
storage_params,
)
logger.info(f"Generated image, saved at: {saved_image_uris}")
iteration += 1
(image_url,) = [item for item in saved_image_uris if item.endswith("png")]
has_original = True
logger.info(
f"Finished {num_iterations} iterations of image telephone for {image_name} with output formats {output_format}"
)
if __name__ == "__main__":
run()