| # utf-8 |
| """ |
| @file margins_common.py_in |
| |
| @brief Marginal Effects with Interactions: Builds the main interaction effects |
| class and describes multiple functions that are common to the various methods. |
| These functions are not directly related to database constructs. Functions that |
| depend on the database should not be included in this file. |
| |
| @namespace marginal |
| """ |
| import re |
| from collections import defaultdict |
| from itertools import chain |
| from collections import namedtuple |
| import operator |
| |
| # import plpy |
| # ============================================================================= |
| |
| CONSTANT = "constant" |
| TermBase = namedtuple("TermBase", ["identifier"]) |
| |
| |
| def _eq(self, other): |
| return self.identifier.strip(' "') == other.identifier.strip(' "') |
| |
| |
| def _ne(self, other): |
| return self.identifier.strip(' "') != other.identifier.strip(' "') |
| |
| |
| def _hash(self): |
| return hash(self.identifier.strip(' "')) |
| |
| |
| def _term_to_string(self): |
| """ |
| Return string representation of a term |
| """ |
| return str(self.identifier) |
| |
| TermBase.__str__ = _term_to_string |
| TermBase.__repr__ = _term_to_string |
| TermBase.__eq__ = _eq |
| TermBase.__ne__ = _ne |
| TermBase.__hash__ = _hash |
| |
| |
| class ContinuousTerm(TermBase): |
| """Class to represent a continuous variable""" |
| def __new__(cls, identifier, index): |
| obj = TermBase.__new__(cls, identifier) |
| obj.index = index |
| return obj |
| #------------------------------------------------------------------------------ |
| |
| |
| class IndicatorTerm(TermBase): |
| """Class to represent a dummy coded term corresponding to a categorical variable""" |
| def __new__(cls, identifier, index, category=None, is_reference=False): |
| obj = TermBase.__new__(cls, identifier) |
| obj.index = index |
| obj.is_reference = is_reference |
| obj.category = category if category else '' |
| return obj |
| #------------------------------------------------------------------------------ |
| |
| |
| class InteractionTerm(object): |
| """Collection of product of (power of) multiple basis terms (continuous or indicator) |
| |
| Example: |
| (all '(<str>)' represent objects with <str> as identifier) |
| |
| Interaction string = 'age * height' at index 5 |
| InteractionTerm object: |
| {self.index = 5, |
| self.all_terms = {(age): 1, (height): 1} |
| self.constant_term = 1 |
| } |
| |
| Interaction string = 'i.is_female * 3 * age^2' at index 10 |
| InteractionTerm object: |
| {self.index = 10, |
| self.all_terms = {(age): 2, (is_female): 1} |
| self.constant_term = 3 |
| } |
| """ |
| def __init__(self, index): |
| self.index = index |
| |
| # all_terms is a dictionary that maps each base term to a power |
| # value of that term. The default power value is 0. |
| # key = TermBase (ContinuousTerm or IndicatorTerm) |
| # value = power of the base term (int) |
| self.all_terms = defaultdict(int) |
| |
| # an interaction term can contain a single constant term (default = 1) |
| self.constant_term = 1 |
| |
| def __hash__(self): |
| return hash(self.index) |
| |
| def __eq__(self, other): |
| if self.constant_term != other.constant_term: |
| return False |
| for self_term, self_power in self.all_terms.items(): |
| if not ((self_term in other.all_terms) and |
| (self_power == other.all_terms[self_term])): |
| return False |
| return True |
| |
| def __ne__(self, other): |
| return not self.__eq__(other) |
| |
| def __contains__(self, item): |
| """ Function to test membership of a base term """ |
| return (item in self.all_terms) |
| |
| def __repr__(self): |
| """ Return string representation of interaction term """ |
| str_rst = [] |
| for each_term, each_power in self.all_terms.items(): |
| if each_power > 1: |
| power_str = "^" + str(each_power) |
| else: |
| power_str = "" |
| str_rst.append(str(each_term) + power_str) |
| if self.constant_term > 1: |
| str_rst.append(str(self.constant_term)) |
| return ' * '.join(str_rst) |
| |
| def __nonzero__(self): |
| """ Function to implement truth value testing and the built-in |
| operation bool() """ |
| return bool(self.all_terms) or self.constant_term > 1 |
| |
| def add_base_term(self, base_term, power): |
| self.all_terms[base_term] += power |
| |
| def compute_partial_deriv(self, ref): |
| """ |
| Args: |
| @param ref: int. Variable for which to compute partial derivative |
| |
| Returns: |
| InteractionTerm. Partial derivative of self |
| |
| Note: The index value of result is 0 since this term is not present in |
| the original variable list. |
| """ |
| if not self.all_terms: |
| return InteractionTerm(0) |
| result = InteractionTerm(0) |
| if ref in sorted(self.all_terms.keys(), key=operator.attrgetter('index')): |
| result.constant_term = self.constant_term |
| for base_term, power in self.all_terms.items(): |
| if ref == base_term: |
| if power > 1: |
| result.add_base_term(ref, power - 1) |
| result.constant_term *= power |
| else: |
| result.add_base_term(base_term, power) |
| else: |
| result.constant_term = 0 |
| return result |
| |
| def str_using_indices(self, array_name, quoted=False, exclude_terms=None): |
| """ |
| Return string representation of an interaction term using the indices |
| of the individual terms. |
| |
| Args: |
| @param array_name: string. If the int_terms are to be treated as |
| indices of an array then 'array_name' can be |
| used while building the string. If None, then |
| terms are printed as is. |
| @param quoted: bool. If true, double quotes are added around the |
| array_name. Ignored if array_name is None. |
| @param exclude_terms: list. List of terms to exclude for output |
| string |
| Returns: |
| str. String representation of interaction term. |
| |
| Example: |
| all_terms = {ContinuousTerm('a', 2):1, ContinuousTerm('c', 1):2} |
| array_name = 'x' |
| quoted = True |
| output = '"x"[2]^1 * "x"[1]^2' |
| |
| all_terms = {ContinuousTerm('age', 1): 2, |
| IndicatorTerm('gender', 3): 4} |
| constant_term = 3 |
| array_name = 'temp' |
| quoted = False |
| output = '3 * temp[1]^2 * temp[3]^4' |
| """ |
| if not exclude_terms: |
| exclude_terms = [] |
| if not array_name: |
| return str(self) |
| str_rst = [] |
| for term in sorted(self.all_terms.keys(), key=operator.attrgetter('index')): |
| if term in exclude_terms: |
| continue |
| power = self.all_terms[term] |
| if power < 1: |
| continue |
| power_suffix = "^" + str(power) if power > 1 else "" |
| quote = "\"" if quoted else "" |
| prefix_str = ("{q}{0}{q}[{1}]". |
| format(array_name, term.index, q=quote)) |
| str_append = ("{prefix}{suffix}". |
| format(prefix=prefix_str, suffix=power_suffix)) |
| str_rst.append(str_append) |
| if self.constant_term > 1: |
| str_rst.append(str(self.constant_term)) |
| return ' * '.join(str_rst) |
| #------------------------------------------------------------------------------- |
| |
| |
| class MarginalEffectsBuilder(object): |
| |
| def __init__(self, design_str): |
| """ |
| Args: |
| @param self |
| @param design_string |
| """ |
| self.design_str = design_str |
| |
| # All the basis terms in the independent_variable |
| # key = index of the basis term in original independent_variable |
| # value = TermBase object |
| self.basis_terms = dict() |
| |
| # inverse index of the basis terms to quickly get the index for each term |
| # key = TermBase object |
| # value = index of the basis term in original independent_variable |
| self.inverse_basis_terms = dict() |
| |
| # indicator_terms (dict): Elements in 'basis_terms' that are indicators |
| # key = category of the indicator variable, |
| # value = list of all IndicatorTerm for a particular category |
| self.indicator_terms = defaultdict(list) |
| |
| # reference_terms (dict): For each category in indicator terms, if the |
| # reference variable is provided then it is added to this dictionary |
| # key = category of the indicator variable, |
| # value = reference IndicatorTerm |
| self.reference_terms = dict() |
| |
| # interaction_terms (dict): Each interaction term in the independent_variable |
| # key = index from independent_variable |
| # value = InteractionTerm |
| # see _parse_design_string() for an example of the format |
| # of a single interaction term. |
| self.interaction_terms = dict() |
| |
| # # parse the design string to populate all variables |
| self.n_basis_terms, self.n_int_terms = self._parse_design_string() |
| |
| # indices of a subset of basis_terms for which we want to compute margins |
| self._subset_basis = None |
| #--------------------------------------------------------------------------- |
| |
| @property |
| def n_total_terms(self): |
| """ |
| Get the total number of terms in the x_design string |
| """ |
| return len(self.basis_terms) + len(self.interaction_terms) |
| #--------------------------------------------------------------------------- |
| |
| @property |
| def subset_basis(self): |
| return self._subset_basis |
| |
| @subset_basis.setter |
| def subset_basis(self, list_value): |
| """ |
| Args: |
| @param list_value: list. List of indices of the basis terms |
| to include in the subset |
| """ |
| list_value = map(str, list_value) |
| if not all(TermBase(i) in self.inverse_basis_terms for i in list_value): |
| raise ValueError("Margins Error: Subset basis contains identifier not present in basis term") |
| try: |
| self._subset_basis = [int(self.inverse_basis_terms[TermBase(i)]) |
| for i in list_value] |
| except TypeError: |
| raise TypeError("Margins Error: Invalid marginal vars argument ({0})". |
| format(str(list_value))) |
| |
| def contains_interaction(self): |
| return len(self.interaction_terms) > 0 |
| |
| def contains_indicators(self): |
| return len(self.indicator_terms) > 0 |
| |
| def is_categorical(self, basis_term): |
| return isinstance(basis_term, IndicatorTerm) |
| # -------------------------------------------------------------------------- |
| |
| def get_indicator_reference(self, term): |
| """ |
| Get the reference variable for a given indicator term |
| Args: |
| @param term |
| |
| Returns: |
| str |
| """ |
| if not self.indicator_terms or not term or not term.category: |
| return None |
| try: |
| reference = self.reference_terms[term.category] |
| except KeyError: |
| return None |
| return reference |
| #--------------------------------------------------------------------------- |
| |
| def get_siblings(self, ind_term): |
| """ |
| Args: |
| @param self |
| @param ind_term |
| |
| Returns: |
| list. A list of terms that are in the same category as the ind_term. |
| """ |
| if not isinstance(ind_term, IndicatorTerm): |
| return None |
| if not ind_term.category: |
| return [] |
| result = list(self.indicator_terms[ind_term.category]) |
| result.remove(ind_term) |
| return result |
| #--------------------------------------------------------------------------- |
| |
| def _get_term_from_str(self, term_str, index=None): |
| """ |
| Strip any prefix and suffix present in the term_str input and return |
| a term object |
| |
| Args: |
| @param term_str: string, String input of a variable, |
| General format of input is [ir].<name>.<label> |
| - The string before the first '.' is considered a prefix and |
| should be either 'i' or 'ir' |
| - The string between the first two '.' is the identifier |
| name of the object. |
| - The string after the second '.' is the label of the variable |
| @param index: int, The index of independent variable in which the |
| term was present. A new continuous or indicator term is created |
| using this index. If index is None, it indicates that |
| the term is part of an interaction term. In this case the term |
| is returned from existing basis terms. |
| |
| Returns: |
| TermBase. Either a ContinuousTerm or an IndicatorTerm |
| |
| Raises: |
| ValueError, if input is invalid |
| |
| Examples: |
| (Input -- Output): |
| """ |
| term_str = term_str.strip() |
| reg = re.compile('".*?"|[^\.]+') |
| t = reg.findall(term_str) |
| if t: |
| # default values |
| is_indicator, is_reference = False, False |
| indicator_category = None |
| prefix = None |
| |
| if len(t) > 1: |
| # (0) is prefix, (1) is identifier, (2) is suffix |
| identifier = t[1] |
| prefix = t[0].strip() |
| if prefix not in ('i', 'ir', 'ib'): |
| raise ValueError("Invalid prefix string '{0}'".format(prefix)) |
| is_reference = (prefix in ('ir', 'ib')) |
| is_indicator = True |
| if len(t) > 2: |
| indicator_category = t[2].strip() |
| else: |
| # whole term is identifier (may include quotes) |
| identifier = t[0] |
| if not (identifier.startswith('"') and identifier.endswith('"')): |
| if not re.match(r'^\w+$', identifier): |
| raise ValueError("Invalid identifier '{0}' in x_design. Quote " |
| "an identifier with non-alphanumeric " |
| "characters using double quotes.". |
| format(identifier)) |
| identifier = identifier.lower() |
| # if identifier is not quoted then always keep a lower-case copy |
| # (i.e. identifier is case-insensitive) |
| # else: |
| # # if identifier is quoted then we keep the case intact, but |
| # # strip the double-quotes at the two ends. |
| # identifier = identifier.strip(' "') |
| else: |
| raise ValueError("Invalid string '{0}' in x_design".format(str(term_str))) |
| |
| if index: |
| if is_indicator: |
| return IndicatorTerm(identifier, index, |
| category=indicator_category, |
| is_reference=is_reference) |
| else: |
| return ContinuousTerm(identifier, index) |
| else: |
| try: |
| term_index = self.inverse_basis_terms[TermBase(identifier)] |
| output_term = self.basis_terms[term_index] |
| except (KeyError, ValueError): |
| raise ValueError( |
| "Invalid input in x_design. Interaction term '{0}' " |
| "contains variable '{1}' not defined as a " |
| "basis variable.". |
| format(str(term_str), str(identifier))) |
| return output_term |
| # -------------------------------------------------------------------------- |
| |
| def _parse_interaction_term(self, interaction_str, index): |
| """ |
| Parse interaction term containing product of powers of basis terms |
| to generate a list of the terms involved in the product. |
| |
| Args: |
| @param interaction_str: string, Interaction between variables |
| represented as a product of powers of variables. |
| |
| Returns: |
| InteractionTerm |
| """ |
| if (interaction_str.startswith('"') and interaction_str.endswith('"')): |
| # if quotes are present around the term then it is to be treated as |
| # is - i.e. as a basis variable |
| raise ValueError("""Inputed interaction term {0} is actually a |
| basis term""".format(interaction_str)) |
| |
| # split string by '*'; eg: '1^2 * 2^2' would give two elements, |
| # 'age^2' would give 1 element |
| # ensure quoted strings are not split |
| # prod_reg = re.compile('".*?"|[^\*]+') |
| # product_terms = [i.strip() for i in prod_reg.findall(interaction_str.strip())] |
| product_reg = re.compile(r'((?:[^\*"]|"[^"]*")+)') |
| product_terms = [i.strip() for i in |
| product_reg.split(interaction_str)[1::2]] |
| |
| if not product_terms: |
| raise ValueError("Not able to process interaction term: {0}". |
| format(interaction_str)) |
| if len(product_terms) == 1: |
| if len(product_terms[0].strip().split("^")) == 1: |
| # a single variable is input without any power or product. |
| # this should be treated as basis term not interaction term |
| raise ValueError("""Inputed interaction term {0} is actually a |
| basis term""".format(interaction_str)) |
| |
| interaction_term = InteractionTerm(index) |
| for each_term in product_terms: |
| # power_reg = re.compile('".*?"|[^\^]+') |
| # each_term_split = [i.strip() |
| # for i in power_reg.findall(each_term.strip())] |
| power_reg = re.compile(r'((?:[^\^"]|"[^"]*")+)') |
| each_term_split = [i.strip() for i in |
| power_reg.split(each_term.strip())[1::2]] |
| |
| if len(each_term_split) > 2: |
| raise ValueError("Invalid interaction term {0}".format(each_term)) |
| basis_term = self._get_term_from_str(each_term_split[0]) |
| |
| if len(each_term_split) > 1: |
| try: |
| basis_term_power = int(each_term_split[1]) |
| except Exception: |
| raise ValueError("1. Invalid input for interaction term: {0}". |
| format(each_term)) |
| if basis_term_power < 0: |
| # negative powers not allowed |
| raise ValueError("2. Invalid input for interaction term: {0}". |
| format(each_term)) |
| else: |
| # if no power is provided (eg. '1 * 2') then default power = 1 |
| basis_term_power = 1 |
| interaction_term.add_base_term(basis_term, basis_term_power) |
| return interaction_term |
| # -------------------------------------------------------------------------- |
| |
| def _parse_design_string(self): |
| """ |
| Parse the design string to compute the basis terms (continuous and |
| indicators) and the interaction terms in the independent variables list. |
| |
| Returns: |
| Tuple with elements: |
| - int: number of basis terms |
| - int: number of interaction terms |
| |
| Updates: |
| self.basis_terms, self.inverse_basis_terms, |
| self.indicator_terms, self.interaction_terms |
| |
| Raises: |
| ValueError |
| |
| Example: |
| self.design_str = "ir.1.color, i.2.color, i.3.color, i.4.gender, |
| i.5, i.5*i.2.color, i.5*i.3.color, |
| i.5*i.4.gender, i.5*i.4.gender*i.2.color, |
| i.5*i.4.gender*i.3.color, age, age^2, i.5*age, |
| i.5*age^2, height" |
| |
| # all terms in () are objects with the identifier string inside |
| self.basis_terms = {1: (1), 2: (2), 3: (3), 4: (4), 5: (5), |
| 11: (age), 15: (height)}, |
| self.interaction_terms = {(6): {(5):1, (2):1}, (7): {(5): 1, 3: 1}, |
| (8): {(5):1, 4:1}, |
| (9): {(5): 1, (4): 1, (2): 1}, |
| (10): {(5): 1, (4): 1, 3: 1}, |
| (12): {(age): (2)}, |
| (13): {(5): 1, (age): 1}, |
| (14): {(5):1, (age): 2}} |
| self.indicator_terms = { |
| 'color': [(1), (2), (3)], |
| 'gender': [(4)], '': [(5)]} |
| self.reference_terms = {'color': (1)} |
| |
| """ |
| if not self.design_str: |
| raise ValueError("Invalid input for x_design: {0}". |
| format(str(self.design_str))) |
| |
| # read a comma-separated string quoted using '"' |
| # comma_reg = re.compile('\s*".*?"|[^\,]+') |
| comma_reg = re.compile(r'((?:[^,"]|"[^"]*")+)') |
| design_terms = [i.strip() for i in comma_reg.split(self.design_str)[1::2]] |
| interaction_term_strings = {} |
| interaction_term_reg = re.compile('".*?"|[\*\^]') |
| for index, each_term in enumerate(design_terms): |
| each_term = each_term.strip() |
| # Database indexing starts from 1 instead of 0 |
| base_1_index = index + 1 |
| |
| # find if it is an interaction term: interaction_chars |
| # has matches for * and ^ while skipping double-quoted strings |
| interaction_chars = interaction_term_reg.findall(each_term) |
| if any(char in interaction_chars for char in ('*', '^')): |
| # we process all interaction terms after the basis terms have |
| # been processed to enable using an interaction term |
| # before the basis term is defined in the design string |
| interaction_term_strings[base_1_index] = each_term |
| else: |
| basis_term = self._get_term_from_str(each_term, base_1_index) |
| if basis_term in self.inverse_basis_terms: |
| # basis term already seen |
| raise ValueError( |
| "Invalid input in x_design. Duplicate identifier string " |
| "({0}) for basis terms ".format(basis_term)) |
| |
| self.basis_terms[base_1_index] = basis_term |
| self.inverse_basis_terms[basis_term] = base_1_index |
| |
| if isinstance(basis_term, IndicatorTerm): |
| self.indicator_terms[basis_term.category].append(basis_term) |
| if basis_term.is_reference: |
| if not basis_term.category in self.reference_terms: |
| self.reference_terms[basis_term.category] = basis_term |
| else: |
| raise ValueError( |
| "Invalid input in x_design. Multiple reference " |
| "terms present for category '{0}'". |
| format(basis_term.category)) |
| |
| # parse all interaction terms after the basis terms are parsed |
| for each_index, each_term in interaction_term_strings.items(): |
| self.interaction_terms[each_index] = \ |
| self._parse_interaction_term(each_term, each_index) |
| |
| return len(self.basis_terms), len(self.interaction_terms) |
| # -------------------------------------------------------------------------- |
| |
| def str_sum_int_terms(self, coef_array_name, indep_array_name=None, quoted=False): |
| if not self.interaction_terms: |
| return "" |
| result = [] |
| for each_index, each_term in self.interaction_terms.items(): |
| quote_str = "\"" if quoted else "" |
| coef_str = "{q}{0}{q}[{1}]".format(coef_array_name, str(each_index), q=quote_str) |
| result.append(coef_str + " * " + |
| each_term.str_using_indices(indep_array_name, quoted=quoted)) |
| return " + ".join(result) |
| #--------------------------------------------------------------------------- |
| |
| def str_sum_deriv_int_terms(self, ref, coef_array_name, indep_array_name=None, quoted=False): |
| if not self.interaction_terms: |
| return "" |
| result = [] |
| for each_index, each_term in self.interaction_terms.items(): |
| quote_str = "\"" if quoted else "" |
| coef_str = "{q}{0}{q}[{1}]".format(coef_array_name, str(each_index), q=quote_str) |
| deriv_int_term = each_term.compute_partial_deriv(ref) |
| if deriv_int_term: |
| result.append(coef_str + " * " + |
| deriv_int_term.str_using_indices(indep_array_name, |
| quoted=quoted)) |
| return " + ".join(result) |
| #--------------------------------------------------------------------------- |
| |
| def get_all_indicator_terms(self): |
| return sorted(chain(*self.indicator_terms.values()), |
| key=operator.attrgetter('index')) |
| |
| def get_subset_basis_terms(self): |
| """ |
| Get the basis terms whose indices are present in subset_basis. |
| If subset_basis is None then all basis terms are returned. |
| |
| Returns: |
| List. Each element is a TermBase |
| """ |
| if self.subset_basis: |
| return [self.basis_terms[i] for i in self.subset_basis] |
| else: |
| return self.basis_terms.values() |
| |
| def get_subset_indicator_terms(self): |
| """ |
| Get the indicator terms whose indices are present in subset_basis. |
| If subset_basis is None then all indicator terms in the object are |
| returned. |
| |
| Returns: |
| List. Each element is an IndicatorTerm |
| """ |
| if self.subset_basis: |
| subset_indicator_terms = list(set(self.get_all_indicator_terms()) & |
| set(self.get_subset_basis_terms())) |
| return sorted(subset_indicator_terms, key=operator.attrgetter('index')) |
| else: |
| return self.get_all_indicator_terms() |
| #--------------------------------------------------------------------------- |
| |
| def get_sorted_basis_identifiers(self): |
| """ Return all identifiers sorted by the index in basis terms """ |
| return [val.identifier for key, val in |
| sorted(self.basis_terms.items(), key=operator.itemgetter(0))] |
| |
| def get_all_basis_identifiers(self): |
| return dict([(val.identifier, key) for key, val in self.basis_terms.items()]) |
| |
| def get_all_indicator_indices(self): |
| """ Get the indices for all indicator terms present in the object.""" |
| return [i.index for i in self.get_all_indicator_terms()] |
| #--------------------------------------------------------------------------- |
| |
| def get_subset_indicator_indices(self): |
| return [i.index for i in self.get_subset_indicator_terms()] |
| #--------------------------------------------------------------------------- |
| |
| def get_all_reference_indices(self): |
| return sorted([i.index for i in self.reference_terms.values()]) |
| #--------------------------------------------------------------------------- |
| |
| def get_interaction_terms_containing(self, ref): |
| """ |
| Args: |
| @param ref: TermBase. The reference term to search in each |
| interaction term |
| |
| Returns: |
| Dict. A subset of self.interaction_terms where each interaction term contains 'ref' |
| """ |
| if not self.interaction_terms: |
| return self.interaction_terms |
| result = dict() |
| for index, int_term in self.interaction_terms.items(): |
| if ref in int_term: |
| result[index] = int_term |
| return result |
| #--------------------------------------------------------------------------- |
| |
| def partial_deriv_interaction_terms(self, ref): |
| result = dict() |
| if self.interaction_terms: |
| for each_index, each_int_term in self.interaction_terms.items(): |
| each_pder = each_int_term.compute_partial_deriv(ref) |
| if each_pder: |
| result[each_index] = each_pder |
| return result |
| #--------------------------------------------------------------------------- |
| |
| def create_2nd_derivative_matrix(self, data_array_name=None, |
| quoted=False, discrete=True): |
| """ |
| Compute the \frac{\partial f_i}{\partial x_k} for all i and k, |
| where i is index of independent variable and k is index of basis variable |
| |
| Args: |
| @param data_array_name: str, Array name of the data |
| @param quoted: bool, If True the data_array_name will be quoted |
| using double quotes |
| @param discrete: bool, If True, discrete differences will be used |
| for categorical variables instead of the partial |
| derivative |
| """ |
| actual_basis = self.get_subset_basis_terms() |
| actual_indices = [self.inverse_basis_terms[i] for i in actual_basis] |
| |
| derivative_matrix_str_list = [] |
| if not self.interaction_terms: |
| # return an identity matrix of size num_basis_terms x num_basis_terms |
| for outer in actual_indices: |
| derivative_matrix_str_list.append( |
| ['1' if outer==inner else '0' |
| for inner in sorted(self.basis_terms.keys())]) |
| else: |
| for index, curr_basis in zip(actual_indices, actual_basis): |
| if discrete and isinstance(curr_basis, IndicatorTerm): |
| set_values, unset_values = self.get_discrete_diff_arrays( |
| curr_basis, data_array_name, quoted) |
| append_str_list = ["0" if i==j else |
| "{0}-{1}".format(str(i), str(j)) |
| for i, j in zip(set_values, unset_values)] |
| derivative_matrix_str_list.append(append_str_list) |
| else: |
| all_partial_deriv_list = ['0'] * self.n_total_terms |
| all_partial_deriv_list[curr_basis.index - 1] = '1' |
| curr_partial_deriv_dict = self.partial_deriv_interaction_terms(curr_basis) |
| for each_index, each_term in curr_partial_deriv_dict.items(): |
| all_partial_deriv_list[each_index - 1] = \ |
| each_term.str_using_indices(data_array_name, quoted) |
| derivative_matrix_str_list.append(all_partial_deriv_list) |
| return derivative_matrix_str_list |
| #--------------------------------------------------------------------------- |
| |
| def get_discrete_diff_arrays(self, term_to_explore, array_name='x', |
| quoted=False, shortened_output=False): |
| """ |
| Create a pair of strings that corresponds to the array expressions for |
| computing the marginal effect for an indicator term using |
| discrete differences. |
| |
| All variables that are not siblings of the term_to_explore |
| are expressed as they were in original expression. |
| |
| Args: |
| @param term_to_explore: TermBase, Term for which to create discrete |
| difference expression |
| @param array_name: str, Name of the array that contains values |
| for the variables. |
| @param shortened_output: bool, If there is no interaction, then |
| all non-categorical items will be the same across the two strings. |
| If shortened_output=True, then only the categorical terms are |
| returned for the set and unset strings. |
| |
| Returns: |
| tuple. A pair of strings each representing an array of data. |
| First one corresponding to term_to_explore |
| being 'set' (=1) with other siblings unset (=0) and the second one |
| corresponding to term_to_explore as 'unset' and other siblings also |
| 'unset' (reference variable, if present, is 'set'). In both cases, |
| variables unrelated to term_to_explore are unchanged. |
| """ |
| quote_str = "\"" if quoted else "" |
| term_siblings = self.get_siblings(term_to_explore) # includes reference |
| if term_to_explore.category not in self.reference_terms: |
| ref_term = None |
| else: |
| ref_term = self.reference_terms[term_to_explore.category] |
| n_terms = self.n_total_terms |
| |
| indicator_set_list = [None] * n_terms |
| indicator_unset_list = [None] * n_terms |
| |
| def assign_discrete_diff(index, set_val, unset_val): |
| indicator_set_list[index] = set_val |
| indicator_unset_list[index] = unset_val |
| |
| for basis_index, basis_term in self.basis_terms.items(): |
| # basis_index is for the original array (in DB) which is 1-base |
| if basis_term == term_to_explore: |
| # set/unset the current variable ... |
| set_val, unset_val = '1', '0' |
| elif basis_term in term_siblings: |
| if basis_term.is_reference: |
| # set reference when unsetting term_to_explore and vice-versa |
| set_val, unset_val = '0', '1' |
| else: |
| # other siblings are always unset |
| set_val, unset_val = '0', '0' |
| else: |
| # all other variables are added in as is |
| basis_str = "{q}{0}{q}[{1}]".format(array_name, basis_term.index, |
| q=quote_str) |
| set_val, unset_val = basis_str, basis_str |
| |
| if (self.is_categorical(basis_term) or |
| self.contains_interaction() or |
| not shortened_output): |
| assign_discrete_diff(basis_index - 1, set_val, unset_val) |
| |
| # repeat the process for interaction terms |
| for int_index, int_term in self.interaction_terms.items(): |
| # again 'int_index' is for the original array (in DB) which is 1-base |
| sibling_in_interaction = any(not sib.is_reference and sib in int_term |
| for sib in term_siblings) |
| if sibling_in_interaction: |
| # sibling is unset in both strings |
| assign_discrete_diff(int_index - 1, '0', '0') |
| else: |
| if term_to_explore in int_term: |
| if ref_term and ref_term in int_term: |
| # reference is unset, when term_to_explore is set |
| set_val = '0' |
| else: |
| set_val = int_term.str_using_indices( |
| array_name, quoted, [term_to_explore]) |
| set_val = '1' if not set_val else set_val |
| assign_discrete_diff(int_index - 1, set_val, '0') |
| elif ref_term and ref_term in int_term: |
| unset_val = int_term.str_using_indices(array_name, quoted, |
| [ref_term]) |
| unset_val = '1' if not unset_val else unset_val |
| assign_discrete_diff(int_index - 1, '0', unset_val) |
| else: |
| term_as_is = int_term.str_using_indices(array_name, quoted) |
| assign_discrete_diff(int_index - 1, term_as_is, term_as_is) |
| return (indicator_set_list, indicator_unset_list) |
| #------------------------------------------------------------------------------ |
| |
| |
| import unittest |
| from itertools import izip_longest |
| |
| |
| class MarginsTestCase(unittest.TestCase): |
| |
| def setUp(self): |
| pass |
| self.maxDiff = None |
| self.xd = ("ir.1.color, i.2.color, i.3.color, i.is_female, i.degree," |
| "degree * 2, i.degree.degree * i.3.color, " |
| "degree * i.is_female.gender, degree * is_female * i.2.color," |
| " degree*is_female*3, age, age^2, degree*age, degree*age^2, " |
| "height") |
| # self.xd = ("ir.1.color, i.2.color, i.3.color, i.4.gender, i.5.degree," |
| # "degree * 2.color, degree * i.3.color, degree * gender," |
| # "degree * gender * i.2.color, degree*gender*3," |
| # "11.age, age^2, degree*age, degree*age^2, 15") |
| self.int_obj = MarginalEffectsBuilder(self.xd) |
| # print("--------------------------------------------------------") |
| # print("Basis terms = " + str(self.int_obj.basis_terms)) |
| # print("Inverse Basis terms = " + str(self.int_obj.inverse_basis_terms)) |
| # print("Indicator terms = " + str(self.int_obj.indicator_terms)) |
| # print("Reference terms = " + str(self.int_obj.reference_terms)) |
| # print("Interaction terms = " + str(self.int_obj.interaction_terms)) |
| # print("--------------------------------------------------------") |
| |
| self.xd2 = ('1, 2, 3, 4, 5, 6, 7, 8, 3*2, ' |
| '4*3*5, 5^2*4*6, 6^2, 7^3*8') |
| # self.xd2 = ('1, 2, age, height, length, 6, 7, 8, age*2, ' |
| # 'height*age*length, length^2*height*6, 6^2, 7^3*8') |
| self.int_obj2 = MarginalEffectsBuilder(self.xd2) |
| |
| self.xd3 = '1, 2, 3, 4, 5, 6, 7, 8' |
| self.int_obj3 = MarginalEffectsBuilder(self.xd3) |
| |
| def test_interaction_term(self): |
| int_term1 = InteractionTerm(4) # "1*2" |
| int_term1.add_base_term(ContinuousTerm('1', 1), 1) |
| int_term1.add_base_term(ContinuousTerm('2', 2), 1) |
| self.assertEqual(int_term1, int_term1) |
| |
| int_term2 = InteractionTerm(3) |
| int_term2.add_base_term(ContinuousTerm('a1111', 1), 1) |
| int_term2.add_base_term(ContinuousTerm('h1111', 2), 1) |
| self.assertNotEqual(int_term1, int_term2) |
| int_term2.constant_term = 5 |
| self.assertNotEqual(int_term1, int_term2) |
| |
| def test_ind_prefix(self): |
| self.assertEqual(self.int_obj._get_term_from_str('i.1.color', 1), |
| IndicatorTerm('1', 1, 'color')) |
| self.assertEqual(self.int_obj._get_term_from_str('1', 1), |
| ContinuousTerm('1', 1)) |
| self.assertEqual(self.int_obj._get_term_from_str('10', 10), |
| ContinuousTerm('10', 10)) |
| self.assertEqual(self.int_obj._get_term_from_str('age', 1), |
| ContinuousTerm('age', 1)) |
| self.assertEqual(self.int_obj._get_term_from_str('i.10.gender', 5), |
| IndicatorTerm('10', 5, 'gender')) |
| self.assertEqual(self.int_obj._get_term_from_str('i.is_female', 10), |
| IndicatorTerm('is_female', 10)) |
| |
| def test_parse_interaction(self): |
| str1 = '1*2' |
| int_term1 = InteractionTerm(4) |
| int_term1.add_base_term(ContinuousTerm('1', 1), 1) |
| int_term1.add_base_term(ContinuousTerm('2', 2), 1) |
| self.assertEqual(self.int_obj._parse_interaction_term(str1, 4), int_term1) |
| |
| str2 = 'age^2' |
| int_term2 = InteractionTerm(4) |
| int_term2.add_base_term(ContinuousTerm('age', 1), 2) |
| self.assertEqual(self.int_obj._parse_interaction_term(str2, 4), int_term2) |
| |
| str2 = 'age^2*i.is_female.gender*ir.1.color' |
| int_term2 = InteractionTerm(4) |
| int_term2.add_base_term(ContinuousTerm('age', 1), 2) |
| int_term2.add_base_term(IndicatorTerm('is_female', 2, 'gender'), 1) |
| int_term2.add_base_term(IndicatorTerm('1', 3, 'color', True), 1) |
| self.assertEqual(self.int_obj._parse_interaction_term(str2, 4), int_term2) |
| self.assertRaises(ValueError, self.int_obj._parse_interaction_term, '1^-2*2^7*3^4', 2) |
| self.assertRaises(ValueError, self.int_obj._parse_interaction_term, '1^2*2^7.1*3^4', 2) |
| self.assertRaises(ValueError, self.int_obj._parse_interaction_term, '1^2/2^7.1*3^4', 2) |
| |
| def test_str_interaction(self): |
| int_obj1 = InteractionTerm(4) |
| int_obj1.add_base_term(ContinuousTerm('a', 1), 1) |
| int_obj1.add_base_term(ContinuousTerm('b', 2), 4) |
| int_obj1.add_base_term(ContinuousTerm('c', 3), 2) |
| int_obj1.constant_term = 5 |
| self.assertEqual(str(int_obj1), 'a * c^2 * b^4 * 5') |
| self.assertEqual(int_obj1.str_using_indices(None), 'a * c^2 * b^4 * 5') |
| self.assertEqual(int_obj1.str_using_indices('x', False), |
| 'x[1] * x[2]^4 * x[3]^2 * 5') |
| |
| def test_x_design(self): |
| basis_terms_dict = dict() |
| for i in range(1, 9): |
| basis_terms_dict[i] = ContinuousTerm(str(i), i) |
| self.assertEqual(self.int_obj2.basis_terms, |
| basis_terms_dict) |
| self.assertEqual(self.int_obj2.indicator_terms, dict()) |
| self.assertEqual(self.int_obj2.reference_terms, dict()) |
| |
| def test_x_design1(self): |
| xd = '1, 2, 1*2, "1^2", 5, 6, 1^1*2^2*5^4, 5^2, "1^2"*5' |
| int_obj = MarginalEffectsBuilder(xd) |
| self.assertEqual([i.identifier for i in int_obj.basis_terms.values()], |
| ['1', '2', '"1^2"', '5', '6']) |
| self.assertEqual(int_obj.indicator_terms, {}) |
| self.assertEqual(int_obj.reference_terms, {}) |
| self.assertEqual(int_obj.contains_indicators(), False) |
| self.assertEqual(int_obj.contains_interaction(), True) |
| self.assertEqual(str(int_obj.interaction_terms), |
| '{8: 5^2, 9: 5 * "1^2", 3: 1 * 2, 7: 1 * 2^2 * 5^4}') |
| |
| def test_x_design2(self): |
| xd = '1, i."red".color, 1*red, 5, 1^1*red^2* 5^4, 6, 5^2' |
| basis_terms_dict = {} |
| for ind, var in zip((1, 4, 6), ('1', '5', '6')): |
| basis_terms_dict[ind] = ContinuousTerm(var, ind) |
| basis_terms_dict[2] = IndicatorTerm('"red"', 2, 'color', False) |
| int_obj = MarginalEffectsBuilder(xd) |
| self.assertEqual(int_obj.basis_terms, basis_terms_dict) |
| self.assertEqual(int_obj.indicator_terms, {'color': [basis_terms_dict[2]]}) |
| self.assertEqual(int_obj.get_all_indicator_terms(), [basis_terms_dict[2]]) |
| self.assertEqual(int_obj.get_all_indicator_indices(), [2]) |
| self.assertEqual(int_obj.get_all_reference_indices(), []) |
| self.assertEqual(int_obj.get_subset_indicator_indices(), [2]) |
| self.assertEqual(int_obj.get_subset_indicator_terms(), [basis_terms_dict[2]]) |
| self.assertRaises(ValueError, int_obj.subset_basis, [1, 3]) |
| int_obj.subset_basis = [1, 5, 6] |
| self.assertEqual(int_obj.get_subset_indicator_indices(), []) |
| self.assertEqual(int_obj.get_subset_basis_terms(), |
| [basis_terms_dict[1], basis_terms_dict[4], basis_terms_dict[6]]) |
| self.assertEqual(int_obj.get_subset_indicator_terms(), []) |
| self.assertEqual(int_obj.contains_indicators(), True) |
| self.assertEqual(int_obj.get_indicator_reference(basis_terms_dict[2]), None) |
| self.assertEqual(int_obj.get_indicator_reference(None), None) |
| self.assertEqual(int_obj.reference_terms, {}) |
| self.assertEqual(int_obj.str_sum_int_terms('B', 'x'), |
| "B[3] * x[1] * x[2] + B[5] * x[1] * x[2]^2 * " |
| "x[4]^4 + B[7] * x[4]^2") |
| |
| def test_x_design3(self): |
| xd = 'i."1*2".color, Age, 3, "1*2"*age' |
| int_obj = MarginalEffectsBuilder(xd) |
| self.assertEqual(int_obj.basis_terms, |
| {1: ContinuousTerm('1*2', 1), |
| 2: ContinuousTerm('age', 2), |
| 3: ContinuousTerm('3', 3)}) |
| self.assertEqual(int_obj.indicator_terms, {'color': [IndicatorTerm('1*2', 1, 'color')]}) |
| self.assertEqual(int_obj.reference_terms, dict()) |
| |
| def test_x_design_failures(self): |
| xd = '1, 2, 1*5, 1*2^2, 5.color, i.6' |
| # 5.color is invalid input |
| self.assertRaises(ValueError, MarginalEffectsBuilder, xd) |
| xd = '1, 2, 1*age, 1*2^2, i.6' |
| # age not defined as basis |
| self.assertRaises(ValueError, MarginalEffectsBuilder, xd) |
| xd = '1, test.2.color, 1*age, 1*2^2, i.6' |
| # invalid prefix "test" |
| self.assertRaises(ValueError, MarginalEffectsBuilder, xd) |
| xd = '1, test.2.color, 1*"age", 1*2^2, i.6' |
| # age not defined as basis |
| self.assertRaises(ValueError, MarginalEffectsBuilder, xd) |
| xd = '1, i.2.color, 1*"Age", 1*age^2, i.6' |
| # age not defined as basis |
| self.assertRaises(ValueError, MarginalEffectsBuilder, xd) |
| xd = '1, i.red.color, age$, 1*red^2, i.6' |
| # age$ not quoted |
| self.assertRaises(ValueError, MarginalEffectsBuilder, xd) |
| # test = MarginalEffectsBuilder(xd) |
| |
| def test_x_design5(self): |
| xd = '1, i.2, 1*i.2.color, 1^1*2^2*3^4' |
| # '3' in interaction term is invalid |
| self.assertRaises(ValueError, MarginalEffectsBuilder, xd) |
| |
| def test_x_design6(self): |
| basis_terms = dict(zip([11, 15], [ContinuousTerm(var, i) for i, var in |
| zip([11, 15], ['age', 'height'])])) |
| for i, var, cat, is_ref in izip_longest( |
| [1, 2, 3, 4, 5], |
| ['1', '2', '3', 'is_female', 'degree'], |
| ['color']*3, |
| [True] + [False]*4): |
| basis_terms[i] = IndicatorTerm(var, i, cat, is_ref) |
| self.assertEqual(self.int_obj.basis_terms, basis_terms) |
| self.assertEqual(self.int_obj.indicator_terms, |
| {'color': [basis_terms[i] for i in range(1, 4)], |
| '': [basis_terms[4], basis_terms[5]]}) |
| self.assertEqual(self.int_obj.get_indicator_reference(basis_terms[2]), basis_terms[1]) |
| self.assertEqual(self.int_obj.reference_terms, {'color': basis_terms[1]}) |
| self.assertEqual(self.int_obj.get_all_reference_indices(), [1]) |
| self.assertEqual(str(self.int_obj.interaction_terms), |
| "{6: 2 * degree, 7: 3 * degree, " |
| "8: is_female * degree, 9: is_female * 2 * degree, " |
| "10: is_female * 3 * degree, 12: age^2, " |
| "13: age * degree, 14: age^2 * degree}") |
| self.assertEqual(str(self.int_obj.get_interaction_terms_containing(basis_terms[2])), |
| "{9: is_female * 2 * degree, 6: 2 * degree}") |
| |
| self.assertEqual(self.int_obj.get_siblings(basis_terms[11]), None) |
| self.assertEqual(self.int_obj.get_siblings(basis_terms[1]), [basis_terms[2], basis_terms[3]]) |
| self.assertEqual(self.int_obj.get_siblings(basis_terms[2]), [basis_terms[1], basis_terms[3]]) |
| self.assertEqual(self.int_obj.get_siblings(basis_terms[3]), [basis_terms[1], basis_terms[2]]) |
| |
| def test_compute_partial_deriv(self): |
| int_obj1 = InteractionTerm(4) |
| int_obj1.add_base_term(ContinuousTerm('a', 1), 1) |
| int_obj1.add_base_term(ContinuousTerm('b', 2), 4) |
| int_obj1.add_base_term(ContinuousTerm('c', 3), 2) |
| int_obj1.constant_term = 5 |
| |
| int_obj1_deriv_a = InteractionTerm(0) |
| int_obj1_deriv_a.constant_term = 5 |
| int_obj1_deriv_a.add_base_term(ContinuousTerm('b', 2), 4) |
| int_obj1_deriv_a.add_base_term(ContinuousTerm('c', 3), 2) |
| self.assertEqual(int_obj1.compute_partial_deriv(ContinuousTerm('a', 1)), |
| int_obj1_deriv_a) |
| |
| int_obj1_deriv_b = InteractionTerm(0) |
| int_obj1_deriv_b.constant_term = 20 |
| int_obj1_deriv_b.add_base_term(ContinuousTerm('a', 1), 1) |
| int_obj1_deriv_b.add_base_term(ContinuousTerm('b', 2), 3) |
| int_obj1_deriv_b.add_base_term(ContinuousTerm('c', 3), 2) |
| self.assertEqual(int_obj1.compute_partial_deriv(ContinuousTerm('b', 2)), |
| int_obj1_deriv_b) |
| |
| def test_partial_deriv_all(self): |
| self.assertEqual(str( |
| self.int_obj.partial_deriv_interaction_terms( |
| IndicatorTerm("degree", 5))), |
| "{6: 2, 7: 3, 8: is_female, 9: is_female * 2, " |
| "10: is_female * 3, 13: age, 14: age^2}") |
| self.assertEqual(str(self.int_obj.partial_deriv_interaction_terms(IndicatorTerm("2", 2))), "{9: is_female * degree, 6: degree}") |
| |
| self.assertEqual(str( |
| self.int_obj.partial_deriv_interaction_terms(IndicatorTerm('age', 0))), |
| "{12: age * 2, 13: degree, 14: age * degree * 2}") |
| |
| def test_deriv_str(self): |
| me_str = [] |
| coeff_array_str = 'B' |
| for each_index, each_basis in sorted(self.int_obj.basis_terms.items()): |
| # if each_basis not in flattened_indicator_terms: |
| basis_deriv_str = self.int_obj.str_sum_deriv_int_terms( |
| each_basis, coeff_array_str, 'x', quoted=True) |
| basis_str = "\"{0}\"[{1}]".format(coeff_array_str, each_index) |
| if basis_deriv_str: |
| basis_str += " + " + basis_deriv_str |
| me_str.append(basis_str) |
| |
| # the above code generates the string for \beta_k + \sum \bet_s I_s |
| self.assertEqual( |
| ', '.join(me_str), |
| '"B"[1], ' |
| '"B"[2] + "B"[6] * "x"[5] + "B"[9] * "x"[4] * "x"[5], ' |
| '"B"[3] + "B"[7] * "x"[5] + "B"[10] * "x"[4] * "x"[5], ' |
| '"B"[4] + "B"[8] * "x"[5] + "B"[9] * "x"[2] * "x"[5] + "B"[10] * "x"[3] * "x"[5], ' |
| '"B"[5] + "B"[6] * "x"[2] + "B"[7] * "x"[3] + "B"[8] * "x"[4] + ' |
| '"B"[9] * "x"[2] * "x"[4] + "B"[10] * "x"[3] * "x"[4] + ' |
| '"B"[13] * "x"[11] + "B"[14] * "x"[11]^2, ' |
| '"B"[11] + "B"[12] * "x"[11] * 2 + "B"[13] * "x"[5] + "B"[14] * "x"[5] * "x"[11] * 2, ' |
| '"B"[15]') |
| |
| def test_create_indicator_me_expr(self): |
| ind_set, ind_unset = self.int_obj.get_discrete_diff_arrays( |
| IndicatorTerm('degree', 0), 'x', True) |
| self.assertEqual(ind_set, |
| ['"x"[1]', '"x"[2]', '"x"[3]', '"x"[4]', '1', '"x"[2]', |
| '"x"[3]', '"x"[4]', '"x"[2] * "x"[4]', '"x"[3] * "x"[4]', |
| '"x"[11]', '"x"[11]^2', '"x"[11]', '"x"[11]^2', '"x"[15]']) |
| self.assertEqual(ind_unset, |
| ['"x"[1]', '"x"[2]', '"x"[3]', '"x"[4]', '0', '0', '0', |
| '0', '0', '0', '"x"[11]', '"x"[11]^2', '0', '0', '"x"[15]']) |
| |
| ind_set2, ind_unset2 = self.int_obj.get_discrete_diff_arrays( |
| IndicatorTerm('2', 0, 'color'), 'x', True) |
| self.assertEqual(ind_set2, |
| ['0', '1', '0', '"x"[4]', '"x"[5]', '"x"[5]', '0', |
| '"x"[4] * "x"[5]', '"x"[4] * "x"[5]', '0', '"x"[11]', |
| '"x"[11]^2', '"x"[5] * "x"[11]', '"x"[5] * "x"[11]^2', |
| '"x"[15]']) |
| self.assertEqual(ind_unset2, |
| ['1', '0', '0', '"x"[4]', '"x"[5]', '0', '0', |
| '"x"[4] * "x"[5]', '0', '0', '"x"[11]', '"x"[11]^2', |
| '"x"[5] * "x"[11]', '"x"[5] * "x"[11]^2', '"x"[15]']) |
| |
| xd2 = ("ir.1.color, i.2.color, i.3.color, i.4.gender, i.5.degree," |
| "5*2, 5*3, 5*4, 5^2, 1*2, 4*1") |
| int_obj2 = MarginalEffectsBuilder(xd2) |
| ind_set, ind_unset = int_obj2.get_discrete_diff_arrays( |
| IndicatorTerm('5', 0, 'degree'), 'x', True) |
| self.assertEqual(ind_set, |
| ['"x"[1]', '"x"[2]', '"x"[3]', '"x"[4]', '1', '"x"[2]', |
| '"x"[3]', '"x"[4]', '1', '"x"[1] * "x"[2]', '"x"[1] * "x"[4]']) |
| self.assertEqual(ind_unset, |
| ['"x"[1]', '"x"[2]', '"x"[3]', '"x"[4]', '0', '0', |
| '0', '0', '0', '"x"[1] * "x"[2]', '"x"[1] * "x"[4]']) |
| |
| ind_set, ind_unset = int_obj2.get_discrete_diff_arrays( |
| IndicatorTerm('2', 0, 'color'), 'x', quoted=False) |
| self.assertEqual(ind_set, |
| ['0', '1', '0', 'x[4]', 'x[5]', 'x[5]', |
| '0', 'x[4] * x[5]', 'x[5]^2', '0', '0']) |
| self.assertEqual(ind_unset, |
| ['1', '0', '0', 'x[4]', 'x[5]', '0', |
| '0', 'x[4] * x[5]', 'x[5]^2', '0', 'x[4]']) |
| |
| ind_set, ind_unset = int_obj2.get_discrete_diff_arrays( |
| IndicatorTerm('3', 0, 'color'), 'x', quoted=False) |
| self.assertEqual(ind_set, |
| ['0', '0', '1', 'x[4]', 'x[5]', '0', |
| 'x[5]', 'x[4] * x[5]', 'x[5]^2', '0', '0']) |
| self.assertEqual(ind_unset, |
| ['1', '0', '0', 'x[4]', 'x[5]', '0', |
| '0', 'x[4] * x[5]', 'x[5]^2', '0', 'x[4]']) |
| |
| def test_derivative_matrix(self): |
| # self.xd = ("ir.1.color, i.2.color, i.3.color, i.4.gender, i.5.degree," |
| # "i.5.degree*i.2.color, i.5.degree*i.3.color, i.5.degree*i.4.gender," |
| # "i.5.degree*i.4.gender*i.2.color, i.5.degree*i.4.gender*i.3.color," |
| # "11, 11^2, i.5.degree*11, i.5.degree*11^2, 15") |
| self.assertEqual(self.int_obj.create_2nd_derivative_matrix('x'), |
| [["1-0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0"], |
| ["0-1", "1-0", "0", "0", "0", "x[5]-0", "0", "0", "x[4] * x[5]-0", "0", "0", "0", "0", "0", "0"], |
| ["0-1", "0", "1-0", "0", "0", "0", "x[5]-0", "0", "0", "x[4] * x[5]-0", "0", "0", "0", "0", "0"], |
| ["0", "0", "0", "1-0", "0", "0", "0", "x[5]-0", "x[2] * x[5]-0", "x[3] * x[5]-0", "0", "0", "0", "0", "0"], |
| ["0", "0", "0", "0", "1-0", "x[2]-0", "x[3]-0", "x[4]-0", |
| "x[2] * x[4]-0", "x[3] * x[4]-0", "0", "0", "x[11]-0", "x[11]^2-0", "0"], |
| ["0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "1", "x[11] * 2", "x[5]", "x[5] * x[11] * 2", "0"], |
| ["0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "1"]]) |
| |
| self.assertEqual(self.int_obj2.create_2nd_derivative_matrix('x'), |
| [["1", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0"], |
| ["0", "1", "0", "0", "0", "0", "0", "0", "x[3]", "0", "0", "0", "0"], |
| ["0", "0", "1", "0", "0", "0", "0", "0", "x[2]", "x[4] * x[5]", "0", "0", "0"], |
| ["0", "0", "0", "1", "0", "0", "0", "0", "0", "x[3] * x[5]", "x[5]^2 * x[6]", "0", "0"], |
| ["0", "0", "0", "0", "1", "0", "0", "0", "0", "x[3] * x[4]", "x[4] * x[5] * x[6] * 2", "0", "0"], |
| ["0", "0", "0", "0", "0", "1", "0", "0", "0", "0", "x[4] * x[5]^2", "x[6] * 2", "0"], |
| ["0", "0", "0", "0", "0", "0", "1", "0", "0", "0", "0", "0", "x[7]^2 * x[8] * 3"], |
| ["0", "0", "0", "0", "0", "0", "0", "1", "0", "0", "0", "0", "x[7]^3"]]) |
| |
| self.int_obj3 = MarginalEffectsBuilder('1, 2, 2*2') |
| self.assertEqual(self.int_obj3.create_2nd_derivative_matrix('x', quoted=True), |
| [["1", "0", "0"], ["0", "1", "\"x\"[2] * 2"]]) |
| |
| def test_subset_basis_derivative(self): |
| xd = ("1, i.2.color, i.3.color, i.4.gender, i.5.degree," |
| "i.5.degree*i.2.color, i.5.degree*i.3.color, i.5.degree*i.4.gender," |
| "i.5.degree*i.4.gender*i.2.color, i.5.degree*i.4.gender*i.3.color," |
| "age, age^2, i.5.degree*age, i.5.degree*age^2, height") |
| interaction_obj = MarginalEffectsBuilder(xd) |
| interaction_obj.subset_basis = ['2', '3', 'age', 'height'] |
| self.assertEqual(interaction_obj.create_2nd_derivative_matrix('x'), |
| [["0", "1-0", "0", "0", "0", "x[5]-0", "0", "0", "x[4] * x[5]-0", "0", "0", "0", "0", "0", "0"], |
| ["0", "0", "1-0", "0", "0", "0", "x[5]-0", "0", "0", "x[4] * x[5]-0", "0", "0", "0", "0", "0"], |
| ["0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "1", "x[11] * 2", "x[5]", "x[5] * x[11] * 2", "0"], |
| ["0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "1"]]) |
| |
| xd2 = ("1, 2, 3, 4, 5, 6, 7") |
| interaction_obj2 = MarginalEffectsBuilder(xd2) |
| interaction_obj2.subset_basis = [3, 4, 5] |
| self.assertEqual(interaction_obj2.create_2nd_derivative_matrix('x'), |
| [["0", "0", "1", "0", "0", "0", "0"], |
| ["0", "0", "0", "1", "0", "0", "0"], |
| ["0", "0", "0", "0", "1", "0", "0"]] |
| ) |
| |
| def test_no_interaction_derivative(self): |
| xd = ("ir.1.color, i.2.color, i.3.color, i.4.gender, i.5.degree") |
| interaction_obj = MarginalEffectsBuilder(xd) |
| self.assertEqual(interaction_obj.create_2nd_derivative_matrix('x'), |
| [['1', '0', '0', '0', '0'], ['0', '1', '0', '0', '0'], |
| ['0', '0', '1', '0', '0'], ['0', '0', '0', '1', '0'], |
| ['0', '0', '0', '0', '1']]) |
| |
| |
| if __name__ == '__main__': |
| # def py_list_to_sql_string(array, array_type=None): |
| # """Convert numeric array to SQL string """ |
| # if array: |
| # type_str = "::" + array_type if array_type else "" |
| # return "ARRAY[{0}]{1}".format(','.join(map(str, array)), type_str) |
| # else: |
| # if not array_type: |
| # array_type = "double precision[]" |
| # return "'{{}}'::{0}".format(array_type) |
| unittest.main() |