blob: e32a885e64f47fd601e0ef0abc576400af8b8236 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* Test utility functions.
*/
all_equal = function(matrix[double] X1, matrix[double] X2)
return(boolean equivalent) {
/*
* Determine if two matrices are equivalent.
*
* Inputs:
* - X1: Inputs, of shape (any, any).
* - X2: Inputs, of same shape as X1.
*
* Outputs:
* - equivalent: Whether or not the two matrices are equivalent.
*/
equivalent = as.logical(prod(X1 == X2))
}
check_all_equal = function(matrix[double] X1, matrix[double] X2)
return(boolean equivalent) {
/*
* Check if two matrices are equivalent, and report any issues.
*
* Issues an "ERROR" statement if elements of the two matrices are
* not equal.
*
* Inputs:
* - X1: Inputs, of shape (any, any).
* - X2: Inputs, of same shape as X1.
*
* Outputs:
* - equivalent: Whether or not the two matrices are equivalent.
*/
# Determine if matrices are equivalent
equivalent = all_equal(X1, X2)
# Evaluate relative error
if (!equivalent) {
print("ERROR: The two matrices are not equivalent.")
}
}
compute_rel_error = function(double x1, double x2)
return (double rel_error) {
/*
* Relative error measure between two values.
*
* Uses smoothing to avoid divide-by-zero errors.
*
* Inputs:
* - x1: First value.
* - x2: Second value.
*
* Outputs:
* - rel_error: Relative error measure between the two values.
*/
rel_error = abs(x1-x2) / max(1e-8, abs(x1)+abs(x2))
}
check_rel_error = function(double x1, double x2, double thresh_error, double thresh_warn)
return (double rel_error) {
/*
* Check and report any issues with the relative error measure between
* two values.
*
* Issues an "ERROR" statement for relative errors > thresh_error,
* indicating that the implementation is likely incorrect.
*
* Issues a "WARNING" statement for relative errors < thresh_error
* but > thresh_warn, indicating that the implementation may be
* incorrect.
*
* Inputs:
* - x1: First value.
* - x2: Second value.
* - thresh_error: Error threshold.
* - thresh_warn: Warning threshold.
*
* Outputs:
* - rel_error: Relative error measure between the two values.
*/
# Compute relative error
rel_error = compute_rel_error(x1, x2)
# Evaluate relative error
if (rel_error > thresh_error) {
print("ERROR: Relative error " + rel_error + " > " + thresh_error + " with " + x1 +
" vs " + x2 + ".")
}
else if (rel_error > thresh_warn & rel_error <= thresh_error) {
print("WARNING: Relative error " + rel_error + " > " + thresh_warn + " & <= " + thresh_error +
" with " + x1 + " vs " + x2 + ".")
}
}
check_rel_grad_error = function(double dw_a, double dw_n, double lossph, double lossmh)
return (double rel_error) {
/*
* Check and report any issues with the relative error measure between
* the analytical and numerical partial derivatives.
*
* - Issues an "ERROR" statement for relative errors > 1e-2,
* indicating that the gradient is likely incorrect.
* - Issues a "WARNING" statement for relative errors < 1e-2
* but > 1e-4, indicating that the may be incorrect.
*
* Inputs:
* - dw_a: Analytical partial derivative wrt w.
* - dw_n: Numerical partial derivative wrt w.
* - lossph: Loss evaluated with w set to w+h.
* - lossmh: Loss evaluated with w set to w-h.
*
* Outputs:
* - rel_error: Relative error measure between the two derivatives.
*/
# Compute relative error
rel_error = compute_rel_error(dw_a, dw_n)
# Evaluate relative error
thresh_error = 1e-2
thresh_warn = 1e-4
if (rel_error > thresh_error) {
print("ERROR: Relative error " + rel_error + " > " + thresh_error + " with " + dw_a +
" analytical vs " + dw_n + " numerical, with lossph " + lossph +
" and lossmh " + lossmh)
}
else if (rel_error > thresh_warn & rel_error <= thresh_error) {
print("WARNING: Relative error " + rel_error + " > " + thresh_warn + " & <= " + thresh_error +
" with " + dw_a + " analytical vs " + dw_n + " numerical, with lossph " + lossph +
" and lossmh " + lossmh)
}
}