blob: 9d0c5d09047ef44d0db9382f409044d42d36cc24 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import pyspark.sql.functions as F
from pyspark.ml import Pipeline
from pyspark.ml.feature import StandardScaler, StringIndexer, VectorAssembler
from pyspark.ml.functions import vector_to_array
from pyspark.sql import SparkSession
def transform(data):
columns_to_scale = data.columns[:-1]
vectorizer = VectorAssembler(inputCols=columns_to_scale, outputCol="features")
scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withStd=True, withMean=True)
labeler = StringIndexer(inputCol=data.columns[-1], outputCol='label')
pipeline = Pipeline(stages=[vectorizer, scaler, labeler])
fitted = pipeline.fit(data)
transformed = fitted.transform(data)
result = transformed.withColumn("feature_arr", vector_to_array("scaled_features")).select(
[F.col("feature_arr")[i].alias(columns_to_scale[i]) for i in range(len(columns_to_scale))] + ['label']
)
return result
def extract(spark, input_uri):
return spark.read.csv(input_uri, header=True, inferSchema=True, comment="#")
def load(data, output_uri):
data.coalesce(1).write.mode("overwrite").csv(output_uri, header=True)
def data_pipeline(input_uri, output_uri):
spark = SparkSession.builder.appName("Prepare Iris Data").getOrCreate()
input = extract(spark, input_uri)
data = transform(input)
load(data, output_uri)
spark.stop()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_uri")
parser.add_argument("--output_uri")
args = parser.parse_args()
data_pipeline(args.input_uri, args.output_uri)