blob: 5bf31914d0a75eac867384f513e918be9ee38c53 [file] [log] [blame]
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility and schema methods for the chicago_taxi sample."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow_transform import coders as tft_coders
from tensorflow_transform.tf_metadata import dataset_schema
from tensorflow_transform.tf_metadata import schema_utils
from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from tensorflow_metadata.proto.v0 import schema_pb2
# Categorical features are assumed to each have a maximum value in the dataset.
MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12]
CATEGORICAL_FEATURE_KEYS = [
'trip_start_hour', 'trip_start_day', 'trip_start_month',
'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area',
'dropoff_community_area'
]
DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds']
# Number of buckets used by tf.transform for encoding each feature.
FEATURE_BUCKET_COUNT = 10
BUCKET_FEATURE_KEYS = [
'pickup_latitude', 'pickup_longitude', 'dropoff_latitude',
'dropoff_longitude'
]
# Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform
VOCAB_SIZE = 1000
# Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed.
OOV_SIZE = 10
VOCAB_FEATURE_KEYS = [
'payment_type',
'company',
]
LABEL_KEY = 'tips'
FARE_KEY = 'fare'
CSV_COLUMN_NAMES = [
'pickup_community_area',
'fare',
'trip_start_month',
'trip_start_hour',
'trip_start_day',
'trip_start_timestamp',
'pickup_latitude',
'pickup_longitude',
'dropoff_latitude',
'dropoff_longitude',
'trip_miles',
'pickup_census_tract',
'dropoff_census_tract',
'payment_type',
'company',
'trip_seconds',
'dropoff_community_area',
'tips',
]
def transformed_name(key):
return key + '_xf'
def transformed_names(keys):
return [transformed_name(key) for key in keys]
# Tf.Transform considers these features as "raw"
def get_raw_feature_spec(schema):
return schema_utils.schema_as_feature_spec(schema).feature_spec
def make_proto_coder(schema):
raw_feature_spec = get_raw_feature_spec(schema)
raw_schema = dataset_schema.from_feature_spec(raw_feature_spec)
return tft_coders.ExampleProtoCoder(raw_schema)
def make_csv_coder(schema):
"""Return a coder for tf.transform to read csv files."""
raw_feature_spec = get_raw_feature_spec(schema)
parsing_schema = dataset_schema.from_feature_spec(raw_feature_spec)
return tft_coders.CsvCoder(CSV_COLUMN_NAMES, parsing_schema)
def clean_raw_data_dict(input_dict, raw_feature_spec):
"""Clean raw data dict."""
output_dict = {}
for key in raw_feature_spec:
if key not in input_dict or not input_dict[key]:
output_dict[key] = []
else:
output_dict[key] = [input_dict[key]]
return output_dict
def make_sql(table_name, max_rows=None, for_eval=False):
"""Creates the sql command for pulling data from BigQuery.
Args:
table_name: BigQuery table name
max_rows: if set, limits the number of rows pulled from BigQuery
for_eval: True if this is for evaluation, false otherwise
Returns:
sql command as string
"""
if for_eval:
# 1/3 of the dataset used for eval
where_clause = 'WHERE MOD(FARM_FINGERPRINT(unique_key), 3) = 0 ' \
'AND pickup_latitude is not null AND pickup_longitude ' \
'is not null AND dropoff_latitude is not null ' \
'AND dropoff_longitude is not null'
else:
# 2/3 of the dataset used for training
where_clause = 'WHERE MOD(FARM_FINGERPRINT(unique_key), 3) > 0 ' \
'AND pickup_latitude is not null AND pickup_longitude ' \
'is not null AND dropoff_latitude is not null ' \
'AND dropoff_longitude is not null'
limit_clause = ''
if max_rows:
limit_clause = 'LIMIT {max_rows}'.format(max_rows=max_rows)
return """
SELECT
CAST(pickup_community_area AS string) AS pickup_community_area,
CAST(dropoff_community_area AS string) AS dropoff_community_area,
CAST(pickup_census_tract AS string) AS pickup_census_tract,
CAST(dropoff_census_tract AS string) AS dropoff_census_tract,
fare,
EXTRACT(MONTH FROM trip_start_timestamp) AS trip_start_month,
EXTRACT(HOUR FROM trip_start_timestamp) AS trip_start_hour,
EXTRACT(DAYOFWEEK FROM trip_start_timestamp) AS trip_start_day,
UNIX_SECONDS(trip_start_timestamp) AS trip_start_timestamp,
pickup_latitude,
pickup_longitude,
dropoff_latitude,
dropoff_longitude,
trip_miles,
payment_type,
company,
trip_seconds,
tips
FROM `{table_name}`
{where_clause}
{limit_clause}
""".format(
table_name=table_name, where_clause=where_clause, limit_clause=limit_clause)
def read_schema(path):
"""Reads a schema from the provided location.
Args:
path: The location of the file holding a serialized Schema proto.
Returns:
An instance of Schema or None if the input argument is None
"""
result = schema_pb2.Schema()
contents = file_io.read_file_to_string(path)
text_format.Parse(contents, result)
return result