| /* |
| * |
| * @file dt.c |
| * |
| * @brief Aggregate and utility functions written in C for C45 and RF in MADlib |
| * |
| * @date April 10, 2012 |
| */ |
| |
| #include <float.h> |
| #include <math.h> |
| #include <stdlib.h> |
| #include <time.h> |
| |
| #include "postgres.h" |
| #include "fmgr.h" |
| #include "access/tupmacs.h" |
| #include "utils/array.h" |
| #include "utils/lsyscache.h" |
| #include "utils/builtins.h" |
| #include "utils/typcache.h" |
| #include "catalog/pg_type.h" |
| #include "catalog/namespace.h" |
| #include "nodes/execnodes.h" |
| #include "nodes/nodes.h" |
| #include "funcapi.h" |
| |
| #ifndef NO_PG_MODULE_MAGIC |
| PG_MODULE_MAGIC; |
| #endif |
| /*#define __DT_SHOW_DEBUG_INFO__*/ |
| #ifdef __DT_SHOW_DEBUG_INFO__ |
| #define dtelog(...) elog(__VA_ARGS__) |
| #else |
| #define dtelog(...) |
| #endif |
| |
| |
| /* |
| * Postgres8.4 doesn't have such macro, so we add here |
| */ |
| #ifndef ARRAY_SIZE |
| #define ARRAY_SIZE(x) (sizeof(x) / sizeof(*(x))) |
| #endif |
| |
| |
| /* |
| * This macro is used to get the mask bit of the given feature |
| * id. |
| * fid - ((fid >> power) << power) equals to fid % (2^power) |
| */ |
| #define dt_fid_mask(fid, power) \ |
| (1 << (fid - ((fid >> power) << power))) |
| |
| |
| /* |
| * We use a lot of floating number operations during the training. |
| * For these operations, DBL_EPSILON defined in float.h, leads to error |
| * add-up and wrong results. For our calculations, we need to redefine |
| * that to a bigger number. Any floating number whose absolute value is |
| * smaller than the one defined here will be treated as zero. |
| */ |
| #define DT_EPSILON 0.000000001 |
| |
| /* |
| * This macro is used to test if a float value is 0. |
| * Due to the precision loss of floating numbers, we can not |
| * compare them directly with 0. |
| */ |
| #define dt_is_float_zero(value) \ |
| ((value) < DT_EPSILON && (value) > -DT_EPSILON) |
| |
| |
| /* |
| * calculate the value of (val)log(val) |
| * |
| * @param val the value to be calculated |
| * |
| * NOTE: when x approximates 0, x*log(x) also approximates 0. |
| * Therefore, we directly return 0 when v is 0. |
| */ |
| #define dt_cal_log(v) (dt_is_float_zero(v) ? 0.0 : (v) * log(v)) |
| |
| #define dt_cal_sqr(v) ((v) * (v)) |
| |
| #define dt_cal_sqr_div(v1, v2) (dt_is_float_zero(v2) ? \ |
| 0.0 : ((v1) * (v1))/(v2)) |
| |
| /* |
| * For Error Based Pruning (EBP), we need to compute the additional errors |
| * if the error rate increases to the upper limit of the confidence level. |
| * The coefficient is the square of the number of standard deviations |
| * corresponding to the selected confidence level. |
| * (Excerpt from Documenta Geigy Scientific Tables (Sixth Edition), |
| * p185 (with modifications).) |
| */ |
| static float8 DT_CONFIDENCE_LEVEL[] = |
| {0, 0.001, 0.005, 0.01, 0.05, 0.10, 0.20, 0.40, 1.00}; |
| static float8 DT_CONFIDENCE_DEV[] = |
| {4.0, 3.09, 2.58, 2.33, 1.65, 1.28, 0.84, 0.25, 0.00}; |
| |
| |
| #define MIN_DT_CONFIDENCE_LEVEL 0.001 |
| #define MAX_DT_CONFIDENCE_LEVEL 100.0 |
| |
| |
| #define dt_check_error_value(condition, message, value) \ |
| do { \ |
| if (!(condition)) \ |
| ereport(ERROR, \ |
| (errcode(ERRCODE_RAISE_EXCEPTION), \ |
| errmsg(message, (value)) \ |
| ) \ |
| ); \ |
| } while (0) |
| |
| |
| #define dt_check_error(condition, message) \ |
| do { \ |
| if (!(condition)) \ |
| ereport(ERROR, \ |
| (errcode(ERRCODE_RAISE_EXCEPTION), \ |
| errmsg(message) \ |
| ) \ |
| ); \ |
| } while (0) |
| |
| |
| /* |
| * a forward declaration. |
| */ |
| static |
| float8 |
| dt_ebp_calc_additional_errors |
| ( |
| float8 total_samples, |
| float8 num_errors, |
| float8 conf_level, |
| float8 coeff |
| ); |
| |
| |
| /* |
| * @brief Calculates the total errors used by Error Based Pruning (EBP). |
| * |
| * @param total The number of total samples represented by the node |
| * being processed. |
| * @param probability The probability to mis-classify samples represented |
| * by the child nodes if they are pruned with EBP. |
| * @param conf_level A certainty factor to calculate the confidence limits |
| * for the probability of error using the binomial theorem. |
| * |
| * @return The computed total error. |
| * |
| */ |
| Datum |
| dt_ebp_calc_errors |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| float8 total_samples = PG_GETARG_FLOAT8(0); |
| float8 probability = PG_GETARG_FLOAT8(1); |
| float8 conf_level = PG_GETARG_FLOAT8(2); |
| float8 result = 1.0L; |
| float8 coeff = 0.0L; |
| unsigned int i = 0; |
| |
| if (!dt_is_float_zero(100 - conf_level)) |
| { |
| dt_check_error_value |
| ( |
| !( |
| conf_level < MIN_DT_CONFIDENCE_LEVEL || |
| conf_level > MAX_DT_CONFIDENCE_LEVEL |
| ), |
| "invalid confidence level: %lf." |
| "Confidence level must be in range from 0.001 to 100", |
| conf_level |
| ); |
| |
| dt_check_error_value |
| ( |
| total_samples > 0, |
| "invalid number: %lf. " |
| "The number of samples must be greater than 0", |
| total_samples |
| ); |
| |
| dt_check_error_value |
| ( |
| !(probability < 0 || probability > 1), |
| "invalid probability: %lf. " |
| "The probability must be in range from 0 to 1", |
| probability |
| ); |
| |
| /* |
| * Confidence level value is in range from 0.001 to 1.0. |
| * It should be divided by 100 when calculate addition error. |
| * Therefore, the range of conf_level here is [0.00001, 1.0]. |
| */ |
| conf_level = conf_level * 0.01; |
| |
| /* |
| * Since the conf_level is in [0.00001, 1.0], |
| * the value of i will be in [1, length(DT_CONFIDENCE_LEVEL) - 1] |
| */ |
| while (conf_level > DT_CONFIDENCE_LEVEL[i]) i++; |
| |
| dt_check_error_value |
| ( |
| i > 0 && i < ARRAY_SIZE(DT_CONFIDENCE_LEVEL), |
| "invalid value: %d. " |
| "The index of confidence level must be in range from 0 to 8", |
| i |
| ); |
| |
| coeff = DT_CONFIDENCE_DEV[i-1] + |
| (DT_CONFIDENCE_DEV[i] - DT_CONFIDENCE_DEV[i-1]) * |
| (conf_level - DT_CONFIDENCE_LEVEL[i-1]) / |
| (DT_CONFIDENCE_LEVEL[i] - DT_CONFIDENCE_LEVEL[i-1]); |
| |
| coeff *= coeff; |
| |
| float8 num_errors = total_samples * (1 - probability); |
| result = dt_ebp_calc_additional_errors |
| ( |
| total_samples, |
| num_errors, |
| conf_level, |
| coeff |
| ) + num_errors; |
| } |
| |
| PG_RETURN_FLOAT8((float8)result); |
| } |
| PG_FUNCTION_INFO_V1(dt_ebp_calc_errors); |
| |
| |
| /* |
| * @brief This function calculates the additional errors for EBP. |
| * Detailed description of that pruning strategy can be found in the paper |
| * 'Error-Based Pruning of Decision Trees Grown on Very Large Data Sets |
| * Can Work!'. |
| * |
| * @param total_samples The number of total samples represented by the node |
| * being processed. |
| * @param num_errors The number of mis-classified samples represented |
| * by the child nodes if they are pruned with EBP. |
| * @param conf_level A certainty factor to calculate the confidence limits |
| * for the probability of error using the binomial theorem. |
| * |
| * @return The additional errors if we prune the node being processed. |
| * |
| */ |
| static |
| float8 |
| dt_ebp_calc_additional_errors |
| ( |
| float8 total_samples, |
| float8 num_errors, |
| float8 conf_level, |
| float8 coeff |
| ) |
| { |
| if (num_errors < 1E-6) |
| { |
| return total_samples * (1 - exp(log(conf_level) / total_samples)); |
| } |
| else |
| if (num_errors < 0.9999) |
| { |
| float8 tmp = total_samples * (1 - exp(log(conf_level) / total_samples)); |
| return tmp + |
| num_errors * |
| ( |
| dt_ebp_calc_additional_errors |
| (total_samples, 1.0, conf_level, coeff) - |
| tmp |
| ); |
| } |
| else |
| if (num_errors + 0.5 >= total_samples) |
| { |
| return 0.67 * (total_samples - num_errors); |
| } |
| else |
| { |
| float8 tmp = |
| ( |
| num_errors + 0.5 + coeff/2 + |
| sqrt(coeff * ((num_errors + 0.5) * |
| (1 - (num_errors + 0.5)/total_samples) + coeff/4)) |
| ) |
| / (total_samples + coeff); |
| |
| return (total_samples * tmp - num_errors); |
| } |
| } |
| |
| |
| /* |
| * @brief The step function for aggregating the class counts while |
| * doing Reduce Error Pruning (REP). |
| * The input for this aggregation is the result of an internal join |
| * between validation set's classification result and encoded table. |
| * |
| * @param class_count_array The array used to store the accumulated information. |
| * [0]: the total number of mis-classified samples |
| * [i]: the number of samples belonging to the ith class |
| * @param classified_class The predicted class based on our trained DT model. |
| * @param original_class The real class value provided in the validation set. |
| * @param max_num_of_classes The total number of distinct class values. |
| * |
| * @return An updated state array. |
| * |
| */ |
| Datum |
| dt_rep_aggr_class_count_sfunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType *pg_class_count = NULL; |
| int array_dim = 0; |
| int *p_array_dim = NULL; |
| int array_length = 0; |
| int64 *class_count = NULL; |
| int classified_class = PG_GETARG_INT32(1); |
| int original_class = PG_GETARG_INT32(2); |
| int max_num_of_classes = PG_GETARG_INT32(3); |
| bool rebuild_array = false; |
| |
| dt_check_error_value |
| ( |
| max_num_of_classes >= 2, |
| "invalid value: %d. " |
| "The number of classes must be greater than or equal to 2", |
| max_num_of_classes |
| ); |
| |
| dt_check_error_value |
| ( |
| original_class > 0 && original_class <= max_num_of_classes, |
| "invalid real class value: %d. " |
| "It must be in range from 1 to the number of classes", |
| original_class |
| ); |
| |
| dt_check_error_value |
| ( |
| classified_class > 0 && classified_class <= max_num_of_classes, |
| "invalid classified class value: %d. " |
| "It must be in range from 1 to the number of classes", |
| classified_class |
| ); |
| |
| /* test if the first argument (class count array) is null */ |
| if (PG_ARGISNULL(0)) |
| { |
| /* |
| * We assume the maximum number of classes is limited (up to millions), |
| * so that the allocated array won't break our memory limitation. |
| */ |
| class_count = palloc0(sizeof(int64) * (max_num_of_classes + 1)); |
| array_length = max_num_of_classes + 1; |
| rebuild_array = true; |
| |
| } |
| else |
| { |
| if (fcinfo->context && IsA(fcinfo->context, AggState)) |
| pg_class_count = PG_GETARG_ARRAYTYPE_P(0); |
| else |
| pg_class_count = PG_GETARG_ARRAYTYPE_P_COPY(0); |
| |
| dt_check_error |
| ( |
| pg_class_count, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(pg_class_count), |
| "dt_rep_aggr_class_count_sfunc cannot accept arrays with NULL values" |
| ); |
| |
| array_dim = ARR_NDIM(pg_class_count); |
| |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of class count array must be equal to 1", |
| array_dim |
| ); |
| |
| p_array_dim = ARR_DIMS(pg_class_count); |
| array_length = ArrayGetNItems(array_dim,p_array_dim); |
| class_count = (int64 *)ARR_DATA_PTR(pg_class_count); |
| |
| dt_check_error_value |
| ( |
| array_length == max_num_of_classes + 1, |
| "dt_rep_aggr_class_count_sfunc invalid array length: %d. " |
| "The length of class count array must be " |
| "equal to the total number classes + 1", |
| array_length |
| ); |
| } |
| |
| /* |
| * If the condition is met, then the current record |
| * has been mis-classified. Therefore, we will need |
| * to increase the first element. |
| */ |
| if (original_class != classified_class) |
| ++class_count[0]; |
| |
| /* In any sample, we will update the original class count */ |
| ++class_count[original_class]; |
| |
| if (rebuild_array) |
| { |
| /* construct a new array to keep the aggr states. */ |
| pg_class_count = |
| construct_array( |
| (Datum *)class_count, |
| array_length, |
| INT8OID, |
| sizeof(int64), |
| true, |
| 'd' |
| ); |
| } |
| |
| PG_RETURN_ARRAYTYPE_P(pg_class_count); |
| } |
| PG_FUNCTION_INFO_V1(dt_rep_aggr_class_count_sfunc); |
| |
| |
| /* |
| * @brief It takes two bigint arrays and add them together. |
| * If this function is used in an aggregation's context, |
| * we store the added information to |
| * |
| * @param 1 arg The array 1. |
| * @param 2 arg The array 2. |
| * |
| * @return The array with the added information. |
| * |
| */ |
| Datum |
| bigint_array_add |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType *pg_array1 = NULL; |
| int array_dim = 0; |
| int *p_array_dim = NULL; |
| int array_length = 0; |
| int64 *array1 = NULL; |
| |
| ArrayType *pg_array2 = NULL; |
| int array_dim2 = 0; |
| int *p_array_dim2 = NULL; |
| int array_length2 = 0; |
| int64 *array2 = NULL; |
| |
| if (PG_ARGISNULL(0) && PG_ARGISNULL(1)) |
| PG_RETURN_NULL(); |
| else if (PG_ARGISNULL(1) || PG_ARGISNULL(0)) |
| { |
| /* |
| * If one of the two array is null, |
| * just return the non-null array directly |
| */ |
| PG_RETURN_ARRAYTYPE_P(PG_ARGISNULL(1) ? |
| PG_GETARG_ARRAYTYPE_P(0) : |
| PG_GETARG_ARRAYTYPE_P(1)); |
| } |
| else |
| { |
| /* If both arrays are not null, we will add them together */ |
| if (fcinfo->context && IsA(fcinfo->context, AggState)) |
| { |
| /* We can safely modify the original array in an aggregate */ |
| pg_array1 = PG_GETARG_ARRAYTYPE_P(0); |
| } |
| else |
| { |
| /* |
| * We must not modify the original array out of aggregate's |
| * context. We simply use copy here to avoid the tedious work |
| * to allocate new arrays. There is no explicit facility to |
| * do that. |
| */ |
| pg_array1 = PG_GETARG_ARRAYTYPE_P_COPY(0); |
| } |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(pg_array1), |
| "bigint_array_add cannot accept arrays with NULL values" |
| ); |
| |
| dt_check_error |
| ( |
| pg_array1, |
| "invalid aggregation state array" |
| ); |
| |
| array_dim = ARR_NDIM(pg_array1); |
| |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of array1 must be equal to 1", |
| array_dim |
| ); |
| |
| p_array_dim = ARR_DIMS(pg_array1); |
| array_length = ArrayGetNItems(array_dim,p_array_dim); |
| array1 = (int64 *)ARR_DATA_PTR(pg_array1); |
| |
| pg_array2 = PG_GETARG_ARRAYTYPE_P(1); |
| array_dim2 = ARR_NDIM(pg_array2); |
| dt_check_error_value |
| ( |
| array_dim2 == 1, |
| "invalid array dimension: %d. " |
| "The dimension of array2 must be equal to 1", |
| array_dim2 |
| ); |
| |
| p_array_dim2 = ARR_DIMS(pg_array2); |
| array_length2 = ArrayGetNItems(array_dim2,p_array_dim2); |
| array2 = (int64 *)ARR_DATA_PTR(pg_array2); |
| |
| dt_check_error |
| ( |
| array_length == array_length2, |
| "the size of the two array must be the same" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(pg_array2), |
| "bigint_array_add cannot accept arrays with NULL values" |
| ); |
| |
| for (int index = 0; index < array_length; index++) |
| array1[index] += array2[index]; |
| |
| PG_RETURN_ARRAYTYPE_P(pg_array1); |
| } |
| } |
| PG_FUNCTION_INFO_V1(bigint_array_add); |
| |
| |
| /* |
| * @brief The final function for aggregating the class counts for REP. |
| * It takes the class count array produced by the step function. |
| * |
| * @param class_count_array The array used to store the accumulated information. |
| * [0]: the total number of mis-classified samples |
| * [i]: the number of samples belonging to the ith class |
| * |
| * @return A two-element array. The first element is the ID of the class that |
| * has the maximum number of samples represented by the root node of |
| * the subtree being processed. The second element is the number of |
| * reduced misclassified samples if the leaf nodes of the subtree are pruned. |
| * |
| */ |
| Datum |
| dt_rep_aggr_class_count_ffunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType *pg_class_count = PG_GETARG_ARRAYTYPE_P(0); |
| int array_dim = ARR_NDIM(pg_class_count); |
| |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of class count array must be equal to 1", |
| array_dim |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(pg_class_count), |
| "dt_rep_aggr_class_count_ffunc cannot accept arrays with NULL values" |
| ); |
| |
| int *p_array_dim = ARR_DIMS(pg_class_count); |
| int array_length = ArrayGetNItems(array_dim,p_array_dim); |
| int64 *class_count = (int64 *)ARR_DATA_PTR(pg_class_count); |
| int64 *result = palloc(sizeof(int64)*2); |
| |
| dt_check_error |
| ( |
| result, |
| "memory allocation failure" |
| ); |
| |
| int64 max = class_count[1]; |
| int64 sum = max; |
| int maxid = 1; |
| for(int i = 2; i < array_length; ++i) |
| { |
| if(max < class_count[i]) |
| { |
| max = class_count[i]; |
| maxid = i; |
| } |
| |
| sum += class_count[i]; |
| } |
| |
| /* maxid is the id of the class, which has the most samples */ |
| result[0] = maxid; |
| |
| /* |
| * (sum - max) is the number of mis-classified samples represented by |
| * the root node of the subtree being processed |
| * class_count_data[0] the total number of mis-classified samples |
| */ |
| result[1] = class_count[0] - (sum - max); |
| |
| ArrayType* result_array = |
| construct_array( |
| (Datum *)result, |
| 2, |
| INT8OID, |
| sizeof(int64), |
| true, |
| 'd' |
| ); |
| |
| PG_RETURN_ARRAYTYPE_P(result_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_rep_aggr_class_count_ffunc); |
| |
| |
| /* |
| * Calculating Split Criteria Values (SCVs for short) is a major |
| * step for growing a decision tree. While the formulas for different |
| * criteria are well defined and understood, the process for calculating |
| * them are not. In the database context, we can not follow the classical |
| * approach to keep all needed counts data in memory resident structures, |
| * as the memory requirement is usually proportional to the size of |
| * the train sets. For big data, this requirement is usually hard to fulfill. |
| * |
| * When building DT in databases, we try to leverage the DB's aggregation |
| * mechanism to do the same thing. This will also give us the opportunity |
| * to leverage database's parallelization infrastructure. |
| * |
| * For that purpose, we will process the train set into something we call |
| * Attribute Class Statistic (ACS for short) with a set of transformations |
| * and use aggregate functions to work on that. Details of how an ACS is |
| * generated can be found in DT design doc. The following is an example ACS |
| * for the golf data set: |
| * |
| * tid | nid | fid | split_value | is_cont | le | total |
| * -----+-----+-----+-------------+---------+-------+------- |
| * 1 | 1 | 4 | | f | {2,6} | {5,9} |
| * 1 | 1 | 4 | | f | {3,3} | {5,9} |
| * 1 | 1 | 3 | | f | {2,3} | {5,9} |
| * 1 | 1 | 3 | | f | {0,4} | {5,9} |
| * 1 | 1 | 3 | | f | {3,2} | {5,9} |
| * 1 | 1 | 2 | 64 | t | {0,1} | {5,9} |
| * 1 | 1 | 2 | 65 | t | {1,1} | {5,9} |
| * 1 | 1 | 2 | 68 | t | {1,2} | {5,9} |
| * 1 | 1 | 2 | 69 | t | {1,3} | {5,9} |
| * 1 | 1 | 2 | 70 | t | {1,4} | {5,9} |
| * 1 | 1 | 2 | 71 | t | {2,4} | {5,9} |
| * 1 | 1 | 2 | 72 | t | {3,5} | {5,9} |
| * 1 | 1 | 2 | 75 | t | {3,7} | {5,9} |
| * 1 | 1 | 2 | 80 | t | {4,7} | {5,9} |
| * 1 | 1 | 2 | 81 | t | {4,8} | {5,9} |
| * 1 | 1 | 2 | 83 | t | {4,9} | {5,9} |
| * 1 | 1 | 2 | 85 | t | {5,9} | {5,9} |
| * 1 | 1 | 1 | 65 | t | {0,1} | {5,9} |
| * 1 | 1 | 1 | 70 | t | {1,3} | {5,9} |
| * 1 | 1 | 1 | 75 | t | {1,4} | {5,9} |
| * 1 | 1 | 1 | 78 | t | {1,5} | {5,9} |
| * 1 | 1 | 1 | 80 | t | {2,7} | {5,9} |
| * 1 | 1 | 1 | 85 | t | {3,7} | {5,9} |
| * 1 | 1 | 1 | 90 | t | {4,8} | {5,9} |
| * 1 | 1 | 1 | 95 | t | {5,8} | {5,9} |
| * 1 | 1 | 1 | 96 | t | {5,9} | {5,9} |
| * (26 rows) |
| * |
| * The fields of ACS is explained below. |
| * tid The ID of the tree. |
| * |
| * nid The ID of the node in the specified tree. |
| * |
| * fid The ID of the selected feature. |
| * |
| * split_value |
| * For continuous features, each distinct value is one candidate |
| * split value. For discrete features, this field is always NULL. |
| * |
| * is_cont Whether the feature fid is continuous or not. This column can be |
| * eliminated if we check (split_value IS NOT NULL) |
| * |
| * le An m-element array, where m is the total number of distinct |
| * classes. le[i] is the number of samples whose class labels are |
| * class i and whose feature fid holds a distinct value equal to |
| * (for discrete features) or less-than or equal to (for continuous |
| * features) the feature value corresponding to the current row. The |
| * corresponding value is split_value for a continous feature, or one |
| * of its distinct values for a discrete feature. |
| * |
| * total An m-element array, where m is the total number of distinct classes. |
| * total[i] is the total number of samples whose class labels are class i. |
| * |
| * The rows are grouped by (tid, nid, fid, split_value). For a discrete feature, |
| * split_value always contains NULL. For a discrete feature with n distinct values, |
| * the group for that feature contains n rows. For a continuous feature, |
| * its group has only one row. For each group, we will calculate an SCV based |
| * on the specified splitting criterion and then choose the split with the |
| * maximum scv value. Because groups are independent, calculating SCVs can be done |
| * in parallel. |
| * |
| * Given the format of the input data stream, SCV calculation is different |
| * from using the SCV formulas directly. There is one row for each distinct |
| * value of a feature. For information gain, we can further transform the |
| * formula as below. We assume there are n distinct values for feature a and |
| * m distinct classes. We denote c[j] as the total number of samples whose class |
| * labels are class j. The cardinality of S is defined as |S|. |Si| is the |
| * total count of samples whose feature value is the ith distinct value. We |
| * denote d[i][j] as the count of samples whose class is j and feature value is |
| * the ith distinct value. |
| * |
| * We define the entropy of S, denoted as info(S), as: |
| * |
| * info(S) = (c[1]/|S|)log(|S|/c[1])+...+(c[m]/|S|)log(|S|/c[m]) |
| * |
| * Suppose using the distinct values of feature a, S is split into n subsets |
| * {S1, S2, ..., Sn}. We define info(S, a) as the weighted entropy of all the |
| * subsets after splitting S using feature a: |
| * |
| * info(S, a) = (|S1|/|S|)info(S1)+...+(|Sn|/|S|)info(Sn) |
| * |
| * The information gain of using a to split S can be defined as: |
| * |
| * IG(S, a)= info(S) - info(S, a) |
| * = log(t) - ( u + v - w ) / t, |
| * |
| * where t, u, v and w are defined as: |
| * |
| * t = |S| |
| * u = (c[1])log(c[1])+...+(c[m])log(c[m]) |
| * v = |S1|log(|S1|)+...+|Sn|log(|Sn|) |
| * w = (d[1][1])log(d[1][1])+(d[1][2])log(d[1][2])+...+(d[n][m])log(d[n][m]) |
| * |
| * In the above formulas, c[j] actually equals to total[j] within the ACS set. |
| * |S| equals to the sum of all elements in total. For the i-th distinct value |
| * of a discrete feature, d[i][j] equals to le[j] of the ACS row corresponding |
| * to the i-th value. With that, |Si| then equals to the sum of all d[i][j]s. |
| * |
| * Therefore, we can define an aggregate function to process the rows in ACS |
| * to calcualte the information gain of all features. The aggregate can calculate |
| * t, u, v, and w incrementally as the rows come in. Their intermediate values will |
| * be kept in the aggregate state variables. In the final function, we can get the |
| * information with log(t) - ( u + v - w ) / t. |
| * |
| * This way, we successfully remove the need to keep all attribute-class counts |
| * in a possibly very big in-memory array. The calculation process fits quite |
| * well with the aggregate mechanism, which are widely available on most data |
| * processing systems. |
| * |
| * When using gain ratio as the split criterion, besides IG(S, a), we also need |
| * Split_info(S, a), which can be defined as: |
| * |
| * Split_info(S, a) = (|S1|/|S|)log(|S|/|S1|)+...+(|Sn|/|S|)log(|S|/|Sn|) |
| * |
| * With ACS in place, we can get |S| and |Si| for each incoming row, based on which |
| * part of Split_info can be calculated. Then in the final function, the gain ratio |
| * of using a to split S can be calculated as: |
| * |
| * GR(S, a) = IG(S, a) / Split_info(s, a) |
| * |
| * For gini, the computation can be reduced to formula below. |
| * |
| * GI(S, a) = (W1/V1+W2/V2+...+Wn/Vn)/t - u/(t^2) |
| * |
| * where u,t,Wi and Vi is defined below. |
| * |
| * t = |S| |
| * u = (c[1])^2+(c[2])^2+...+(c[m])^2. |
| * Wi = (d[i][1])^2+(d[i][2])^2+...+(d[i][m])^2. |
| * Vi = d[i][1]+d[i][2]+...+d[i][m] |
| * |
| * We do not need to store Wi and Vi into separate variables. Instead, we only |
| * need two variables to keep the accumulated results of Wi and Vi. |
| * This way the gini index can also be calculated with aggregates using constant |
| * memory. |
| * |
| * Based on this understanding, we will define the following structures, |
| * types, and aggregate functions to calculate SCVs. |
| * |
| */ |
| |
| |
| /* |
| * We use a 9-element array to keep the state of the |
| * aggregate for calculating splitting criteria values (SCVs). |
| * The enum types defines which element of that array is used |
| * for which purpose. |
| * |
| */ |
| enum DT_SCV_STATE_ARRAY_INDEX |
| { |
| /* 1 infogain, 2 gainratio, 3 gini */ |
| SCV_CODE = 0, |
| |
| /* is continuous or not*/ |
| SCV_IS_CONT, |
| |
| /* the u component */ |
| SCV_U, |
| |
| /* the v component */ |
| SCV_V, |
| |
| /* the w component */ |
| SCV_W, |
| |
| /* the t component */ |
| SCV_T, |
| |
| /* the total number of samples in the training set */ |
| SCV_SAMPLE_TOTAL, |
| |
| /* the ID of the class with the largest number of samples */ |
| SCV_MAX_CLASS_ID, |
| |
| /* the total number of samples belonging to MAX_CLASS */ |
| SCV_MAX_CLASS_COUNT |
| |
| }; |
| |
| |
| /* |
| * We use a 5-element array to keep the final result of the |
| * aggregate for calculating splitting criteria values (SCVs). |
| * The enum types defines which element of that array is used |
| * for which purpose. |
| * |
| */ |
| enum DT_SCV_FINAL_ARRAY_INDEX |
| { |
| /* Calculated SCV */ |
| SCV_FINAL_VALUE = 0, |
| |
| /* Whether the selected feature is continuous or discrete */ |
| SCV_FINAL_IS_CONT, |
| |
| /* The ID of the class with the largest number of samples */ |
| SCV_FINAL_CLASS_ID, |
| |
| /* The percentage of samples belonging to MAX_CLASS */ |
| SCV_FINAL_CLASS_PROB, |
| |
| /* Total count of samples */ |
| SCV_FINAL_TOTAL_COUNT |
| }; |
| |
| |
| /* Codes for different split criteria. */ |
| #define DT_SC_INFOGAIN 1 |
| #define DT_SC_GAINRATIO 2 |
| #define DT_SC_GINI 3 |
| |
| |
| /* |
| * @brief The step function for the aggregation used to find the best SCV. |
| * |
| * @param best_scv_array This array stores the internal aggregation state. Its |
| * definition is the same as the returned array. |
| * @param scv_final_array This array contains the computed splitting criteria |
| * values. Please refer to the definition of |
| * DT_SCV_FINAL_ARRAY_INDEX. |
| * @param fid The ID of the feature used by this split. |
| * @param split_value The split_value for this split. For discrete features, |
| * it is always NULL. |
| * |
| * @return A seven-element array. Please refer to the definition of |
| * DT_SCV_FINAL_ARRAY_INDEX for the first five elements. The |
| * last two elements of this array is fid and split_value. |
| */ |
| Datum |
| dt_best_scv_sfunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType* best_scv_array = NULL; |
| if (fcinfo->context && IsA(fcinfo->context, AggState)) |
| best_scv_array = PG_GETARG_ARRAYTYPE_P(0); |
| else |
| best_scv_array = PG_GETARG_ARRAYTYPE_P_COPY(0); |
| |
| dt_check_error |
| ( |
| best_scv_array, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(best_scv_array), |
| "the first array passed to dt_best_scv_sfunc cannot contain NULL values" |
| ); |
| |
| int array_dim = ARR_NDIM(best_scv_array); |
| |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of scv state array must be equal to 1", |
| array_dim |
| ); |
| |
| int* p_array_dim = ARR_DIMS(best_scv_array); |
| int array_length = ArrayGetNItems(array_dim, p_array_dim); |
| |
| dt_check_error_value |
| ( |
| array_length == SCV_FINAL_TOTAL_COUNT + 3, |
| "dt_best_scv_sfunc invalid array length: %d", |
| array_length |
| ); |
| |
| float8 *best_scv_data = (float8 *)ARR_DATA_PTR(best_scv_array); |
| dt_check_error |
| ( |
| best_scv_data, |
| "invalid aggregation data array" |
| ); |
| |
| // scv array |
| ArrayType* scv_array = PG_GETARG_ARRAYTYPE_P(1); |
| dt_check_error(scv_array, "invalid scv array"); |
| array_dim = ARR_NDIM(scv_array); |
| dt_check_error(array_dim == 1, |
| "the dimension of scv array must be equal to 1"); |
| dt_check_error |
| ( |
| !ARR_HASNULL(scv_array), |
| "the second array passed to dt_best_scv_sfunc cannot contain NULL values" |
| ); |
| |
| p_array_dim = ARR_DIMS(scv_array); |
| array_length = ArrayGetNItems(array_dim, p_array_dim); |
| |
| dt_check_error_value |
| ( |
| array_length == SCV_FINAL_TOTAL_COUNT + 1, |
| "dt_best_scv_sfunc invalid array length: %d", |
| array_length |
| ); |
| |
| float8 *scv_data = (float8 *)ARR_DATA_PTR(scv_array); |
| dt_check_error(scv_data, "invalid scv data array"); |
| |
| float8 scvdiff = 0.0; |
| int i = 0; |
| int fid = PG_GETARG_INT32(2); |
| float8 sp_val = PG_GETARG_FLOAT8(3); |
| |
| scvdiff = scv_data[SCV_FINAL_VALUE] - best_scv_data[SCV_FINAL_VALUE]; |
| |
| dtelog( NOTICE, |
| "cur:%lf, %lf, best:%lf, %lf", |
| scv_data[SCV_FINAL_VALUE], |
| fid, |
| best_scv_data[SCV_FINAL_VALUE], |
| best_scv_data[SCV_FINAL_TOTAL_COUNT + 1]); |
| |
| /* |
| * When the SCVs for two features tie, we will use the fid and split_value |
| * as the tie breakers. This ensures that we consistently choose the same |
| * feature/splitting value as the split. |
| */ |
| if ( (scvdiff > DT_EPSILON) || |
| ( |
| dt_is_float_zero(scvdiff) && |
| ( |
| (best_scv_data[SCV_FINAL_TOTAL_COUNT + 1] < fid) || |
| ( dt_is_float_zero |
| ( |
| best_scv_data[SCV_FINAL_TOTAL_COUNT + 1]-fid |
| ) && |
| best_scv_data[SCV_FINAL_TOTAL_COUNT + 2] < sp_val |
| ) |
| ) |
| ) |
| ) |
| { |
| for (i = 0; i <= SCV_FINAL_TOTAL_COUNT; ++i) |
| { |
| best_scv_data[i] = scv_data[i]; |
| } |
| |
| best_scv_data[i] = fid; |
| best_scv_data[i + 1] = sp_val; |
| } |
| |
| PG_RETURN_ARRAYTYPE_P(best_scv_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_best_scv_sfunc); |
| |
| |
| /* |
| * @brief The pre-function for finding the best splitting criteria values. |
| * |
| * @param scv_state_array The array from sfunc1. |
| * @param scv_state_array The array from sfunc2. |
| * |
| * @return A seven element array. Please refer to the definition of |
| * DT_SCV_FINAL_ARRAY_INDEX for the first five elements. The |
| * last two elements of this array is fid and split_value. |
| * |
| */ |
| Datum |
| dt_best_scv_prefunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType* scv_state_array = NULL; |
| if (fcinfo->context && IsA(fcinfo->context, AggState)) |
| scv_state_array = PG_GETARG_ARRAYTYPE_P(0); |
| else |
| scv_state_array = PG_GETARG_ARRAYTYPE_P_COPY(0); |
| |
| dt_check_error |
| ( |
| scv_state_array, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(scv_state_array), |
| "the first array passed to dt_best_scv_prefunc cannot contain NULL values" |
| ); |
| |
| int array_dim = ARR_NDIM(scv_state_array); |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of scv state array must be equal to 1", |
| array_dim |
| ); |
| |
| int *p_array_dim = ARR_DIMS(scv_state_array); |
| int array_length = ArrayGetNItems(array_dim, p_array_dim); |
| dt_check_error_value |
| ( |
| array_length == SCV_FINAL_TOTAL_COUNT + 3, |
| "dt_scv_aggr_prefunc invalid array length: %d", |
| array_length |
| ); |
| |
| /* the scv state data from a segment */ |
| float8 *scv_state_data = (float8 *)ARR_DATA_PTR(scv_state_array); |
| dt_check_error |
| ( |
| scv_state_data, |
| "invalid aggregation data array" |
| ); |
| |
| ArrayType* scv_state_array2 = PG_GETARG_ARRAYTYPE_P(1); |
| dt_check_error |
| ( |
| scv_state_array2, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(scv_state_array2), |
| "the second array passed to dt_best_scv_prefunc cannot contain NULL values" |
| ); |
| |
| array_dim = ARR_NDIM(scv_state_array2); |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of scv state array must be equal to 1", |
| array_dim |
| ); |
| p_array_dim = ARR_DIMS(scv_state_array2); |
| array_length = ArrayGetNItems(array_dim, p_array_dim); |
| dt_check_error_value |
| ( |
| array_length == SCV_FINAL_TOTAL_COUNT + 3, |
| "dt_scv_aggr_prefunc invalid array length: %d", |
| array_length |
| ); |
| |
| /* the scv state data from another segment */ |
| float8 *scv_state_data2 = (float8 *)ARR_DATA_PTR(scv_state_array2); |
| dt_check_error |
| ( |
| scv_state_data2, |
| "invalid aggregation data array" |
| ); |
| |
| float8 scvdiff = scv_state_data2[SCV_FINAL_VALUE] - |
| scv_state_data[SCV_FINAL_VALUE]; |
| int i = 0; |
| |
| float8 array2_fid = scv_state_data2[SCV_FINAL_TOTAL_COUNT + 1]; |
| float8 array2_sp_val = scv_state_data2[SCV_FINAL_TOTAL_COUNT + 2]; |
| |
| float8 array1_fid = scv_state_data[SCV_FINAL_TOTAL_COUNT + 1]; |
| float8 array1_sp_val = scv_state_data[SCV_FINAL_TOTAL_COUNT + 2]; |
| |
| /* |
| * When the SCVs for two features tie, we will use the fid and split_value |
| * as the tie breakers. This ensures that we consistently choose the same |
| * feature/splitting value as the split. |
| */ |
| if ((scvdiff > DT_EPSILON) || |
| ( |
| dt_is_float_zero(scvdiff) && |
| ( |
| (array1_fid < array2_fid) || |
| ( dt_is_float_zero |
| ( |
| array1_fid-array2_fid |
| ) && |
| array1_sp_val < array2_sp_val |
| ) |
| ) |
| ) |
| ) |
| { |
| for (i = 0; i <= SCV_FINAL_TOTAL_COUNT + 2; ++i) |
| { |
| scv_state_data[i] = scv_state_data2[i]; |
| } |
| } |
| |
| PG_RETURN_ARRAYTYPE_P(scv_state_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_best_scv_prefunc); |
| |
| |
| /* |
| * @brief The step function for the aggregation of SCV. |
| * It accumulates all the information for SCV calculation |
| * and stores to a nine-element array. |
| * |
| * @param scv_state_array The array used to accumulate all the information |
| * for the calculation of SCV. |
| * Please refer to the definition of |
| * DT_SCV_STATE_ARRAY_INDEX. |
| * @param sc_code 1- infogain; 2- gainratio; 3- gini. |
| * @param feature_val The feature value of current record under processing. |
| * @param class The class of current record under processing. |
| * @param is_cont_feature True - The feature is continuous. |
| * False - The feature is discrete. |
| * @param le The le component of an ACS record. |
| * @param total The total component of an ACS record. |
| * @param true_total_count If there is any missing value, true_total_count is larger |
| * than the total count computed in the aggregation. Thus, |
| * we should multiply a ratio for the computed gain. |
| * |
| * @return A nine-element array. Please refer to the definition of |
| * DT_SCV_STATE_ARRAY_INDEX for the detailed information of this array. |
| */ |
| Datum |
| dt_scv_aggr_sfunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType* scv_state_array = NULL; |
| if (fcinfo->context && IsA(fcinfo->context, AggState)) |
| scv_state_array = PG_GETARG_ARRAYTYPE_P(0); |
| else |
| scv_state_array = PG_GETARG_ARRAYTYPE_P_COPY(0); |
| |
| dt_check_error |
| ( |
| scv_state_array, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(scv_state_array), |
| "the first array passed to dt_scv_aggr_sfunc cannot contain NULL values" |
| ); |
| |
| int array_dim = ARR_NDIM(scv_state_array); |
| |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of scv state array must be equal to 1", |
| array_dim |
| ); |
| |
| int* p_array_dim = ARR_DIMS(scv_state_array); |
| int array_length = ArrayGetNItems(array_dim, p_array_dim); |
| |
| dt_check_error_value |
| ( |
| array_length == SCV_MAX_CLASS_COUNT + 1, |
| "dt_scv_aggr_sfunc invalid array length: %d", |
| array_length |
| ); |
| |
| float8 *scv_state_data = (float8 *)ARR_DATA_PTR(scv_state_array); |
| dt_check_error |
| ( |
| scv_state_data, |
| "invalid aggregation data array" |
| ); |
| |
| int sc_type = PG_GETARG_INT32(1); |
| bool is_cont_feat = PG_ARGISNULL(2) ? 0 : PG_GETARG_BOOL(2); |
| int num_class = PG_ARGISNULL(3) ? 0 : PG_GETARG_INT32(3); |
| |
| // we only read the data from le-array and total-array |
| ArrayType* le_array = PG_GETARG_ARRAYTYPE_P(4); |
| dt_check_error(le_array, "invalid le array"); |
| array_dim = ARR_NDIM(le_array); |
| dt_check_error(array_dim == 1, "the dimemsion of le array must be 1"); |
| p_array_dim = ARR_DIMS(le_array); |
| array_length = ArrayGetNItems(array_dim, p_array_dim); |
| dt_check_error |
| ( |
| array_length == num_class, |
| "the size of le array must be the number of class" |
| ); |
| float8* le_data = (float8 *)ARR_DATA_PTR(le_array); |
| |
| // total array |
| ArrayType* total_array = PG_GETARG_ARRAYTYPE_P(5); |
| dt_check_error(total_array, "invalid total array"); |
| array_dim = ARR_NDIM(total_array); |
| dt_check_error(array_dim == 1, "the dimemsion of total array must be 1"); |
| p_array_dim = ARR_DIMS(total_array); |
| array_length = ArrayGetNItems(array_dim, p_array_dim); |
| dt_check_error |
| ( |
| array_length == num_class, |
| "the size of total array must be the number of class" |
| ); |
| float8* total_data = (float8 *)ARR_DATA_PTR(total_array); |
| |
| int i = 0; |
| float8 feat_le = 0.0; |
| float8 feat_cnts = 0.0; |
| |
| dt_check_error_value |
| ( |
| DT_SC_INFOGAIN == sc_type || |
| DT_SC_GAINRATIO == sc_type || |
| DT_SC_GINI == sc_type, |
| "invalid split criterion: %d. " |
| "It must be 1(infogain), 2(gainratio) or 3(gini)", |
| sc_type |
| ); |
| |
| scv_state_data[SCV_CODE] = sc_type; |
| scv_state_data[SCV_SAMPLE_TOTAL] = PG_ARGISNULL(6) ? 0 : PG_GETARG_INT64(6); |
| scv_state_data[SCV_IS_CONT] = is_cont_feat; |
| |
| dtelog(NOTICE, "array: %lf, %lf, %lf, %lf", |
| le_data[0], le_data[1], total_data[0], total_data[1]); |
| |
| // processing the continuous feature |
| if (is_cont_feat) |
| { |
| // the definitions of t, u, v and w are the same between IG and GR |
| if (DT_SC_INFOGAIN == sc_type || DT_SC_GAINRATIO == sc_type) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = 0; |
| for (i = 0; i < num_class; ++i) |
| { |
| dt_check_error_value |
| ( |
| total_data[i] >= le_data[i], |
| "the difference: %lf", |
| total_data[i] - le_data[i] |
| ); |
| |
| feat_le += le_data[i]; |
| feat_cnts += total_data[i]; |
| |
| // max class count and ID |
| if (scv_state_data[SCV_MAX_CLASS_COUNT] < total_data[i]) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = total_data[i]; |
| scv_state_data[SCV_MAX_CLASS_ID] = i + 1; |
| } |
| |
| // calculate the statistic info for class |
| scv_state_data[SCV_U] += dt_cal_log(total_data[i]); |
| |
| // calculate the statistic info for the class label and the feature value |
| scv_state_data[SCV_W] += |
| (dt_cal_log(le_data[i]) + dt_cal_log(total_data[i] - le_data[i])); |
| } |
| |
| // calculate the statistic info for the feature |
| scv_state_data[SCV_V] += |
| (dt_cal_log(feat_le) + dt_cal_log(feat_cnts - feat_le)); |
| |
| // calculate the number of non-null elements |
| scv_state_data[SCV_T] = feat_cnts; |
| } |
| else |
| { |
| // gini index |
| scv_state_data[SCV_MAX_CLASS_COUNT] = 0; |
| for (i = 0; i < num_class; ++i) |
| { |
| dt_check_error_value |
| ( |
| total_data[i] >= le_data[i], |
| "the difference: %lf", |
| total_data[i] - le_data[i] |
| ); |
| |
| feat_le += le_data[i]; |
| feat_cnts += total_data[i]; |
| |
| // max class count and ID |
| if (scv_state_data[SCV_MAX_CLASS_COUNT] < total_data[i]) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = total_data[i]; |
| scv_state_data[SCV_MAX_CLASS_ID] = i + 1; |
| } |
| |
| // calculate the statistic info for class |
| scv_state_data[SCV_U] += dt_cal_sqr(total_data[i]); |
| } |
| |
| // calculate the number of non-null elements |
| scv_state_data[SCV_T] = feat_cnts; |
| |
| // calculate the statistic info for the class label and the feature value |
| feat_cnts -= feat_le; |
| |
| for (i = 0; i < num_class; ++i) |
| { |
| scv_state_data[SCV_W] += |
| ( |
| dt_cal_sqr_div(le_data[i], feat_le) + |
| dt_cal_sqr_div(total_data[i] - le_data[i], feat_cnts) |
| ); |
| } |
| } |
| } |
| else // processing the discrete feature |
| { |
| // the definitions of t, u, v and w are the same between IG and GR |
| if (DT_SC_INFOGAIN == sc_type || DT_SC_GAINRATIO == sc_type) |
| { |
| /* |
| * calculate the value of count, the max class and class info |
| * we only need to write once |
| */ |
| if (dt_is_float_zero(scv_state_data[SCV_T])) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = 0; |
| for (i = 0; i < num_class; ++i) |
| { |
| feat_cnts += total_data[i]; |
| |
| // max class count and ID |
| if (scv_state_data[SCV_MAX_CLASS_COUNT] < total_data[i]) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = total_data[i]; |
| scv_state_data[SCV_MAX_CLASS_ID] = i + 1; |
| } |
| |
| // calculate the statistic info for class |
| scv_state_data[SCV_U] += dt_cal_log(total_data[i]); |
| } |
| |
| // calculate the count |
| scv_state_data[SCV_T] = feat_cnts; |
| } |
| |
| // calculate the statistic info for the class label and the feature value |
| for (i = 0; i < num_class; ++i) |
| { |
| scv_state_data[SCV_W] += dt_cal_log(le_data[i]); |
| feat_le += le_data[i]; |
| } |
| |
| // calculate the statistic info for the feature |
| scv_state_data[SCV_V] += dt_cal_log(feat_le); |
| } |
| else |
| { |
| /* |
| * calculate the value of count, the max class and class info |
| * we only need to write once |
| */ |
| if (dt_is_float_zero(scv_state_data[SCV_T])) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = 0; |
| for (i = 0; i < num_class; ++i) |
| { |
| feat_cnts += total_data[i]; |
| |
| // max class count and ID |
| if (scv_state_data[SCV_MAX_CLASS_COUNT] < total_data[i]) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = total_data[i]; |
| scv_state_data[SCV_MAX_CLASS_ID] = i + 1; |
| } |
| |
| // calculate the statistic info for class |
| scv_state_data[SCV_U] += dt_cal_sqr(total_data[i]); |
| } |
| |
| // calculate the count |
| scv_state_data[SCV_T] = feat_cnts; |
| } |
| |
| // calculate the statistic info for the class label and the feature value |
| for (i = 0; i < num_class; ++i) |
| { |
| feat_le += le_data[i]; |
| } |
| |
| for (i = 0; i < num_class; ++i) |
| { |
| scv_state_data[SCV_W] += dt_cal_sqr_div(le_data[i], feat_le); |
| } |
| } |
| } |
| dtelog(NOTICE, "data: %lf, %lf, %lf, %lf", |
| scv_state_data[SCV_W], |
| scv_state_data[SCV_U], |
| scv_state_data[SCV_V], |
| scv_state_data[SCV_T]); |
| |
| PG_RETURN_ARRAYTYPE_P(scv_state_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_scv_aggr_sfunc); |
| |
| |
| /* |
| * @brief The pre-function for the aggregation of SCV. It takes the state |
| * array produced by two sfunc and combine them together. |
| * |
| * @param scv_state_array The array from sfunc1. |
| * @param scv_state_array The array from sfunc2. |
| * |
| * @return A nine-element array. Please refer to the definition of |
| * DT_SCV_STATE_ARRAY_INDEX for the detailed information of this array. |
| * |
| */ |
| Datum |
| dt_scv_aggr_prefunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType* scv_state_array = NULL; |
| if (fcinfo->context && IsA(fcinfo->context, AggState)) |
| scv_state_array = PG_GETARG_ARRAYTYPE_P(0); |
| else |
| scv_state_array = PG_GETARG_ARRAYTYPE_P_COPY(0); |
| |
| dt_check_error |
| ( |
| scv_state_array, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(scv_state_array), |
| "the first array passed to dt_scv_aggr_prefunc cannot contain NULL values" |
| ); |
| |
| int array_dim = ARR_NDIM(scv_state_array); |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of scv state array must be equal to 1", |
| array_dim |
| ); |
| |
| int *p_array_dim = ARR_DIMS(scv_state_array); |
| int array_length = ArrayGetNItems(array_dim, p_array_dim); |
| dt_check_error_value |
| ( |
| array_length == SCV_MAX_CLASS_COUNT + 1, |
| "dt_scv_aggr_prefunc invalid array length: %d", |
| array_length |
| ); |
| |
| /* the scv state data from a segment */ |
| float8 *scv_state_data = (float8 *)ARR_DATA_PTR(scv_state_array); |
| dt_check_error |
| ( |
| scv_state_data, |
| "invalid aggregation data array" |
| ); |
| |
| ArrayType* scv_state_array2 = PG_GETARG_ARRAYTYPE_P(1); |
| dt_check_error |
| ( |
| scv_state_array2, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(scv_state_array2), |
| "the second array passed to dt_scv_aggr_prefunc cannot contain NULL values" |
| ); |
| |
| array_dim = ARR_NDIM(scv_state_array2); |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of scv state array must be equal to 1", |
| array_dim |
| ); |
| p_array_dim = ARR_DIMS(scv_state_array2); |
| array_length = ArrayGetNItems(array_dim, p_array_dim); |
| dt_check_error_value |
| ( |
| array_length == SCV_MAX_CLASS_COUNT + 1, |
| "dt_scv_aggr_prefunc invalid array length: %d", |
| array_length |
| ); |
| |
| /* the scv state data from another segment */ |
| float8 *scv_state_data2 = (float8 *)ARR_DATA_PTR(scv_state_array2); |
| dt_check_error |
| ( |
| scv_state_data2, |
| "invalid aggregation data array" |
| ); |
| |
| /* |
| * For the following data, such as entropy, gini and split info, |
| * we need to combine the accumulated value from multiple segments. |
| */ |
| scv_state_data[SCV_W] += scv_state_data2[SCV_W]; |
| scv_state_data[SCV_V] += scv_state_data2[SCV_V]; |
| |
| if (dt_is_float_zero(scv_state_data[SCV_T])) |
| { |
| scv_state_data[SCV_T] = scv_state_data2[SCV_T]; |
| scv_state_data[SCV_U] = scv_state_data2[SCV_U]; |
| scv_state_data[SCV_IS_CONT] = scv_state_data2[SCV_IS_CONT]; |
| scv_state_data[SCV_CODE] = scv_state_data2[SCV_CODE]; |
| } |
| |
| /* |
| * We should compare the results from different segments and |
| * find the class with maximum samples. |
| */ |
| if (scv_state_data[SCV_MAX_CLASS_COUNT] < |
| scv_state_data2[SCV_MAX_CLASS_COUNT]) |
| { |
| scv_state_data[SCV_MAX_CLASS_COUNT] = |
| scv_state_data2[SCV_MAX_CLASS_COUNT]; |
| scv_state_data[SCV_MAX_CLASS_ID] = |
| scv_state_data2[SCV_MAX_CLASS_ID]; |
| } |
| |
| PG_RETURN_ARRAYTYPE_P(scv_state_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_scv_aggr_prefunc); |
| |
| |
| /* |
| * @brief The final function for the aggregation of SCV. |
| * It takes the state array produced by the prefunc and produces |
| * a five-element array. |
| * |
| * @param scv_state_array The array containing all the information for the |
| * calculation of SCV. |
| * |
| * @return A five-element array. Please refer to the definition of |
| * DT_SCV_FINAL_ARRAY_INDEX for the detailed information of this array. |
| * |
| */ |
| Datum |
| dt_scv_aggr_ffunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType* scv_state_array = PG_GETARG_ARRAYTYPE_P(0); |
| dt_check_error |
| ( |
| scv_state_array, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_HASNULL(scv_state_array), |
| "the first array passed to dt_scv_aggr_ffunc cannot contain NULL values" |
| ); |
| |
| int array_dim = ARR_NDIM(scv_state_array); |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of state array must be equal to 1", |
| array_dim |
| ); |
| |
| int* p_array_dim = ARR_DIMS(scv_state_array); |
| int array_length = ArrayGetNItems(array_dim, p_array_dim); |
| |
| dt_check_error_value |
| ( |
| array_length == SCV_MAX_CLASS_COUNT + 1, |
| "dt_scv_aggr_ffunc: invalid array length: %d", |
| array_length |
| ); |
| |
| dtelog(NOTICE, "dt_scv_aggr_ffunc array_length:%d",array_length); |
| |
| float8 *scv_state_data = (float8 *)ARR_DATA_PTR(scv_state_array); |
| dt_check_error |
| ( |
| scv_state_data, |
| "invalid aggregation data array" |
| ); |
| |
| |
| dtelog(NOTICE, "final: %lf, %lf, %lf, %lf", |
| scv_state_data[SCV_W], |
| scv_state_data[SCV_U], |
| scv_state_data[SCV_V], |
| scv_state_data[SCV_T]); |
| |
| int result_size = SCV_FINAL_TOTAL_COUNT + 1; |
| float8 *result = palloc0(sizeof(float8) * result_size); |
| float8 tmp = 0.0; |
| |
| dtelog( NOTICE, |
| "total:%lf, %lf", |
| scv_state_data[SCV_SAMPLE_TOTAL], |
| scv_state_data[SCV_T]); |
| |
| /* If true total count is 0/null, there is no missing values*/ |
| if (dt_is_float_zero(scv_state_data[SCV_SAMPLE_TOTAL])) |
| { |
| scv_state_data[SCV_SAMPLE_TOTAL] = |
| scv_state_data[SCV_T]; |
| } |
| |
| /* true total count should be greater than 0*/ |
| dt_check_error |
| ( |
| scv_state_data[SCV_SAMPLE_TOTAL] > 0 && scv_state_data[SCV_T] > 0, |
| "true total count should be greater than 0" |
| ); |
| |
| /* |
| * For the following elements, such as max class id, we should copy |
| * them from step function array to final function array for returning. |
| */ |
| result[SCV_FINAL_CLASS_ID] = scv_state_data[SCV_MAX_CLASS_ID]; |
| result[SCV_FINAL_IS_CONT] = scv_state_data[SCV_IS_CONT]; |
| result[SCV_FINAL_TOTAL_COUNT] = scv_state_data[SCV_SAMPLE_TOTAL]; |
| result[SCV_FINAL_CLASS_PROB] = |
| scv_state_data[SCV_MAX_CLASS_COUNT] / scv_state_data[SCV_SAMPLE_TOTAL]; |
| |
| |
| if (DT_SC_INFOGAIN == ((int)scv_state_data[SCV_CODE])) |
| { |
| // info gain |
| result[SCV_FINAL_VALUE] = |
| log(scv_state_data[SCV_T]) - |
| ((scv_state_data[SCV_U] + scv_state_data[SCV_V] - |
| scv_state_data[SCV_W]) / scv_state_data[SCV_T]); |
| } |
| else if (DT_SC_GAINRATIO == ((int)scv_state_data[SCV_CODE])) |
| { |
| // gain ratio |
| tmp = dt_cal_log(scv_state_data[SCV_T]) - scv_state_data[SCV_V]; |
| result[SCV_FINAL_VALUE] = dt_is_float_zero(tmp) ? 0.0 : |
| 1 + (scv_state_data[SCV_W] - scv_state_data[SCV_U]) / tmp; |
| } |
| else |
| { |
| //gini index |
| result[SCV_FINAL_VALUE] = |
| (scv_state_data[SCV_W] / scv_state_data[SCV_T]) - |
| (scv_state_data[SCV_U]) / dt_cal_sqr(scv_state_data[SCV_T]); |
| } |
| |
| result[SCV_FINAL_VALUE] *= (scv_state_data[SCV_T] / |
| scv_state_data[SCV_SAMPLE_TOTAL]); |
| |
| dtelog(NOTICE, "final value: %lf", result[SCV_FINAL_VALUE]); |
| |
| ArrayType* result_array = |
| construct_array( |
| (Datum *)result, |
| result_size, |
| FLOAT8OID, |
| sizeof(float8), |
| true, |
| 'd' |
| ); |
| |
| PG_RETURN_ARRAYTYPE_P(result_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_scv_aggr_ffunc); |
| |
| |
| /* |
| * @brief The function samples a set of integer values between low and high. |
| * The sample method is 'sample with replacement', which means a sample |
| * could be chosen multiple times. |
| * |
| * @param sample_size Number of records to be sampled. |
| * @param low Low limit of sampled values. |
| * @param high High limit of sampled values. |
| * @param seed Seed for random number. |
| * |
| * @return A set of integer values sampled randomly between [low, high]. |
| * |
| */ |
| Datum |
| dt_sample_within_range |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| FuncCallContext *funcctx = NULL; |
| int64 call_cntr = 0; |
| int64 max_calls = 0; |
| |
| /* stuff done only on the first call of the function */ |
| if (SRF_IS_FIRSTCALL()) |
| { |
| MemoryContext oldcontext; |
| |
| /* create a function context for cross-call persistence */ |
| funcctx = SRF_FIRSTCALL_INIT(); |
| |
| /* switch to memory context appropriate for multiple function calls */ |
| oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx); |
| int64 low = PG_GETARG_INT64(1); |
| int64 high = PG_GETARG_INT64(2); |
| |
| dt_check_error |
| ( |
| low<=high && low>=0, |
| "The low margin must not be greater than the high margin. " |
| "And negative numbers are not accepted" |
| ); |
| |
| /* total number of samples to be returned */ |
| funcctx->max_calls = PG_GETARG_INT64(0); |
| MemoryContextSwitchTo(oldcontext); |
| } |
| |
| /* stuff done on every call of the function */ |
| funcctx = SRF_PERCALL_SETUP(); |
| call_cntr = funcctx->call_cntr; |
| max_calls = funcctx->max_calls; |
| |
| /* when there is more records to return */ |
| if (call_cntr < max_calls) |
| { |
| int64 low = PG_GETARG_INT64(1); |
| int64 high = PG_GETARG_INT64(2); |
| float8 rand_num = (random()/(float8)(RAND_MAX+1.0)); |
| int64 selection = (int64)(rand_num*(high-low+1)+low); |
| SRF_RETURN_NEXT(funcctx, Int64GetDatum(selection)); |
| } |
| |
| /* when there is no more records left */ |
| SRF_RETURN_DONE(funcctx); |
| } |
| PG_FUNCTION_INFO_V1(dt_sample_within_range); |
| |
| |
| /* |
| * @brief Retrieve the specified number of unique features for a node. |
| * Discrete features used by ancestor nodes will be excluded. |
| * If the number of remaining features is less or equal than the |
| * requested number of features, then all the remaining features |
| * will be returned. Otherwise, we will sample the requested |
| * number of features from the remaining features. |
| * |
| * @param num_req_features The number of requested features. |
| * @param num_features The total number of features. |
| * @param nid The ID of the node for which the |
| * features are sampled. |
| * @param dp_fids The IDs of the discrete features |
| * used by the ancestors. |
| * |
| * @return An array containing all the IDs of sampled features. |
| * |
| */ |
| Datum |
| dt_get_node_split_fids |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| int32 num_req_features = PG_ARGISNULL(0) ? 0 : PG_GETARG_INT32(0); |
| int32 num_features = PG_ARGISNULL(1) ? 0 : PG_GETARG_INT32(1); |
| int32 nid = PG_ARGISNULL(2) ? 0 : PG_GETARG_INT32(2); |
| |
| dt_check_error |
| ( |
| num_req_features > 0 && num_features > 0 && nid > 0, |
| "the first three arguments can not be null" |
| ); |
| |
| int32 n_remain_fids = num_features; |
| int32 *dp_fids = NULL; |
| Datum *result = NULL; |
| ArrayType *result_array = NULL; |
| int32 power_uint32 = 5; |
| |
| /* bit map for whether a feature was chosen before or not */ |
| uint32 n_bitmap = (num_features + (1 << power_uint32) - 1) >> power_uint32; |
| uint32 *bitmap = (uint32*)palloc0(n_bitmap * sizeof(uint32)); |
| |
| if (!PG_ARGISNULL(3)) |
| { |
| ArrayType *dp_fids_array = PG_GETARG_ARRAYTYPE_P(3); |
| int dim_nids = ARR_NDIM(dp_fids_array); |
| dt_check_error_value |
| ( |
| 1 == dim_nids, |
| "invalid array dimension: %d. " |
| "The dimension of the array must be equal to 1", |
| dim_nids |
| ); |
| |
| dt_check_error_value |
| ( |
| !ARR_HASNULL(dp_fids_array), |
| "the first array passed to %s cannot contain NULL values", |
| __FUNCTION__ |
| ); |
| |
| int *p_dim_nids = ARR_DIMS(dp_fids_array); |
| int len_nids = ArrayGetNItems(dim_nids, p_dim_nids); |
| dt_check_error_value |
| ( |
| len_nids <= num_features, |
| "dt_get_node_split_fids invalid array length: %d", |
| len_nids |
| ); |
| |
| dp_fids = (int *)ARR_DATA_PTR(dp_fids_array); |
| dt_check_error (dp_fids, "invalid data array"); |
| |
| /* |
| * the feature ID starts from 1 |
| * if the feature was already chosen, then set the bit to 1 |
| */ |
| for (int i = 0; i < len_nids; ++i) |
| bitmap[(dp_fids[i] - 1) >> power_uint32] |= |
| dt_fid_mask((dp_fids[i] - 1), power_uint32); |
| |
| n_remain_fids = num_features - len_nids; |
| } |
| |
| result = palloc0 |
| ( |
| ((n_remain_fids > num_req_features) ? |
| num_req_features : |
| n_remain_fids ? n_remain_fids : 1) * sizeof(Datum) |
| ); |
| /* |
| * Sample the features if the number of remaining features is greater |
| * than the request number |
| */ |
| if (n_remain_fids > num_req_features) |
| { |
| for (int i = 0; i < num_req_features; ++i) |
| { |
| int32 fid = 0; |
| |
| /* |
| * if sample a duplicated number, then sample again until |
| * we found a unique number |
| */ |
| do |
| { |
| fid = random() % num_features; |
| } |
| while (0 < (bitmap[fid >> power_uint32] & dt_fid_mask(fid, power_uint32))); |
| |
| result[i] = Int32GetDatum(fid + 1); |
| |
| /* set the bit to true for the sampled number*/ |
| bitmap[fid >> power_uint32] |= dt_fid_mask(fid, power_uint32); |
| } |
| } |
| else if (0 == n_remain_fids) |
| { |
| /* |
| * if no features left, then simply return any one of the features |
| * so that the best split information can be retrieved |
| */ |
| num_req_features = 1; |
| result[0] = Int32GetDatum(1); |
| } |
| else |
| { |
| /* |
| * If the number of remain features are less than or equal randomly |
| * chosen features then return the remain features directly. |
| * n_remain_fids <= num_req_features |
| */ |
| |
| num_req_features = n_remain_fids; |
| |
| /* if the features weren't chosen, then choose them */ |
| for (int32 i = 0; i < num_features; ++i) |
| if (0 == (bitmap[i >> power_uint32] & dt_fid_mask(i, power_uint32))) |
| result[--n_remain_fids] = Int32GetDatum(i + 1); |
| |
| dt_check_error_value |
| ( |
| 0 == n_remain_fids, |
| "the number of random chosen features is wrong, total:%d", |
| n_remain_fids |
| ); |
| |
| } |
| |
| /* free the bitmap */ |
| pfree(bitmap); |
| |
| /* |
| * the number of elements in the result array must be |
| * greater than or equal to 1 |
| */ |
| dt_check_error_value |
| ( |
| num_req_features > 0, |
| "the number of chosen features for node %d is zero", |
| nid |
| ); |
| |
| result_array = |
| construct_array( |
| result, |
| num_req_features, |
| INT4OID, |
| sizeof(int32), |
| true, |
| 'i' |
| ); |
| |
| PG_RETURN_ARRAYTYPE_P(result_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_get_node_split_fids); |
| |
| |
| /* |
| * @brief Use % as the delimiter to split the given string. The char '\' is used |
| * to escape %. We will not change the default behavior of '\' in PG/GP. |
| * For example, assume the given string is E"\\\\\\\\\\%123%123". Then it only |
| * has one delimiter; the string will be split to two substrings: |
| * E'\\\\\\\\\\%123' and '123'; the position array size is 1, where position[0] = 9; |
| * ; (*len) = 13. |
| * |
| * @param str The string to be split. |
| * @param position An array to store the position of each un-escaped % in the string. |
| * @param num_pos The expected number of un-escaped %s in the string. |
| * @param len The length of the string. It doesn't include the terminal. |
| * |
| * @return The position array which records the positions of all un-escaped %s |
| * in the give string. |
| * |
| * @note If the number of %s in the string is not equal to the expected number, |
| * we will report error via elog. |
| */ |
| static |
| int* |
| dt_split_string |
| ( |
| char *str, |
| int *position, |
| int num_pos, |
| int *len |
| ) |
| { |
| int i = 0; |
| int j = 0; |
| |
| /* the number of the escape chars which occur continuously */ |
| int num_cont_escapes = 0; |
| |
| for(; str != NULL && *str != '\0'; ++str, ++j) |
| { |
| if ('%' == *str) |
| { |
| /* |
| * if the number of the escapes is even number |
| * then no need to escape. Otherwise escape the delimiter |
| */ |
| if (!(num_cont_escapes & 0x01)) |
| { |
| dt_check_error |
| ( |
| i < num_pos, |
| "the number of the elements in the array is less than " |
| "the format string expects." |
| ); |
| |
| /* convert the char '%' to '\0' */ |
| position[i++] = j; |
| *str = '\0'; |
| } |
| |
| /* reset the number of the continuous escape chars */ |
| num_cont_escapes = 0; |
| } |
| else if ('\\' == *str) |
| { |
| /* increase the number of continuous escape chars */ |
| ++num_cont_escapes; |
| } |
| else |
| { |
| /* reset the number of the continuous escape chars */ |
| num_cont_escapes = 0; |
| } |
| } |
| |
| *len = j; |
| |
| dt_check_error |
| ( |
| i == num_pos, |
| "the number of the elements in the array is greater than " |
| "the format string expects. " |
| ); |
| |
| return position; |
| } |
| |
| |
| /* |
| * @brief Change all occurrences of '\%' in the give string to '%'. Our split |
| * method will ensure that the char immediately before a '%' must be a '\'. |
| * We traverse the string from left to right, if we meet a '%', then |
| * move the substring after the current '\%' to the right place until |
| * we meet next '\%' or the '\0'. Finally, set the terminal symbol for |
| * the replaced string. |
| * |
| * @param str The null terminated string to be escaped. |
| * The char immediately before a '%' must be a '\'. |
| * |
| * @return The new string with \% changed to %. |
| * |
| */ |
| static |
| char* |
| dt_escape_pct_sym |
| ( |
| char *str |
| ) |
| { |
| int num_escapes = 0; |
| |
| /* remember the start address of the escaped string */ |
| char *p_new_string = str; |
| |
| while(str != NULL && *str != '\0') |
| { |
| if ('%' == *str) |
| { |
| dt_check_error_value |
| ( |
| (str - 1) && ('\\' == *(str - 1)), |
| "The char immediately before a %s must be a \\", |
| "%" |
| ); |
| |
| /* |
| * The char immediately before % is \ |
| * increase the number of escape chars |
| */ |
| ++num_escapes; |
| do |
| { |
| /* |
| * move the string which is between the current "\%" |
| * and next "\%" |
| */ |
| *(str - num_escapes) = *str; |
| ++str; |
| } while (str != NULL && *str != '\0' && *str != '%'); |
| } |
| else |
| { |
| ++str; |
| } |
| } |
| |
| /* if there is no escapes, then set the end symbol for the string */ |
| if (num_escapes > 0) |
| *(str - num_escapes) = '\0'; |
| |
| return p_new_string; |
| } |
| |
| |
| /* |
| * @brief We need to build a lot of query strings based on a set of arguments. For that |
| * purpose, this function will take a format string (the template) and an array |
| * of values, scan through the format string, and replace the %s in the format |
| * string with the corresponding values in the array. The result string is |
| * returned as a PG/GP text Datum. The escape char for '%' is '\'. And we will |
| * not change it's default behavior in PG/GP. For example, assume that |
| * fmt = E'\\\\\\\\ % \\% %', args[] = {"100", "20"}, then the returned text |
| * of this function is E'\\\\\\\\ 100 % 20' |
| * |
| * @param fmt The format string. %s are used to indicate a position |
| * where a value should be filled in. |
| * @param args An array of values that should be used for replacements. |
| * args[i] replaces the i-th % in fmt. The array length should |
| * equal to the number of %s in fmt. |
| * |
| * @return A string with all %s which were not escaped in first argument replaced |
| * with the corresponding values in the second argument. |
| * |
| */ |
| Datum |
| dt_text_format |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| dt_check_error |
| ( |
| !(PG_ARGISNULL(0) || PG_ARGISNULL(1)), |
| "the format string and its arguments must not be null" |
| ); |
| |
| char *fmt = text_to_cstring(PG_GETARG_TEXT_PP(0)); |
| ArrayType *args_array = PG_GETARG_ARRAYTYPE_P(1); |
| |
| dt_check_error_value |
| ( |
| !ARR_HASNULL(args_array), |
| "the first array passed to %s cannot contain NULL values", |
| __FUNCTION__ |
| ); |
| |
| dt_check_error |
| ( |
| !ARR_NULLBITMAP(args_array), |
| "the argument array must not has null value" |
| ); |
| |
| int nitems = 0; |
| int *dims = NULL; |
| int ndims = 0; |
| Oid element_type= 0; |
| int typlen = 0; |
| bool typbyval = false; |
| char typalign = '\0'; |
| char *p = NULL; |
| int i = 0; |
| |
| ArrayMetaState *my_extra= NULL; |
| StringInfoData buf; |
| |
| ndims = ARR_NDIM(args_array); |
| dims = ARR_DIMS(args_array); |
| nitems = ArrayGetNItems(ndims, dims); |
| |
| /* if there are no elements, return the format string directly */ |
| if (nitems == 0) |
| PG_RETURN_TEXT_P(cstring_to_text(fmt)); |
| |
| int *position = (int*)palloc0(nitems * sizeof(int)); |
| |
| int last_pos = 0; |
| int len_fmt = 0; |
| |
| /* |
| * split the format string, so that later we can replace the delimiters |
| * with the given arguments |
| */ |
| dt_split_string(fmt, position, nitems, &len_fmt); |
| |
| element_type = ARR_ELEMTYPE(args_array); |
| initStringInfo(&buf); |
| |
| /* |
| * We arrange to look up info about element type, including its output |
| * conversion proc, only once per series of calls, assuming the element |
| * type doesn't change underneath us. |
| */ |
| my_extra = (ArrayMetaState *) fcinfo->flinfo->fn_extra; |
| if (my_extra == NULL) |
| { |
| fcinfo->flinfo->fn_extra = MemoryContextAlloc |
| ( |
| fcinfo->flinfo->fn_mcxt, |
| sizeof(ArrayMetaState) |
| ); |
| my_extra = (ArrayMetaState *) fcinfo->flinfo->fn_extra; |
| my_extra->element_type = ~element_type; |
| } |
| |
| if (my_extra->element_type != element_type) |
| { |
| /* |
| * Get info about element type, including its output conversion proc |
| */ |
| get_type_io_data |
| ( |
| element_type, |
| IOFunc_output, |
| &my_extra->typlen, |
| &my_extra->typbyval, |
| &my_extra->typalign, |
| &my_extra->typdelim, |
| &my_extra->typioparam, |
| &my_extra->typiofunc |
| ); |
| fmgr_info_cxt |
| ( |
| my_extra->typiofunc, |
| &my_extra->proc, |
| fcinfo->flinfo->fn_mcxt |
| ); |
| my_extra->element_type = element_type; |
| } |
| typlen = my_extra->typlen; |
| typbyval = my_extra->typbyval; |
| typalign = my_extra->typalign; |
| p = ARR_DATA_PTR(args_array); |
| |
| for (i = 0; i < nitems; i++) |
| { |
| Datum itemvalue; |
| char *value; |
| |
| itemvalue = fetch_att(p, typbyval, typlen); |
| value = OutputFunctionCall(&my_extra->proc, itemvalue); |
| |
| /* there is no string before the delimiter */ |
| if (last_pos == position[i]) |
| { |
| appendStringInfo(&buf, "%s", value); |
| ++last_pos; |
| } |
| else |
| { |
| /* |
| * has a string before the delimiter |
| * we replace "\%" in the string to "%", since "%" is escaped |
| * then combine the string and argument string together |
| */ |
| appendStringInfo |
| ( |
| &buf, |
| "%s%s", |
| dt_escape_pct_sym(fmt + last_pos), |
| value |
| ); |
| |
| last_pos = position[i] + 1; |
| } |
| |
| p = att_addlength_pointer(p, typlen, p); |
| p = (char *) att_align_nominal(p, typalign); |
| } |
| |
| /* the last char in the format string is not delimiter */ |
| if (last_pos < len_fmt) |
| appendStringInfo(&buf, "%s", fmt + last_pos); |
| |
| PG_RETURN_TEXT_P(cstring_to_text_with_len(buf.data, buf.len)); |
| } |
| PG_FUNCTION_INFO_V1(dt_text_format); |
| |
| |
| /* |
| * @brief This function checks whether the specified table exists or not. |
| * |
| * @param input The table name to be tested. |
| * |
| * @return A boolean value indicating whether the table exists or not. |
| */ |
| Datum table_exists(PG_FUNCTION_ARGS) |
| { |
| text* input; |
| List* names; |
| Oid relid; |
| |
| if (PG_ARGISNULL(0)) |
| PG_RETURN_BOOL(false); |
| |
| input = PG_GETARG_TEXT_PP(0); |
| |
| names = textToQualifiedNameList(input); |
| #if PG_VERSION_NUM >= 90200 |
| relid = RangeVarGetRelid(makeRangeVarFromNameList(names), NoLock, true); |
| #else |
| relid = RangeVarGetRelid(makeRangeVarFromNameList(names), true); |
| #endif |
| PG_RETURN_BOOL(OidIsValid(relid)); |
| } |
| PG_FUNCTION_INFO_V1(table_exists); |
| |
| |
| /* |
| * @brief The step function for generating the acc counts. |
| * |
| * @param class_count_array The array used to store the count information. |
| * The length of the array equals max_num_of_classes. |
| * @param max_num_of_classes The total number of distinct class values. |
| * @param count The count value to be accumulated. |
| * @param class The current class value. |
| * |
| * @return The updated version of class_count_array. |
| * |
| */ |
| Datum |
| dt_acc_count_sfunc |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| ArrayType *pg_count_array = NULL; |
| int array_dim = 0; |
| int *p_array_dim = NULL; |
| int array_length = 0; |
| int64 *count_array = NULL; |
| dt_check_error_value |
| ( |
| !PG_ARGISNULL(1), |
| "In function: %s. " |
| "The parameter of 'max_num_of_classes' should not be null", |
| __FUNCTION__ |
| ); |
| int max_num_of_classes = PG_GETARG_INT32(1); |
| int64 count = PG_ARGISNULL(2)?0:PG_GETARG_INT64(2); |
| int class = PG_ARGISNULL(3)?0:PG_GETARG_INT32(3); |
| bool rebuild_array = false; |
| |
| dt_check_error_value |
| ( |
| max_num_of_classes >= 2 && max_num_of_classes <= 1e6, |
| "invalid value: %d. " |
| "The number of classes must be in the range of [2, 1e6]", |
| max_num_of_classes |
| ); |
| |
| dt_check_error_value |
| ( |
| class >= 1 && class <= max_num_of_classes, |
| "invalid real class value: %d. " |
| "It must be in range from 1 to the number of classes", |
| class |
| ); |
| |
| /* test if the first argument (class count array) is null */ |
| if (PG_ARGISNULL(0)) |
| { |
| /* |
| * We assume the maximum number of classes is limited (up to millions), |
| * so that the allocated array won't break our memory limitation. |
| */ |
| count_array = palloc0(sizeof(int64) * max_num_of_classes); |
| array_length = max_num_of_classes; |
| rebuild_array = true; |
| |
| } |
| else |
| { |
| if (fcinfo->context && IsA(fcinfo->context, AggState)) |
| pg_count_array = PG_GETARG_ARRAYTYPE_P(0); |
| else |
| pg_count_array = PG_GETARG_ARRAYTYPE_P_COPY(0); |
| |
| dt_check_error |
| ( |
| pg_count_array, |
| "invalid aggregation state array" |
| ); |
| |
| dt_check_error_value |
| ( |
| !ARR_HASNULL(pg_count_array), |
| "the first array passed to %s cannot contain NULL values", |
| __FUNCTION__ |
| ); |
| |
| array_dim = ARR_NDIM(pg_count_array); |
| |
| dt_check_error_value |
| ( |
| array_dim == 1, |
| "invalid array dimension: %d. " |
| "The dimension of class count array must be equal to 1", |
| array_dim |
| ); |
| |
| p_array_dim = ARR_DIMS(pg_count_array); |
| array_length = ArrayGetNItems(array_dim,p_array_dim); |
| count_array = (int64 *)ARR_DATA_PTR(pg_count_array); |
| |
| dt_check_error_value |
| ( |
| array_length == max_num_of_classes, |
| "dt_acc_count_sfunc invalid array length: %d. " |
| "The length of class count array must be " |
| "equal to the total number classes", |
| array_length |
| ); |
| } |
| |
| count_array[class - 1] += count; |
| |
| if (rebuild_array) |
| { |
| /* construct a new array to keep the aggr states. */ |
| pg_count_array = |
| construct_array( |
| (Datum *)count_array, |
| array_length, |
| INT8OID, |
| sizeof(int64), |
| true, |
| 'd' |
| ); |
| } |
| |
| PG_RETURN_ARRAYTYPE_P(pg_count_array); |
| } |
| PG_FUNCTION_INFO_V1(dt_acc_count_sfunc); |
| |
| |
| /* |
| * @brief Cast a value to text. On some databases, there |
| * are no such casts for certain data types, such as |
| * the cast for bool to text. |
| * |
| * @param value The value with any specific type |
| * |
| * @note This is a strict function. |
| * |
| */ |
| Datum |
| dt_to_text |
| ( |
| PG_FUNCTION_ARGS |
| ) |
| { |
| Datum value = PG_GETARG_DATUM(0); |
| Oid valtype = get_fn_expr_argtype(fcinfo->flinfo, 0); |
| Oid typoutput = 0; |
| bool typIsVarlena = 0; |
| char *result = NULL; |
| |
| getTypeOutputInfo(valtype, &typoutput, &typIsVarlena); |
| |
| // call the output function of the type to convert |
| result = OidOutputFunctionCall(typoutput, value); |
| |
| PG_RETURN_TEXT_P(cstring_to_text(result)); |
| } |
| PG_FUNCTION_INFO_V1(dt_to_text); |
| |
| |
| /* |
| * @brief The step function of the aggregate array_indexed_agg. |
| * To avoid allocating memory in each step function and manipulating |
| * the array bitmap for null values, we keep the null values by |
| * ourself. The solution is that, we use two items in the state |
| * array to represent one result item. The 2*i-th item in the state |
| * array represents the actual value of the i-th result item, |
| * and the 2*i+1-th item in the state array represents whether |
| * the i-th result item is NULL. |
| * |
| * @param state The step state array of the aggregate function. |
| * @param elem The element to be filled into the state array. |
| * @param elem_cnt The number of elements. |
| * @param elem_idx The subscript of "elem" in the state array. |
| * |
| */ |
| Datum dt_array_indexed_agg_sfunc(PG_FUNCTION_ARGS) |
| { |
| ArrayType *state; |
| ArrayBuildState build_state; |
| Datum elem; |
| Oid elem_typ = FLOAT8OID; |
| int32_t elem_cnt; |
| int32_t elem_idx; |
| int32_t iterator_idx; |
| |
| dt_check_error_value |
| ( |
| (fcinfo->context && IsA(fcinfo->context, AggState)), |
| "%s can only be used in aggregations", |
| __FUNCTION__ |
| ); |
| |
| state = PG_ARGISNULL(0) ? NULL : PG_GETARG_ARRAYTYPE_P(0); |
| elem = PG_ARGISNULL(1) ? (Datum) 0 : PG_GETARG_DATUM(1); |
| elem_cnt = PG_GETARG_INT64(2); |
| elem_idx = PG_GETARG_INT64(3) - 1; |
| |
| dt_check_error_value |
| ( |
| elem_cnt > 0, |
| "array_size:%d should be bigger than zero", |
| elem_cnt |
| ); |
| |
| dt_check_error_value |
| ( |
| elem_idx >= 0 && elem_idx < elem_cnt, |
| "the subscript %d is out of range", |
| elem_idx |
| ); |
| |
| get_typlenbyvalalign |
| ( |
| elem_typ, |
| &build_state.typlen, |
| &build_state.typbyval, |
| &build_state.typalign |
| ); |
| |
| if (NULL == state) |
| { |
| build_state.mcontext = NULL; |
| |
| /* |
| * allocate two element for each index, the first one is the value, |
| * the second one indicates whether the item is null |
| */ |
| build_state.alen = (elem_cnt << 1); |
| build_state.dvalues = (Datum *) palloc(build_state.alen * sizeof(Datum)); |
| build_state.dnulls = NULL; |
| build_state.nelems = build_state.alen; |
| build_state.element_type = elem_typ; |
| |
| for (iterator_idx = 0; iterator_idx < build_state.alen; iterator_idx++) |
| { |
| build_state.dvalues[iterator_idx] = Float8GetDatum(1); |
| } |
| |
| /* put the elem into the target slot */ |
| build_state.dvalues[elem_idx << 1] = elem; |
| build_state.dvalues[(elem_idx << 1) + 1] = |
| Float8GetDatum(PG_ARGISNULL(1) ? 1 : 0); |
| |
| state = construct_array(build_state.dvalues, build_state.nelems, |
| build_state.element_type, build_state.typlen, |
| build_state.typbyval, build_state.typalign); |
| |
| PG_RETURN_ARRAYTYPE_P(state); |
| } |
| |
| dt_check_error_value |
| ( |
| !ARR_HASNULL(state), |
| "the first array passed to %s cannot contain NULL values", |
| __FUNCTION__ |
| ); |
| |
| dt_check_error_value |
| ( |
| ARR_DIMS(state)[0] == (elem_cnt << 1), |
| "The dimension of state array should be %d", |
| (elem_cnt << 1) |
| ); |
| |
| ((float8*)ARR_DATA_PTR(state))[(elem_idx << 1)] = DatumGetFloat8(elem); |
| ((float8*)ARR_DATA_PTR(state))[(elem_idx << 1) + 1] = PG_ARGISNULL(1) ? 1 : 0; |
| |
| PG_RETURN_ARRAYTYPE_P(state); |
| } |
| PG_FUNCTION_INFO_V1(dt_array_indexed_agg_sfunc); |
| |
| |
| /* |
| * @brief The pre-function of the aggregate array_indexed_agg. |
| * |
| * @param arg0 The first state array. |
| * @param arg1 The second state array. |
| * |
| * @return The combined state. |
| * |
| */ |
| Datum dt_array_indexed_agg_prefunc(PG_FUNCTION_ARGS) |
| { |
| ArrayType *arg0, *arg1; |
| int64 iterator_idx; |
| int32 elem_cnt; |
| int64 elem_idx; |
| |
| dt_check_error_value |
| ( |
| (fcinfo->context && IsA(fcinfo->context, AggState)), |
| "%s can only be used in aggregations", |
| __FUNCTION__ |
| ); |
| |
| arg0 = PG_ARGISNULL(0) ? NULL : PG_GETARG_ARRAYTYPE_P(0); |
| arg1 = PG_ARGISNULL(1) ? NULL : PG_GETARG_ARRAYTYPE_P(1); |
| |
| if (NULL == arg0) |
| { |
| PG_RETURN_ARRAYTYPE_P(arg1); |
| } |
| else if (NULL == arg1) |
| { |
| PG_RETURN_ARRAYTYPE_P(arg0); |
| } |
| |
| dt_check_error |
| ( |
| ARR_NDIM(arg0) == ARR_NDIM(arg1), |
| "the dimension of the two state array should be the same" |
| ); |
| |
| dt_check_error |
| ( |
| 1 == ARR_NDIM(arg0), |
| "the dimension of state array must be equal to 1" |
| ); |
| |
| dt_check_error |
| ( |
| ARR_DIMS(arg0)[0] == ARR_DIMS(arg1)[0], |
| "the size of the two state array must be the same" |
| ); |
| |
| elem_cnt = (ARR_DIMS(arg0)[0]) >> 1; |
| |
| for (iterator_idx = 0; iterator_idx < elem_cnt; iterator_idx++) |
| { |
| elem_idx = iterator_idx << 1; |
| |
| /* |
| * just taking the non-null one, pre-steps must make |
| * sure there is no duplicate |
| */ |
| if (0 == (int)((float8*)ARR_DATA_PTR(arg1))[elem_idx + 1]) |
| { |
| ((float8*)ARR_DATA_PTR(arg0))[elem_idx] = |
| ((float8*)ARR_DATA_PTR(arg1))[elem_idx]; |
| ((float8*)ARR_DATA_PTR(arg0))[elem_idx + 1] = 0; |
| } |
| } |
| |
| PG_RETURN_ARRAYTYPE_P(arg0); |
| } |
| PG_FUNCTION_INFO_V1(dt_array_indexed_agg_prefunc); |
| |
| |
| /* |
| * @brief The final function of array_indexed_agg. |
| * |
| * @param state The state array. |
| * |
| * @return The aggregate result. |
| * |
| */ |
| Datum dt_array_indexed_agg_ffunc(PG_FUNCTION_ARGS) |
| { |
| ArrayType *state, *result; |
| ArrayBuildState build_state; |
| Oid elem_typ = FLOAT8OID; |
| int32_t elem_cnt; |
| int32_t iterator_idx; |
| int lbs[1]; |
| |
| dt_check_error_value |
| ( |
| (fcinfo->context && IsA(fcinfo->context, AggState)), |
| "%s can only be used in aggregations", |
| __FUNCTION__ |
| ); |
| |
| state = PG_ARGISNULL(0) ? NULL : PG_GETARG_ARRAYTYPE_P(0); |
| |
| dt_check_error |
| ( |
| NULL != state, |
| "the state array that fed into the final aggregate " |
| "should not be null" |
| ); |
| |
| dt_check_error |
| ( |
| 1 == ARR_NDIM(state), |
| "the dimension of the state array should be equal to 1" |
| ); |
| |
| dt_check_error |
| ( |
| 0 == (ARR_DIMS(state)[0] & 0x01), |
| "invalid state array" |
| ); |
| |
| elem_cnt = (ARR_DIMS(state)[0]) >> 1; |
| |
| get_typlenbyvalalign |
| ( |
| elem_typ, |
| &build_state.typlen, |
| &build_state.typbyval, |
| &build_state.typalign |
| ); |
| |
| build_state.mcontext = NULL; |
| build_state.alen = elem_cnt; |
| build_state.dvalues = (Datum *) palloc(build_state.alen * sizeof(Datum)); |
| build_state.dnulls = (bool *) palloc(build_state.alen * sizeof(bool)); |
| build_state.nelems = build_state.alen; |
| build_state.element_type= elem_typ; |
| |
| for (iterator_idx = 0; iterator_idx < elem_cnt; iterator_idx ++) |
| { |
| build_state.dnulls[iterator_idx] = |
| (int)((float8*)ARR_DATA_PTR(state))[(iterator_idx << 1) + 1]; |
| build_state.dvalues[iterator_idx] = |
| Float8GetDatum(((float8*)ARR_DATA_PTR(state))[(iterator_idx << 1)]); |
| } |
| |
| lbs[0] = 1; |
| result = construct_md_array |
| ( |
| build_state.dvalues, |
| build_state.dnulls, |
| 1, |
| &(build_state.nelems), |
| lbs, |
| build_state.element_type, |
| build_state.typlen, |
| build_state.typbyval, |
| build_state.typalign |
| ); |
| |
| PG_RETURN_ARRAYTYPE_P(result); |
| } |
| PG_FUNCTION_INFO_V1(dt_array_indexed_agg_ffunc); |