blob: 5441e4b9bf7e39cfcea4972a7e966c4de05db9d4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Minimal example of how to read samples from a dataset generated by `generate_external_dataset_carbon.py`
using tensorflow, using make_batch_carbon_reader() instead of make_carbon_reader()"""
from __future__ import print_function
import argparse
import jnius_config
import tensorflow as tf
from pycarbon.tests import DEFAULT_CARBONSDK_PATH
from pycarbon.core.carbon_tf_utils import tf_tensors, make_pycarbon_dataset
from pycarbon.core.carbon_reader import make_batch_carbon_reader
from pycarbon.reader import make_reader
from pycarbon.reader import make_tensor, make_dataset
def tensorflow_hello_world(dataset_url='file:///tmp/carbon_external_dataset'):
# Example: tf_tensors will return tensors with dataset data
with make_reader(dataset_url) as reader:
tensor = tf_tensors(reader)
with tf.Session() as sess:
# Because we are using make_reader(), each read returns a batch of rows instead of a single row
batched_sample = sess.run(tensor)
print("id batch: {0}".format(batched_sample.id))
# Example: use tf.data.Dataset API
with make_reader(dataset_url) as reader:
dataset = make_dataset(reader)
iterator = dataset.make_one_shot_iterator()
tensor = iterator.get_next()
with tf.Session() as sess:
batched_sample = sess.run(tensor)
print("id batch: {0}".format(batched_sample.id))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Tensorflow hello world')
parser.add_argument('-c', '--carbon-sdk-path', type=str, default=DEFAULT_CARBONSDK_PATH,
help='carbon sdk path')
args = parser.parse_args()
jnius_config.set_classpath(args.carbon_sdk_path)
tensorflow_hello_world()