blob: f665625e6757ac65b6b56aa04cda4bd7d0c0c677 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import os
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
from data_preparation.data_preparation import LABEL_COLUMN, TEST_CSV, TRAIN_CSV
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
feature_column_names = ['carat', 'cut', 'color', 'clarity', 'depth', 'table', 'x', 'y', 'z']
categorical_cols = ['cut', 'clarity', 'color']
# inference functions
MODEL_JOBLIB_FILENAME = "model.joblib"
preprocessor = ColumnTransformer(
transformers=[('categorical', OneHotEncoder(), categorical_cols)], remainder='passthrough'
)
def train(train_df, test_df, n_jobs, model_dir):
X_train = train_df[feature_column_names]
y_train = train_df[LABEL_COLUMN]
print(X_train.head(3))
X_test = test_df[feature_column_names]
y_test = test_df[LABEL_COLUMN]
# train
diamond_price_model = LinearRegression(n_jobs=n_jobs)
my_pipeline = Pipeline(steps=[('preprocessor', preprocessor), ('model', diamond_price_model)])
my_pipeline.fit(X_train, y_train)
train_predictions = my_pipeline.predict(X_train)
print(train_predictions)
print("validating model")
abs_err = np.abs(my_pipeline.predict(X_test) - y_test)
print(f"Test prediction error:{abs_err} ")
# persist model
path = os.path.join(model_dir, MODEL_JOBLIB_FILENAME)
joblib.dump(my_pipeline, path)
print("model persisted at " + path)
return path
if __name__ == '__main__':
data_path = Path(__file__).parent.parent.absolute().joinpath('data')
parser = argparse.ArgumentParser()
parser.add_argument('--model-dir', type=str, default=os.getenv('SM_MODEL_DIR', f"file://{data_path}"))
parser.add_argument('--n-jobs', default=2)
parser.add_argument(
"--train", type=str, default=os.getenv("SM_CHANNEL_TRAIN", f"file://{data_path.joinpath('train')}")
)
parser.add_argument(
"--test", type=str, default=os.getenv("SM_CHANNEL_TEST", f"file://{data_path.joinpath('test')}")
)
parser.add_argument("--train-file", type=str, default=TRAIN_CSV)
parser.add_argument("--test-file", type=str, default=TEST_CSV)
args = parser.parse_args()
train_df = pd.read_csv(os.path.join(args.train, args.train_file))
test_df = pd.read_csv(os.path.join(args.test, args.test_file))
train(train_df=train_df, test_df=test_df, n_jobs=args.n_jobs, model_dir=args.model_dir)