| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| source("scripts/builtin/shapExplainer.dml") as shap; |
| |
| if ($1 == 'prepare_mask_for_permutation') { |
| #prepare_mask_for_permutation |
| perm = matrix("3 1 2", cols=3, rows=1) |
| result = shap::prepare_mask_for_permutation(permutation=perm) |
| expected_result = matrix("0 0 0 0 0 1 1 0 1 1 1 1 0 1 0 1 1 0", rows=6, cols=3) |
| |
| } else if ($1 == 'prepare_mask_for_partial_permutation') { |
| #prepare_mask_for_partial_permutation |
| perm = matrix("4 1 2", cols=3, rows=1) |
| result = shap::prepare_mask_for_permutation(permutation=perm, n_non_varying_inds=2) |
| expected_result = matrix("0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 0 1 1 0 1 1 1 1 0 1", rows=6, cols=5) |
| |
| } else if ($1 == 'prepare_mask_for_partitioned_permutation') { |
| #prepare_mask_for_partitioned_permutation |
| perm = matrix("4 1 2", cols=3, rows=1) |
| partitions = matrix("2 4 3 5", cols=2, rows=2) |
| result = shap::prepare_mask_for_permutation(permutation=perm, partitions=partitions) |
| expected_result = matrix("0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0 1 1 0 0 1 1 1 0 0", rows=6, cols=5) |
| |
| } else if ($1 == 'compute_means_from_predictions') { |
| #compute_means_from_predictions |
| p = matrix("2 3 3 4 4 5", rows=6, cols=1) |
| result = shap::compute_means_from_predictions(p, 2) |
| expected_result = matrix("2.5 3.5 4.5", rows=3, cols=1) |
| |
| } else if ($1 == 'compute_phis_from_prediction_means') { |
| #compute_phis_from_prediction_means |
| permutation = matrix("2 3 4 1 5", cols=5, rows=1) |
| P_perm = matrix("10 21 22 23 24 100 31 32 33 34", rows=10, cols=1) |
| result = shap::compute_phis_from_prediction_means(P=P_perm, permutations=permutation) |
| expected_result = matrix("1 38.5 1 1 48.5", rows=5, cols=1) |
| |
| } else if ($1 == 'compute_phis_from_prediction_means_non_vars') { |
| #compute_phis_from_prediction_means with non varying inds |
| permutation = matrix("3 4 2 1 5", cols=5, rows=1) |
| non_varying_inds= matrix("2", rows=1, cols=1) |
| P_perm = matrix("10 22 23 24 100 31 32 33", rows=8, cols=1) |
| result = shap::compute_phis_from_prediction_means(P=P_perm, permutations=permutation, non_var_inds=non_varying_inds) |
| expected_result = matrix("1 0 39.5 1 48.5", rows=5, cols=1) |
| |
| } else if ($1 == 'prepare_full_mask') { |
| #prepare_full_mask |
| mask = matrix("1 0 0 1", rows=2, cols=2) |
| result = shap::prepare_full_mask(mask, 3) |
| result = shap::u_repeatRows(mask,3) |
| expected_result = matrix("1 0 1 0 1 0 0 1 0 1 0 1", rows=6, cols=2) |
| |
| } else if ($1 == 'prepare_masked_X_bg') { |
| #prepare_masked_X_bg |
| mask = matrix("1 0 0 1", rows=2, cols=2) |
| full_mask = shap::prepare_full_mask(mask, 3) |
| X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2) |
| result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0) |
| expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0", rows=6, cols=2) |
| |
| } else if ($1 == 'prepare_masked_X_bg_independent_perms') { |
| #prepare_masked_X_bg for independent perms |
| mask = matrix("1 0 0 1 1 0 0 1", rows=4, cols=2) |
| full_mask = shap::prepare_full_mask(mask, 3) |
| X_bg_samples = matrix("11 12 21 22 31 32 41 42 51 52 61 62", rows=6, cols=2) |
| result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 2) |
| expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0 0 42 0 52 0 62 41 0 51 0 61 0", rows=12, cols=2) |
| |
| } else if ($1 == 'apply_full_mask') { |
| #apply_full_mask |
| x_row = matrix("100 200", rows=1, cols=2) |
| mask = matrix("1 0 0 1", rows=2, cols=2) |
| full_mask = shap::prepare_full_mask(mask, 3) |
| X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2) |
| masked_X_bg = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0) |
| result = shap::apply_full_mask(x_row, full_mask, masked_X_bg) |
| expected_result = matrix("100 12 100 22 100 32 11 200 21 200 31 200", rows=6, cols=2) |
| |
| } else { |
| print("Test type "+$1+" unknown.") |
| result = matrix("100 100", rows=1, cols=2) |
| expected_result = matrix("0 0", rows=1, cols=2) |
| } |
| |
| write(result, $2) |
| write(expected_result, $3) |
| |