blob: d6957d7b39c2721c369848bf00eb88d99c14672b [file] [log] [blame]
import pandas as pd
import pyspark.sql as ps
import pyspark.sql.functions as F
# See # See https://github.com/dragansah/tpch-dbgen/blob/master/tpch-queries/8.sql
from hamilton import htypes
from hamilton.plugins import h_spark
def start_date() -> str:
return "1995-01-01"
def end_date() -> str:
return "1996-12-31"
def america(region: ps.DataFrame) -> ps.DataFrame:
return region.filter(F.col("r_name") == "AMERICA")
def american_nations(nation: ps.DataFrame, america: ps.DataFrame) -> ps.DataFrame:
return nation.join(america, nation.n_regionkey == america.r_regionkey).select(["n_nationkey"])
def american_customers(customer: ps.DataFrame, american_nations: ps.DataFrame) -> ps.DataFrame:
return customer.join(american_nations, customer.c_nationkey == american_nations.n_nationkey)
def american_orders(orders: ps.DataFrame, american_customers: ps.DataFrame) -> ps.DataFrame:
return orders.join(american_customers, orders.o_custkey == american_customers.c_custkey)
def order_data_augmented(
american_orders: ps.DataFrame,
lineitem: ps.DataFrame,
supplier: ps.DataFrame,
nation: ps.DataFrame,
part: ps.DataFrame,
) -> ps.DataFrame:
d = lineitem.join(part, lineitem.l_partkey == part.p_partkey).drop("n_nation", "n_nationkey")
d = d.join(american_orders.drop("n_nationkey"), d.l_orderkey == american_orders.o_orderkey)
d = d.join(supplier, d.l_suppkey == supplier.s_suppkey)
d = d.join(nation, d.s_nationkey == nation.n_nationkey)
return d
def order_data_filtered(
order_data_augmented: ps.DataFrame,
start_date: str,
end_date: str,
p_type: str = "ECONOMY ANODIZED STEEL",
) -> ps.DataFrame:
return order_data_augmented.filter(
(F.col("o_orderdate") >= F.to_date(F.lit(start_date)))
& (F.col("o_orderdate") <= F.to_date(F.lit(end_date)))
& (F.col("p_type") == p_type)
)
def o_year(o_orderdate: pd.Series) -> htypes.column[pd.Series, int]:
return pd.to_datetime(o_orderdate).dt.year
def volume(l_extendedprice: pd.Series, l_discount: pd.Series) -> htypes.column[pd.Series, float]:
return l_extendedprice * (1 - l_discount)
def brazil_volume(n_name: pd.Series, volume: pd.Series) -> htypes.column[pd.Series, float]:
return volume.where(n_name == "BRAZIL", 0)
@h_spark.with_columns(
o_year,
volume,
brazil_volume,
columns_to_pass=["o_orderdate", "l_extendedprice", "l_discount", "n_name", "volume"],
select=["o_year", "volume", "brazil_volume"],
)
def processed(order_data_filtered: ps.DataFrame) -> ps.DataFrame:
return order_data_filtered
def brazil_volume_by_year(processed: ps.DataFrame) -> ps.DataFrame:
return processed.groupBy("o_year").agg(
F.sum("volume").alias("sum_volume"), F.sum("brazil_volume").alias("sum_brazil_volume")
)
def final_data(brazil_volume_by_year: ps.DataFrame) -> pd.DataFrame:
return brazil_volume_by_year.toPandas()