blob: af554c39ecb5caf7d08fdb108fb189070d8dc8c9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is a utility function that will consumer the data generated by dbgen from TPC-H and convert
it into a parquet file with the column names as expected by the TPC-H specification. It assumes
the data generated resides in a path ../../benchmarks/tpch/data relative to the current file,
as will be generated by the script provided in this repository.
"""
from pathlib import Path
import datafusion
import pyarrow as pa
ctx = datafusion.SessionContext()
all_schemas = {}
all_schemas["customer"] = [
("C_CUSTKEY", pa.int64()),
("C_NAME", pa.string()),
("C_ADDRESS", pa.string()),
("C_NATIONKEY", pa.int64()),
("C_PHONE", pa.string()),
("C_ACCTBAL", pa.decimal128(15, 2)),
("C_MKTSEGMENT", pa.string()),
("C_COMMENT", pa.string()),
]
all_schemas["lineitem"] = [
("L_ORDERKEY", pa.int64()),
("L_PARTKEY", pa.int64()),
("L_SUPPKEY", pa.int64()),
("L_LINENUMBER", pa.int32()),
("L_QUANTITY", pa.decimal128(15, 2)),
("L_EXTENDEDPRICE", pa.decimal128(15, 2)),
("L_DISCOUNT", pa.decimal128(15, 2)),
("L_TAX", pa.decimal128(15, 2)),
("L_RETURNFLAG", pa.string()),
("L_LINESTATUS", pa.string()),
("L_SHIPDATE", pa.date32()),
("L_COMMITDATE", pa.date32()),
("L_RECEIPTDATE", pa.date32()),
("L_SHIPINSTRUCT", pa.string()),
("L_SHIPMODE", pa.string()),
("L_COMMENT", pa.string()),
]
all_schemas["nation"] = [
("N_NATIONKEY", pa.int64()),
("N_NAME", pa.string()),
("N_REGIONKEY", pa.int64()),
("N_COMMENT", pa.string()),
]
all_schemas["orders"] = [
("O_ORDERKEY", pa.int64()),
("O_CUSTKEY", pa.int64()),
("O_ORDERSTATUS", pa.string()),
("O_TOTALPRICE", pa.decimal128(15, 2)),
("O_ORDERDATE", pa.date32()),
("O_ORDERPRIORITY", pa.string()),
("O_CLERK", pa.string()),
("O_SHIPPRIORITY", pa.int32()),
("O_COMMENT", pa.string()),
]
all_schemas["part"] = [
("P_PARTKEY", pa.int64()),
("P_NAME", pa.string()),
("P_MFGR", pa.string()),
("P_BRAND", pa.string()),
("P_TYPE", pa.string()),
("P_SIZE", pa.int32()),
("P_CONTAINER", pa.string()),
("P_RETAILPRICE", pa.decimal128(15, 2)),
("P_COMMENT", pa.string()),
]
all_schemas["partsupp"] = [
("PS_PARTKEY", pa.int64()),
("PS_SUPPKEY", pa.int64()),
("PS_AVAILQTY", pa.int32()),
("PS_SUPPLYCOST", pa.decimal128(15, 2)),
("PS_COMMENT", pa.string()),
]
all_schemas["region"] = [
("r_REGIONKEY", pa.int64()),
("r_NAME", pa.string()),
("r_COMMENT", pa.string()),
]
all_schemas["supplier"] = [
("S_SUPPKEY", pa.int64()),
("S_NAME", pa.string()),
("S_ADDRESS", pa.string()),
("S_NATIONKEY", pa.int32()),
("S_PHONE", pa.string()),
("S_ACCTBAL", pa.decimal128(15, 2)),
("S_COMMENT", pa.string()),
]
curr_dir = Path(__file__).resolve().parent
for filename, curr_schema_val in all_schemas.items():
# For convenience, go ahead and convert the schema column names to lowercase
curr_schema = [(s[0].lower(), s[1]) for s in curr_schema_val]
# Pre-collect the output columns so we can ignore the null field we add
# in to handle the trailing | in the file
output_cols = [r[0] for r in curr_schema]
curr_schema = [pa.field(r[0], r[1], nullable=False) for r in curr_schema]
# Trailing | requires extra field for in processing
curr_schema.append(("some_null", pa.null()))
schema = pa.schema(curr_schema)
source_file = (curr_dir / f"../../benchmarks/tpch/data/{filename}.csv").resolve()
dest_file = (curr_dir / f"./data/{filename}.parquet").resolve()
df = ctx.read_csv(source_file, schema=schema, has_header=False, delimiter="|")
df = df.select(*output_cols)
df.write_parquet(dest_file, compression="snappy")