blob: d8c7a1f29b78bf907af7882576f5b56b15d4da17 [file] [log] [blame]
import textwrap
import dlt
import ibis
import ibis.expr.types as ir
import openai
from hamilton.function_modifiers import pipe, source, step
from hamilton.htypes import Collect, Parallelizable
def db_con(pipeline: dlt.Pipeline) -> ibis.BaseBackend:
"""Connect to the Ibis backend"""
backend = ibis.connect(f"{pipeline.pipeline_name}.duckdb")
ibis.set_backend(backend)
return backend
def channel(selected_channels: list[str]) -> Parallelizable[str]:
"""Iterate over channels for which to load messages and replies"""
for channel in selected_channels:
yield channel
def _epoch_microseconds(timestamp: ir.TimestampColumn) -> ir.StringColumn:
"""Convert the timestamp value to a string with microsecond precision
Required to meet the Slack format.
"""
seconds_from_epoch = timestamp.epoch_seconds()
microseconds = timestamp.microsecond() / int(10e5)
return (seconds_from_epoch + microseconds).cast(str)
def channel_message(
channel: str,
db_con: ibis.BaseBackend,
pipeline: dlt.Pipeline,
) -> ir.Table:
"""Load table containing parent messages of a channel.
the timestamps `thread_ts` and `ts` are converted to strings.
`thread_ts` is not None if the message has replies / started a thread. Otherwise,
`thread_ts` == `ts`. Coalesce is used to fill these None values with `ts`
Slack reference: https://api.slack.com/messaging/retrieving#finding_threads
"""
return (
db_con.table(
f"{channel}_message",
schema=pipeline.dataset_name,
database=pipeline.pipeline_name,
)
.mutate(
thread_ts=_epoch_microseconds(ibis._.thread_ts).cast(str),
ts=_epoch_microseconds(ibis._.ts).cast(str),
)
.mutate(thread_ts=ibis.coalesce(ibis._.thread_ts, ibis._.ts))
)
def channel_replies(
channel: str,
db_con: ibis.BaseBackend,
pipeline: dlt.Pipeline,
) -> ir.Table:
"""Create table for replies"""
return db_con.table(
f"{channel}_replies_message",
schema=pipeline.dataset_name,
database=pipeline.pipeline_name,
)
def channel_threads(
channel_message: ir.Table,
channel_replies: ir.Table,
) -> ir.Table:
"""Union of parent messages and replies. Sort by thread start, then message timestamp"""
columns = ["channel", "thread_ts", "ts", "user", "text", "_dlt_load_id", "_dlt_id"]
return ibis.union(
channel_message.select(columns),
channel_replies.select(columns),
).order_by([ibis._.thread_ts, ibis._.ts])
def channels_collection(channel_threads: Collect[ir.Table]) -> ir.Table:
"""Collect `channel_threads` for all channels"""
return ibis.union(*channel_threads)
def _format_messages(threads: ir.Table) -> ir.Table:
"""Assign a user id per thread and prefix messages with it"""
thread_user_id_expr = (ibis.dense_rank().over(order_by="user") + 1).cast(str)
return threads.group_by("thread_ts").mutate(
message=thread_user_id_expr.concat(": ", ibis._.text)
)
def _aggregate_thread(threads: ir.Table) -> ir.Table:
"""Create threads as a single string by concatenating messages
Functions decorates with `@ibis.udf` are loaded by the Ibis backend.
They aren't meant to be called directly.
ref: https://ibis-project.org/how-to/extending/builtin
"""
@ibis.udf.agg.builtin(name="string_agg")
def _string_agg(arg, sep: str = "\n ") -> str:
raise NotImplementedError
@ibis.udf.agg.builtin(name="array_agg")
def _array_agg(arg) -> list[str]:
raise NotImplementedError
return threads.group_by("thread_ts").agg(
thread=_string_agg(ibis._.message),
num_messages=ibis._.count(),
users=_array_agg(ibis._.user).unique(),
_dlt_load_id=ibis._._dlt_load_id.max(),
_dlt_id=_array_agg(ibis._._dlt_id),
)
def summary_prompt() -> str:
"""LLM prompt to summarize Slack thread"""
return textwrap.dedent(
"""Hamilton is an open source library to write dataflows in Python. It is used by developers for data engineering, data science, machine learning, and LLM workflows.
Next is a discussion thread about Hamilton started by User1. Complete these tasks: identify the issue raised by User1, summarize the discussion, indicate if you think the issue was resolved.
DISCUSSION THREAD
{text}
"""
)
def _summary(threads: ir.Table, prompt: str) -> ir.Table:
"""Generate a summary for each thread.
Uses a scalar Python UDF executed by the backend.
"""
@ibis.udf.scalar.python
def _openai_completion_udf(
text: str, prompt_template: str
) -> str: # Ibis requires `str` type hint even if None is allowed
"""Fill `prompt` with `text` and use OpenAI chat completion.
Returns None if:
- `text` is empty
- `content` is too long
- OpenAI call fails
"""
if len(text) == 0:
return None
content = prompt_template.format(text=text)
if len(content) // 4 > 8191:
return None
client = openai.OpenAI()
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": content}],
)
try:
output = response.choices[0].message.content
except Exception:
output = None
return output
return threads.mutate(summary=_openai_completion_udf(threads.thread, prompt))
# @pipe operator facilitates managing the function/node namespace
@pipe(
step(_format_messages), step(_aggregate_thread), step(_summary, prompt=source("summary_prompt"))
)
def threads(channels_collection: ir.Table) -> ir.Table:
"""Create `threads` table by formatting, aggregating messages,
and generating summaries.
"""
return channels_collection
def insert_threads(threads: ir.Table) -> int:
"""Save `threads` table and return row count."""
db_con = ibis.get_backend()
threads_table = db_con.create_table("threads", threads)
db_con.insert("threads", threads)
return int(threads_table.count().execute())