| import logging |
| from typing import List, Union |
| |
| import numpy as np |
| import pandas as pd |
| from sklearn.linear_model import LinearRegression |
| from sklearn.metrics import r2_score |
| from sklearn.model_selection import train_test_split |
| |
| from hamilton.function_modifiers import extract_fields |
| |
| |
| @extract_fields( |
| dict( |
| X_train=pd.DataFrame, |
| X_test=pd.DataFrame, |
| y_train=pd.Series, |
| y_test=pd.Series, |
| ) |
| ) |
| def split_data( |
| create_model_input_table: pd.DataFrame, |
| test_size: float, |
| random_state: int, |
| features: List[str], |
| ) -> dict: |
| """Splits data into features and targets training and test sets. |
| |
| Args: |
| data: Data containing features and target. |
| parameters: Parameters defined in parameters/data_science.yml. |
| Returns: |
| Split data. |
| """ |
| X = create_model_input_table[features] |
| y = create_model_input_table["price"] |
| X_train, X_test, y_train, y_test = train_test_split( |
| X, |
| y, |
| test_size=test_size, |
| random_state=random_state, |
| ) |
| return dict( |
| X_train=X_train, |
| X_test=X_test, |
| y_train=y_train, |
| y_test=y_test, |
| ) |
| |
| |
| def train_model(X_train: pd.DataFrame, y_train: pd.Series) -> LinearRegression: |
| """Trains the linear regression model. |
| |
| Args: |
| X_train: Training data of independent features. |
| y_train: Training data for price. |
| |
| Returns: |
| Trained model. |
| """ |
| regressor = LinearRegression() |
| regressor.fit(X_train, y_train) |
| return regressor |
| |
| |
| def evaluate_model( |
| train_model: LinearRegression, |
| X_test: pd.DataFrame, |
| y_test: pd.Series, |
| ) -> Union[float, np.ndarray]: |
| """Calculates and logs the coefficient of determination. |
| |
| Args: |
| regressor: Trained model. |
| X_test: Testing data of independent features. |
| y_test: Testing data for price. |
| """ |
| y_pred = train_model.predict(X_test) |
| score = r2_score(y_test, y_pred) |
| logger = logging.getLogger(__name__) |
| logger.info("Model has a coefficient R^2 of %.3f on test data.", score) |
| return score |