blob: c5f227a67e5352587b2fd763af8075ae50bf4a0a [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
source("scripts/builtin/shapExplainer.dml") as shap;
if ($1 == 'prepare_mask_for_permutation') {
#prepare_mask_for_permutation
perm = matrix("3 1 2", cols=3, rows=1)
result = shap::prepare_mask_for_permutation(permutation=perm)
expected_result = matrix("0 0 0 0 0 1 1 0 1 1 1 1 0 1 0 1 1 0", rows=6, cols=3)
} else if ($1 == 'prepare_mask_for_partial_permutation') {
#prepare_mask_for_partial_permutation
perm = matrix("4 1 2", cols=3, rows=1)
result = shap::prepare_mask_for_permutation(permutation=perm, n_non_varying_inds=2)
expected_result = matrix("0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 0 1 1 0 1 1 1 1 0 1", rows=6, cols=5)
} else if ($1 == 'prepare_mask_for_partitioned_permutation') {
#prepare_mask_for_partitioned_permutation
perm = matrix("4 1 2", cols=3, rows=1)
partitions = matrix("2 4 3 5", cols=2, rows=2)
result = shap::prepare_mask_for_permutation(permutation=perm, partitions=partitions)
expected_result = matrix("0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0 1 1 0 0 1 1 1 0 0", rows=6, cols=5)
} else if ($1 == 'compute_means_from_predictions') {
#compute_means_from_predictions
p = matrix("2 3 3 4 4 5", rows=6, cols=1)
result = shap::compute_means_from_predictions(p, 2)
expected_result = matrix("2.5 3.5 4.5", rows=3, cols=1)
} else if ($1 == 'compute_phis_from_prediction_means') {
#compute_phis_from_prediction_means
permutation = matrix("2 3 4 1 5", cols=5, rows=1)
P_perm = matrix("10 21 22 23 24 100 31 32 33 34", rows=10, cols=1)
result = shap::compute_phis_from_prediction_means(P=P_perm, permutations=permutation)
expected_result = matrix("1 38.5 1 1 48.5", rows=5, cols=1)
} else if ($1 == 'compute_phis_from_prediction_means_non_vars') {
#compute_phis_from_prediction_means with non varying inds
permutation = matrix("3 4 2 1 5", cols=5, rows=1)
non_varying_inds= matrix("2", rows=1, cols=1)
P_perm = matrix("10 22 23 24 100 31 32 33", rows=8, cols=1)
result = shap::compute_phis_from_prediction_means(P=P_perm, permutations=permutation, non_var_inds=non_varying_inds)
expected_result = matrix("1 0 39.5 1 48.5", rows=5, cols=1)
} else if ($1 == 'prepare_full_mask') {
#prepare_full_mask
mask = matrix("1 0 0 1", rows=2, cols=2)
result = shap::prepare_full_mask(mask, 3)
result = shap::u_repeatRows(mask,3)
expected_result = matrix("1 0 1 0 1 0 0 1 0 1 0 1", rows=6, cols=2)
} else if ($1 == 'prepare_masked_X_bg') {
#prepare_masked_X_bg
mask = matrix("1 0 0 1", rows=2, cols=2)
full_mask = shap::prepare_full_mask(mask, 3)
X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2)
result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0)
expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0", rows=6, cols=2)
} else if ($1 == 'prepare_masked_X_bg_independent_perms') {
#prepare_masked_X_bg for independent perms
mask = matrix("1 0 0 1 1 0 0 1", rows=4, cols=2)
full_mask = shap::prepare_full_mask(mask, 3)
X_bg_samples = matrix("11 12 21 22 31 32 41 42 51 52 61 62", rows=6, cols=2)
result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 2)
expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0 0 42 0 52 0 62 41 0 51 0 61 0", rows=12, cols=2)
} else if ($1 == 'apply_full_mask') {
#apply_full_mask
x_row = matrix("100 200", rows=1, cols=2)
mask = matrix("1 0 0 1", rows=2, cols=2)
full_mask = shap::prepare_full_mask(mask, 3)
X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2)
masked_X_bg = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0)
result = shap::apply_full_mask(x_row, full_mask, masked_X_bg)
expected_result = matrix("100 12 100 22 100 32 11 200 21 200 31 200", rows=6, cols=2)
} else {
print("Test type "+$1+" unknown.")
result = matrix("100 100", rows=1, cols=2)
expected_result = matrix("0 0", rows=1, cols=2)
}
write(result, $2)
write(expected_result, $3)