blob: 5ea70cc9452ba3192777bb320b3c1c86332c774d [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
from pyspark import SparkContext
from slicing.base.Bucket import Bucket
from slicing.base.SparkNode import SparkNode
from slicing.base.top_k import Topk
from slicing.spark_modules import spark_utils, join_data_parallel
from slicing.spark_modules.join_data_parallel import rows_mapper, combiner
from slicing.spark_modules.spark_utils import approved_union_slice
def merge_values(a, b):
a.minimize_bounds(b)
return a
def merge_combiners(a, b):
a + b
return a
def union_enum(left_level, right_level, x_size, alpha, top_k, w, loss, cur_lvl):
buckets = {}
for node_i in range(len(left_level)):
for node_j in range(len(right_level)):
flag = approved_union_slice(left_level[node_i], right_level[node_j])
if flag:
node = SparkNode(None, None)
node.attributes = list(set(left_level[node_i].attributes) | set(right_level[node_j].attributes))
bucket = Bucket(node, cur_lvl, w, x_size, loss)
bucket.parents.append(left_level[node_i])
bucket.parents.append(right_level[node_j])
bucket.calc_bounds(w, x_size, loss)
if bucket.check_bounds(x_size, alpha, top_k):
buckets[bucket.name] = bucket
return buckets
def parallel_process(all_features, predictions, loss, sc, debug, alpha, k, w, loss_type):
top_k = Topk(k)
cur_lvl = 0
levels = []
cur_lvl_nodes = list(all_features)
pred_pandas = predictions.toPandas()
x_size = len(pred_pandas)
b_topk = SparkContext.broadcast(sc, top_k)
b_cur_lvl = SparkContext.broadcast(sc, cur_lvl)
buckets = {}
for node in cur_lvl_nodes:
bucket = Bucket(node, cur_lvl, w, x_size, loss)
buckets[bucket.name] = bucket
b_buckets = SparkContext.broadcast(sc, buckets)
# rows = predictions.rdd.map(lambda row: (row[0], row[1].indices, row[2])) \
# .map(lambda item: (item[0], item[1].tolist(), item[2]))
rows = predictions.rdd.map(lambda row: row[1].indices) \
.map(lambda item: list(item))
mapped = rows.map(lambda row: rows_mapper(row, b_buckets.value, loss_type))
flattened = mapped.flatMap(lambda line: (line.items()))
reduced = flattened.combineByKey(combiner, join_data_parallel.merge_values, join_data_parallel.merge_combiners)
cur_lvl_nodes = reduced.values() \
.map(lambda bucket: spark_utils.calc_bucket_metrics(bucket, loss, w, x_size, b_cur_lvl.value))
if debug:
cur_lvl_nodes.map(lambda bucket: bucket.print_debug(b_topk.value)).collect()
cur_lvl = 1
prev_level = cur_lvl_nodes.collect()
b_cur_lvl_nodes = SparkContext.broadcast(sc, prev_level)
levels.append(b_cur_lvl_nodes)
top_k = top_k.buckets_top_k(prev_level, x_size, alpha, 1)
while len(prev_level) > 0:
b_topk = SparkContext.broadcast(sc, top_k)
cur_min = top_k.min_score
b_cur_lvl = SparkContext.broadcast(sc, cur_lvl)
top_k.print_topk()
buckets = []
for left in range(int(cur_lvl / 2) + 1):
right = cur_lvl - left - 1
nodes = union_enum(levels[left].value, levels[right].value, x_size, alpha, top_k, w, loss, cur_lvl)
buckets.append(nodes)
b_buckets = sc.parallelize(buckets)
all_buckets = b_buckets.flatMap(lambda line: (line.items()))
combined = dict(all_buckets.combineByKey(combiner, merge_values, merge_combiners).collect())
b_buckets = SparkContext.broadcast(sc, combined)
to_slice = dict(filter(lambda bucket: bucket[1].check_bounds(x_size, alpha, top_k), combined.items()))
b_to_slice = SparkContext.broadcast(sc, to_slice)
mapped = rows.map(lambda row: rows_mapper(row, b_to_slice.value, loss_type))
flattened = mapped.flatMap(lambda line: (line.items()))
partial = flattened.combineByKey(combiner, join_data_parallel.merge_values, join_data_parallel.merge_combiners)
prev_level = partial\
.map(lambda bucket: spark_utils.calc_bucket_metrics(bucket[1], loss, w, x_size, b_cur_lvl.value)).collect()
top_k = top_k.buckets_top_k(prev_level, x_size, alpha, cur_min)
b_topk = SparkContext.broadcast(sc, top_k)
if debug:
partial.values().map(lambda bucket: bucket.print_debug(b_topk.value)).collect()
print("Level " + str(cur_lvl) + " had " + str(
len(levels[cur_lvl - 1].value) * (len(levels[cur_lvl - 1].value) - 1)) +
" candidates but after pruning only " + str(len(prev_level)) + " go to the next level")
print("Program stopped at level " + str(cur_lvl))
b_cur_lvl_nodes = SparkContext.broadcast(sc, prev_level)
levels.append(b_cur_lvl_nodes)
cur_lvl += 1
print()
print("Selected slices are: ")
top_k.print_topk()
return None