blob: 7f7e0db24aa354e4bca625b34981f90f66c3ee96 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import os
# Required to run the script easily on PySpark's root directory on the Spark repo.
sys.path.append(os.getcwd())
import uuid
import time
import random
from typing import List
from pyspark.sql.types import (
StringType,
StructType,
StructField,
)
from pyspark.sql.streaming.stateful_processor_api_client import (
ListTimerIterator,
StatefulProcessorApiClient,
)
from pyspark.sql.streaming.benchmark.utils import print_percentiles
from pyspark.sql.streaming.benchmark.tws_utils import get_list_state, get_map_state, get_value_state
def benchmark_value_state(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
data_size = int(params[0])
value_state = get_value_state(
api_client, "example_value_state", StructType([StructField("value", StringType(), True)])
)
measured_times_implicit_key = []
measured_times_get = []
measured_times_update = []
uuid_long = []
for i in range(int(data_size / 32) + 1):
uuid_long.append(str(uuid.uuid4()))
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
for i in range(1000000):
# Generate a random value
random.shuffle(uuid_long)
value = ("".join(uuid_long))[0:data_size]
start_time_implicit_key_ns = time.perf_counter_ns()
api_client.set_implicit_key(("example_grouping_key",))
end_time_implicit_key_ns = time.perf_counter_ns()
measured_times_implicit_key.append(
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
)
# Measure the time taken for the get operation
start_time_get_ns = time.perf_counter_ns()
value_state.get()
end_time_get_ns = time.perf_counter_ns()
measured_times_get.append((end_time_get_ns - start_time_get_ns) / 1000)
start_time_update_ns = time.perf_counter_ns()
value_state.update((value,))
end_time_update_ns = time.perf_counter_ns()
measured_times_update.append((end_time_update_ns - start_time_update_ns) / 1000)
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
print(" ==================== GET latency (micros) ======================")
print_percentiles(measured_times_get, [50, 95, 99, 99.9, 100])
print(" ==================== UPDATE latency (micros) ======================")
print_percentiles(measured_times_update, [50, 95, 99, 99.9, 100])
def benchmark_list_state(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
data_size = int(params[0])
list_length = int(params[1])
# get and rewrite the list - the actual behavior depends on the server side implementation
list_state = get_list_state(
api_client, "example_list_state", StructType([StructField("value", StringType(), True)])
)
measured_times_implicit_key = []
measured_times_get = []
measured_times_put = []
measured_times_clear = []
measured_times_append_value = []
uuid_long = []
for i in range(int(data_size / 32) + 1):
uuid_long.append(str(uuid.uuid4()))
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
for i in range(1000000):
# Generate a random value
random.shuffle(uuid_long)
value = ("".join(uuid_long))[0:data_size]
start_time_implicit_key_ns = time.perf_counter_ns()
api_client.set_implicit_key(("example_grouping_key",))
end_time_implicit_key_ns = time.perf_counter_ns()
measured_times_implicit_key.append(
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
)
# Measure the time taken for the get operation
start_time_get_ns = time.perf_counter_ns()
old_list_elements = list(list_state.get())
end_time_get_ns = time.perf_counter_ns()
measured_times_get.append((end_time_get_ns - start_time_get_ns) / 1000)
if len(old_list_elements) > list_length:
start_time_clear_ns = time.perf_counter_ns()
list_state.clear()
end_time_clear_ns = time.perf_counter_ns()
measured_times_clear.append((end_time_clear_ns - start_time_clear_ns) / 1000)
elif len(old_list_elements) % 2 == 0:
start_time_put_ns = time.perf_counter_ns()
old_list_elements.append((value,))
list_state.put(old_list_elements)
end_time_put_ns = time.perf_counter_ns()
measured_times_put.append((end_time_put_ns - start_time_put_ns) / 1000)
else:
start_time_append_value_ns = time.perf_counter_ns()
list_state.appendValue((value,))
end_time_append_value_ns = time.perf_counter_ns()
measured_times_append_value.append(
(end_time_append_value_ns - start_time_append_value_ns) / 1000
)
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
print(" ==================== GET latency (micros) ======================")
print_percentiles(measured_times_get, [50, 95, 99, 99.9, 100])
print(" ==================== PUT latency (micros) ======================")
print_percentiles(measured_times_put, [50, 95, 99, 99.9, 100])
print(" ==================== CLEAR latency (micros) ======================")
print_percentiles(measured_times_clear, [50, 95, 99, 99.9, 100])
print(" ==================== APPEND VALUE latency (micros) ======================")
print_percentiles(measured_times_append_value, [50, 95, 99, 99.9, 100])
def benchmark_map_state(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
data_size = int(params[0])
map_length = int(params[1])
map_state = get_map_state(
api_client,
"example_map_state",
StructType(
[
StructField("key", StringType(), True),
]
),
StructType([StructField("value", StringType(), True)]),
)
measured_times_implicit_key = []
measured_times_keys = []
measured_times_iterator = []
measured_times_clear = []
measured_times_contains_key = []
measured_times_update_value = []
measured_times_remove_key = []
uuid_long = []
for i in range(int(data_size / 32) + 1):
uuid_long.append(str(uuid.uuid4()))
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
run_clear = False
for i in range(1000000):
# Generate a random value
random.shuffle(uuid_long)
value = ("".join(uuid_long))[0:data_size]
start_time_implicit_key_ns = time.perf_counter_ns()
api_client.set_implicit_key(("example_grouping_key",))
end_time_implicit_key_ns = time.perf_counter_ns()
measured_times_implicit_key.append(
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
)
if i % 2 == 0:
start_time_keys_ns = time.perf_counter_ns()
keys = list(map_state.keys())
end_time_keys_ns = time.perf_counter_ns()
measured_times_keys.append((end_time_keys_ns - start_time_keys_ns) / 1000)
else:
start_time_iterator_ns = time.perf_counter_ns()
kv_pairs = list(map_state.iterator())
end_time_iterator_ns = time.perf_counter_ns()
measured_times_iterator.append((end_time_iterator_ns - start_time_iterator_ns) / 1000)
keys = [kv[0] for kv in kv_pairs]
if len(keys) > map_length and run_clear:
start_time_clear_ns = time.perf_counter_ns()
map_state.clear()
end_time_clear_ns = time.perf_counter_ns()
measured_times_clear.append((end_time_clear_ns - start_time_clear_ns) / 1000)
run_clear = False
elif len(keys) > map_length:
for key in keys:
start_time_contains_key_ns = time.perf_counter_ns()
map_state.containsKey(key)
end_time_contains_key_ns = time.perf_counter_ns()
measured_times_contains_key.append(
(end_time_contains_key_ns - start_time_contains_key_ns) / 1000
)
start_time_remove_key_ns = time.perf_counter_ns()
map_state.removeKey(key)
end_time_remove_key_ns = time.perf_counter_ns()
measured_times_remove_key.append(
(end_time_remove_key_ns - start_time_remove_key_ns) / 1000
)
run_clear = True
else:
start_time_update_value_ns = time.perf_counter_ns()
map_state.updateValue((str(uuid.uuid4()),), (value,))
end_time_update_value_ns = time.perf_counter_ns()
measured_times_update_value.append(
(end_time_update_value_ns - start_time_update_value_ns) / 1000
)
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
print(" ==================== KEYS latency (micros) ======================")
print_percentiles(measured_times_keys, [50, 95, 99, 99.9, 100])
print(" ==================== ITERATOR latency (micros) ======================")
print_percentiles(measured_times_iterator, [50, 95, 99, 99.9, 100])
print(" ==================== CLEAR latency (micros) ======================")
print_percentiles(measured_times_clear, [50, 95, 99, 99.9, 100])
print(" ==================== CONTAINS KEY latency (micros) ======================")
print_percentiles(measured_times_contains_key, [50, 95, 99, 99.9, 100])
print(" ==================== UPDATE VALUE latency (micros) ======================")
print_percentiles(measured_times_update_value, [50, 95, 99, 99.9, 100])
print(" ==================== REMOVE KEY latency (micros) ======================")
print_percentiles(measured_times_remove_key, [50, 95, 99, 99.9, 100])
def benchmark_timer(api_client: StatefulProcessorApiClient, params: List[str]) -> None:
num_timers = int(params[0])
measured_times_implicit_key = []
measured_times_register = []
measured_times_delete = []
measured_times_list = []
# TODO: Use streaming quantiles in Apache DataSketch if we want to run this longer
for i in range(1000000):
expiry_ts_ms = random.randint(0, 10000000)
start_time_implicit_key_ns = time.perf_counter_ns()
api_client.set_implicit_key(("example_grouping_key",))
end_time_implicit_key_ns = time.perf_counter_ns()
measured_times_implicit_key.append(
(end_time_implicit_key_ns - start_time_implicit_key_ns) / 1000
)
start_time_list_ns = time.perf_counter_ns()
timer_iter = ListTimerIterator(api_client)
timers = list(timer_iter)
end_time_list_ns = time.perf_counter_ns()
measured_times_list.append((end_time_list_ns - start_time_list_ns) / 1000)
if len(timers) > num_timers:
start_time_delete_ns = time.perf_counter_ns()
api_client.delete_timer(timers[0])
end_time_delete_ns = time.perf_counter_ns()
measured_times_delete.append((end_time_delete_ns - start_time_delete_ns) / 1000)
start_time_register_ns = time.perf_counter_ns()
api_client.register_timer(expiry_ts_ms)
end_time_register_ns = time.perf_counter_ns()
measured_times_register.append((end_time_register_ns - start_time_register_ns) / 1000)
print(" ==================== SET IMPLICIT KEY latency (micros) ======================")
print_percentiles(measured_times_implicit_key, [50, 95, 99, 99.9, 100])
print(" ==================== REGISTER latency (micros) ======================")
print_percentiles(measured_times_register, [50, 95, 99, 99.9, 100])
print(" ==================== DELETE latency (micros) ======================")
print_percentiles(measured_times_delete, [50, 95, 99, 99.9, 100])
print(" ==================== LIST latency (micros) ======================")
print_percentiles(measured_times_list, [50, 95, 99, 99.9, 100])
def main(state_server_port: str, benchmark_type: str) -> None:
key_schema = StructType(
[
StructField("key", StringType(), True),
]
)
try:
state_server_id = int(state_server_port)
except ValueError:
state_server_id = state_server_port # type: ignore[assignment]
api_client = StatefulProcessorApiClient(
state_server_port=state_server_id,
key_schema=key_schema,
)
benchmarks = {
"value": benchmark_value_state,
"list": benchmark_list_state,
"map": benchmark_map_state,
"timer": benchmark_timer,
}
benchmarks[benchmark_type](api_client, sys.argv[3:])
if __name__ == "__main__":
"""
Instructions to run the benchmark:
(assuming you installed required dependencies for PySpark)
1. `cd python`
2. `python3 pyspark/sql/streaming/benchmark/benchmark_tws_state_server.py
<port/uds file of state server> <state type> <params if required>`
Currently, state type can be one of the following:
- value
- list
- map
- timer
Please take a look at the benchmark functions to see the parameters required for each state
type.
"""
print("Starting the benchmark code... state server port: " + sys.argv[1])
main(sys.argv[1], sys.argv[2])