| import importlib |
| import logging |
| import os |
| from typing import List |
| |
| import boto3 |
| import click |
| from tenacity import retry, stop_after_delay |
| |
| from hamilton import dataflows, driver |
| from hamilton.io.materialization import to |
| from hamilton.log_setup import setup_logging |
| |
| importlib.import_module("adapters") |
| |
| setup_logging(logging.INFO) |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| def _create_driver(): |
| # TODO -- use the hub |
| caption_images = dataflows.import_module("caption_images", "elijahbenizzy") |
| generate_images = dataflows.import_module("generate_images", "elijahbenizzy") |
| dr = driver.Driver({"include_embeddings": True}, caption_images, generate_images) |
| return dr |
| |
| |
| def determine_state( |
| initial_image_path: str, storage_mode: str, image_name: str, params: dict[str, str] |
| ): |
| """Determines where we are in the iteration loop""" |
| iteration = 0 |
| image_url = initial_image_path |
| has_original = True |
| |
| if storage_mode == "local": |
| files = os.listdir(os.path.join(params["base_dir"], image_name)) |
| paths = [os.path.basename(f) for f in files] |
| elif storage_mode == "s3": |
| s3_client = boto3.resource("s3") |
| bucket = s3_client.Bucket(params["bucket"]) |
| contents = [item for item in bucket.objects.filter(Prefix=f"{image_name}/")] |
| paths = [item.key.replace(f"{image_name}/", "") for item in contents] |
| else: |
| raise ValueError(f"Invalid storage engine: {storage_mode}") |
| if "original.png" not in paths: |
| # No original one, starting from zero |
| has_original = False |
| if initial_image_path is None: |
| raise ValueError( |
| "Must provide initial_image_path if no original image is present in the target location" |
| ) |
| else: |
| metadata_paths = sorted( |
| [item for item in paths if "metadata_" in item], |
| key=lambda item: int(item.split("_")[1].split(".")[0]), |
| ) |
| image_paths = sorted( |
| [item for item in paths if "image_" in item], |
| key=lambda item: int(item.split("_")[1].split(".")[0]), |
| ) |
| iteration = min(len(metadata_paths), len(image_paths)) |
| if len(image_paths) > 0: |
| image_path = image_paths[-1] |
| image_url = ( |
| os.path.join(params["base_dir"], image_name, image_path) |
| if storage_mode == "local" |
| else f"s3://{params['bucket']}/{image_name}/{image_path}" |
| ) |
| return iteration, image_url, has_original |
| |
| |
| @retry(stop=stop_after_delay(120)) |
| def caption_step( |
| dr: driver, |
| image_name: str, |
| iteration: int, |
| storage_params: dict, |
| storage_mode: str, |
| image_url: str, |
| image_formats: List[str], |
| descriptiveness: str, |
| save_original_image: bool = False, |
| ) -> str: |
| """ |
| Step to caption an image. Returns the caption generated by OpenAI. |
| |
| :param dr: The driver to use |
| :param image_name: The name of the image, globally unique |
| :param iteration: The iteration number |
| :param storage_params: The storage parameters |
| :param storage_mode: The storage mode |
| :param image_url: The URL of the image to caption |
| :param image_formats: The image formats to save |
| :param descriptiveness: The descriptiveness of the image |
| :param save_original_image: Whether to save the original image |
| :return: The caption generated by OpenAI |
| """ |
| metadata_save_path = os.path.join(image_name, f"metadata_{iteration}.json") |
| materializers = [] |
| if storage_mode == "s3": |
| bucket = storage_params["bucket"] |
| materializers.append( |
| to.json_s3( |
| bucket=bucket, |
| key=f"{image_name}/metadata_{iteration}.json", |
| dependencies=["metadata"], |
| id="save_metadata", |
| ) |
| ) |
| if save_original_image: |
| for format in image_formats: |
| materializers.append( |
| to.image_s3( |
| bucket=bucket, |
| key=f"{image_name}/original.{format}", |
| dependencies=["image_url"], |
| id=f"save_original_image_{format}_s3", |
| format=format, |
| ) |
| ) |
| if storage_mode == "local": |
| base_dir = storage_params["base_dir"] |
| materializers.append( |
| to.json( |
| path=os.path.join(base_dir, metadata_save_path), |
| dependencies=["metadata"], |
| id="save_metadata", |
| ) |
| ) |
| if save_original_image: |
| for format in image_formats: |
| materializers.append( |
| to.image( |
| path=os.path.join(base_dir, f"{image_name}/original.{format}"), |
| dependencies=["image_url"], |
| id=f"save_original_image_{format}", |
| format=format, |
| ) |
| ) |
| materializer_data, results = dr.materialize( |
| *materializers, |
| additional_vars=["generated_caption"], |
| inputs={ |
| "image_url": image_url, |
| "descriptiveness": descriptiveness, |
| "additional_metadata": { |
| "descriptiveness": descriptiveness, |
| "iteration": iteration, |
| "creator": "github.com/elijahbenizzy", |
| }, |
| }, |
| ) |
| return results["generated_caption"] |
| |
| |
| @retry(stop=stop_after_delay(120)) |
| def generate_step( |
| dr: driver, |
| image_name: str, |
| iteration: int, |
| caption: str, |
| storage_mode: str, |
| formats: List[str], |
| storage_params: dict, |
| ) -> List[str]: |
| """Step to generate an image. Returns the temporary (openAI-generated) URL of the image for display. |
| :param dr: The driver to use |
| :param image_name: The name of the image, globally unique |
| :param iteration: The iteration number |
| :param caption: The caption to use |
| :param storage_mode: The storage mode (s3 or local) |
| :param formats: The image formats to save |
| :param storage_params: The storage parameters to pass to the storage mode |
| :return: The temporary (openAI-generated) URL of the image for display |
| """ |
| materializers = [] |
| save_uris = [] |
| if storage_mode == "s3": |
| bucket = storage_params["bucket"] |
| for format in formats: |
| materializers.append( |
| to.image_s3( |
| bucket=bucket, |
| key=f"{image_name}/image_{iteration}.{format}", |
| dependencies=["generated_image"], |
| id=f"save_image_{format}_s3", |
| format=format, |
| ) |
| ) |
| save_uris.append(f"s3://{bucket}/{image_name}/image_{iteration}.{format}") |
| if storage_mode == "local": |
| base_dir = storage_params["base_dir"] |
| for format in formats: |
| image_path = os.path.join(base_dir, f"{image_name}/image_{iteration}.{format}") |
| materializers.append( |
| to.image( |
| path=os.path.join(base_dir, f"{image_name}/image_{iteration}.{format}"), |
| dependencies=["generated_image"], |
| id=f"save_image_{format}", |
| format=format, |
| ) |
| ) |
| save_uris.append(image_path) |
| inputs = {"image_generation_prompt": caption} |
| dr.materialize(*materializers, inputs=inputs) |
| return save_uris # URL for the generated image from OpenAI |
| |
| |
| @click.command() |
| @click.option( |
| "-o", |
| "--output-format", |
| required=True, |
| help="Output formats to convert to. Will always save to png.", |
| type=str, |
| multiple=True, |
| ) |
| @click.option( |
| "-i", |
| "--image-name", |
| required=True, |
| help="Image name to convert, must be globally unique", |
| type=str, |
| ) |
| @click.option( |
| "-b", |
| "--bucket", |
| required=False, |
| help="S3 bucket to use use if mode=='s3'", |
| type=str, |
| default=None, |
| ) |
| @click.option( |
| "-d", |
| "--base-dir", |
| required=False, |
| help="Base directory to use if mode=='local'", |
| type=str, |
| default=None, |
| ) |
| @click.option( |
| "-m", |
| "--mode", |
| required=True, |
| help="Mode to use, either 's3' or 'local'", |
| type=click.Choice(["s3", "local"]), |
| ) |
| @click.option( |
| "-p", |
| "--initial-path", |
| required=False, |
| help="Initial path to use if not restarting workflow", |
| type=str, |
| default=None, |
| ) |
| @click.option( |
| "-n", |
| "--num-iterations", |
| required=False, |
| help="Number of iterations to run", |
| type=int, |
| default=150, |
| ) |
| @click.option( |
| "-d", |
| "--descriptiveness", |
| required=False, |
| help="Descriptiveness of the image. Adverb to describe how descriptive it is. E.G. " |
| "'quite', 'somewhat', 'not particularly'", |
| type=str, |
| default=None, |
| ) |
| def run( |
| image_name: str, |
| output_format: List[str], |
| bucket: str, |
| base_dir: str, |
| mode: str, |
| initial_path: str, |
| num_iterations: int, |
| descriptiveness: str, |
| ): |
| logger.info(f"Running image telephone for {image_name} with output formats {output_format}") |
| output_format = list(set(output_format)) |
| if "png" not in output_format: |
| output_format.append("png") |
| if mode == "local": |
| assert base_dir is not None, "Must provide base_dir if mode is local" |
| if mode == "s3": |
| assert bucket is not None, "Must provide bucket if mode is s3" |
| storage_params = {"bucket": bucket} if mode == "s3" else {"base_dir": base_dir} |
| if mode == "local" and not os.path.exists(image_path := os.path.join(base_dir, image_name)): |
| os.makedirs(image_path) |
| iteration, image_url, has_original = determine_state( |
| initial_path, mode, image_name, storage_params |
| ) |
| logger.info(f"Starting from iteration {iteration} with image {image_url}") |
| dr = _create_driver() |
| while iteration < num_iterations: |
| logger.info(f" Beginning iteration: {iteration} with image URL: {image_url}") |
| generated_caption = caption_step( |
| dr, |
| image_name, |
| iteration, |
| storage_params, |
| mode, |
| image_url, |
| output_format, |
| descriptiveness=descriptiveness, |
| save_original_image=not has_original, |
| ) |
| logger.info(f"Captioned image: {image_url} with caption: {generated_caption}") |
| saved_image_uris = generate_step( |
| dr, |
| image_name, |
| iteration, |
| generated_caption, |
| mode, |
| output_format, |
| storage_params, |
| ) |
| logger.info(f"Generated image, saved at: {saved_image_uris}") |
| iteration += 1 |
| (image_url,) = [item for item in saved_image_uris if item.endswith("png")] |
| has_original = True |
| logger.info( |
| f"Finished {num_iterations} iterations of image telephone for {image_name} with output formats {output_format}" |
| ) |
| |
| |
| if __name__ == "__main__": |
| run() |