| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| /* |
| * Test utility functions. |
| */ |
| |
| all_equal = function(matrix[double] X1, matrix[double] X2) |
| return(boolean equivalent) { |
| /* |
| * Determine if two matrices are equivalent. |
| * |
| * Inputs: |
| * - X1: Inputs, of shape (any, any). |
| * - X2: Inputs, of same shape as X1. |
| * |
| * Outputs: |
| * - equivalent: Whether or not the two matrices are equivalent. |
| */ |
| equivalent = as.logical(prod(X1 == X2)) |
| } |
| |
| check_all_equal = function(matrix[double] X1, matrix[double] X2) |
| return(boolean equivalent) { |
| /* |
| * Check if two matrices are equivalent, and report any issues. |
| * |
| * Issues an "ERROR" statement if elements of the two matrices are |
| * not equal. |
| * |
| * Inputs: |
| * - X1: Inputs, of shape (any, any). |
| * - X2: Inputs, of same shape as X1. |
| * |
| * Outputs: |
| * - equivalent: Whether or not the two matrices are equivalent. |
| */ |
| # Determine if matrices are equivalent |
| equivalent = all_equal(X1, X2) |
| |
| # Evaluate relative error |
| if (!equivalent) { |
| print("ERROR: The two matrices are not equivalent.") |
| } |
| } |
| |
| compute_rel_error = function(double x1, double x2) |
| return (double rel_error) { |
| /* |
| * Relative error measure between two values. |
| * |
| * Uses smoothing to avoid divide-by-zero errors. |
| * |
| * Inputs: |
| * - x1: First value. |
| * - x2: Second value. |
| * |
| * Outputs: |
| * - rel_error: Relative error measure between the two values. |
| */ |
| rel_error = abs(x1-x2) / max(1e-8, abs(x1)+abs(x2)) |
| } |
| |
| check_rel_error = function(double x1, double x2, double thresh_error, double thresh_warn) |
| return (double rel_error) { |
| /* |
| * Check and report any issues with the relative error measure between |
| * two values. |
| * |
| * Issues an "ERROR" statement for relative errors > thresh_error, |
| * indicating that the implementation is likely incorrect. |
| * |
| * Issues a "WARNING" statement for relative errors < thresh_error |
| * but > thresh_warn, indicating that the implementation may be |
| * incorrect. |
| * |
| * Inputs: |
| * - x1: First value. |
| * - x2: Second value. |
| * - thresh_error: Error threshold. |
| * - thresh_warn: Warning threshold. |
| * |
| * Outputs: |
| * - rel_error: Relative error measure between the two values. |
| */ |
| # Compute relative error |
| rel_error = compute_rel_error(x1, x2) |
| |
| # Evaluate relative error |
| if (rel_error > thresh_error) { |
| print("ERROR: Relative error " + rel_error + " > " + thresh_error + " with " + x1 + |
| " vs " + x2 + ".") |
| } |
| else if (rel_error > thresh_warn & rel_error <= thresh_error) { |
| print("WARNING: Relative error " + rel_error + " > " + thresh_warn + " & <= " + thresh_error + |
| " with " + x1 + " vs " + x2 + ".") |
| } |
| } |
| |
| check_rel_grad_error = function(double dw_a, double dw_n, double lossph, double lossmh) |
| return (double rel_error) { |
| /* |
| * Check and report any issues with the relative error measure between |
| * the analytical and numerical partial derivatives. |
| * |
| * - Issues an "ERROR" statement for relative errors > 1e-2, |
| * indicating that the gradient is likely incorrect. |
| * - Issues a "WARNING" statement for relative errors < 1e-2 |
| * but > 1e-4, indicating that the may be incorrect. |
| * |
| * Inputs: |
| * - dw_a: Analytical partial derivative wrt w. |
| * - dw_n: Numerical partial derivative wrt w. |
| * - lossph: Loss evaluated with w set to w+h. |
| * - lossmh: Loss evaluated with w set to w-h. |
| * |
| * Outputs: |
| * - rel_error: Relative error measure between the two derivatives. |
| */ |
| # Compute relative error |
| rel_error = compute_rel_error(dw_a, dw_n) |
| |
| # Evaluate relative error |
| thresh_error = 1e-2 |
| thresh_warn = 1e-4 |
| if (rel_error > thresh_error) { |
| print("ERROR: Relative error " + rel_error + " > " + thresh_error + " with " + dw_a + |
| " analytical vs " + dw_n + " numerical, with lossph " + lossph + |
| " and lossmh " + lossmh) |
| } |
| else if (rel_error > thresh_warn & rel_error <= thresh_error) { |
| print("WARNING: Relative error " + rel_error + " > " + thresh_warn + " & <= " + thresh_error + |
| " with " + dw_a + " analytical vs " + dw_n + " numerical, with lossph " + lossph + |
| " and lossmh " + lossmh) |
| } |
| } |
| |