ARROW-10882: [Python] Allow writing dataset from iterator of batches

This binds InMemoryDataset to Python, allowing us to create and write back out datasets from iterables of record batches and various other objects.

Closes #9802 from lidavidm/arrow-10882

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc
index df15578..2df3414 100644
--- a/cpp/src/arrow/dataset/dataset.cc
+++ b/cpp/src/arrow/dataset/dataset.cc
@@ -66,8 +66,11 @@
 
 InMemoryFragment::InMemoryFragment(RecordBatchVector record_batches,
                                    Expression partition_expression)
-    : InMemoryFragment(record_batches.empty() ? schema({}) : record_batches[0]->schema(),
-                       std::move(record_batches), std::move(partition_expression)) {}
+    : Fragment(std::move(partition_expression), /*schema=*/nullptr),
+      record_batches_(std::move(record_batches)) {
+  // Order of argument evaluation is undefined, so compute physical_schema here
+  physical_schema_ = record_batches_.empty() ? schema({}) : record_batches_[0]->schema();
+}
 
 Result<ScanTaskIterator> InMemoryFragment::Scan(std::shared_ptr<ScanOptions> options) {
   // Make an explicit copy of record_batches_ to ensure Scan can be called
@@ -148,6 +151,28 @@
     : Dataset(table->schema()),
       get_batches_(new TableRecordBatchGenerator(std::move(table))) {}
 
+struct ReaderRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator {
+  explicit ReaderRecordBatchGenerator(std::shared_ptr<RecordBatchReader> reader)
+      : reader_(std::move(reader)), consumed_(false) {}
+
+  RecordBatchIterator Get() const final {
+    if (consumed_) {
+      return MakeErrorIterator<std::shared_ptr<RecordBatch>>(Status::Invalid(
+          "RecordBatchReader-backed InMemoryDataset was already consumed"));
+    }
+    consumed_ = true;
+    auto reader = reader_;
+    return MakeFunctionIterator([reader] { return reader->Next(); });
+  }
+
+  std::shared_ptr<RecordBatchReader> reader_;
+  mutable bool consumed_;
+};
+
+InMemoryDataset::InMemoryDataset(std::shared_ptr<RecordBatchReader> reader)
+    : Dataset(reader->schema()),
+      get_batches_(new ReaderRecordBatchGenerator(std::move(reader))) {}
+
 Result<std::shared_ptr<Dataset>> InMemoryDataset::ReplaceSchema(
     std::shared_ptr<Schema> schema) const {
   RETURN_NOT_OK(CheckProjectable(*schema_, *schema));
diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h
index a28b798..6be8305 100644
--- a/cpp/src/arrow/dataset/dataset.h
+++ b/cpp/src/arrow/dataset/dataset.h
@@ -184,6 +184,7 @@
   InMemoryDataset(std::shared_ptr<Schema> schema, RecordBatchVector batches);
 
   explicit InMemoryDataset(std::shared_ptr<Table> table);
+  explicit InMemoryDataset(std::shared_ptr<RecordBatchReader> reader);
 
   std::string type_name() const override { return "in-memory"; }
 
diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc
index a360355..1db96b8 100644
--- a/cpp/src/arrow/dataset/dataset_test.cc
+++ b/cpp/src/arrow/dataset/dataset_test.cc
@@ -79,6 +79,23 @@
                     .status());
 }
 
+TEST_F(TestInMemoryDataset, FromReader) {
+  constexpr int64_t kBatchSize = 1024;
+  constexpr int64_t kNumberBatches = 16;
+
+  SetSchema({field("i32", int32()), field("f64", float64())});
+  auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+  auto source_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
+  auto target_reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch);
+
+  auto dataset = std::make_shared<InMemoryDataset>(source_reader);
+
+  AssertDatasetEquals(target_reader.get(), dataset.get());
+  // Such datasets can only be scanned once
+  ASSERT_OK_AND_ASSIGN(auto fragments, dataset->GetFragments());
+  ASSERT_RAISES(Invalid, fragments.Next());
+}
+
 TEST_F(TestInMemoryDataset, GetFragments) {
   constexpr int64_t kBatchSize = 1024;
   constexpr int64_t kNumberBatches = 16;
@@ -93,6 +110,20 @@
   AssertDatasetEquals(reader.get(), dataset.get());
 }
 
+TEST_F(TestInMemoryDataset, InMemoryFragment) {
+  constexpr int64_t kBatchSize = 1024;
+
+  SetSchema({field("i32", int32()), field("f64", float64())});
+  auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
+  RecordBatchVector batches{batch};
+
+  // Regression test: previously this constructor relied on undefined behavior (order of
+  // evaluation of arguments) leading to fragments being constructed with empty schemas
+  auto fragment = std::make_shared<InMemoryFragment>(batches);
+  ASSERT_OK_AND_ASSIGN(auto schema, fragment->ReadPhysicalSchema());
+  AssertSchemaEqual(batch->schema(), schema);
+}
+
 class TestUnionDataset : public DatasetFixtureMixin {};
 
 TEST_F(TestUnionDataset, ReplaceSchema) {
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index b012f97..17eedb3 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -26,11 +26,11 @@
 
 import pyarrow as pa
 from pyarrow.lib cimport *
-from pyarrow.lib import frombytes, tobytes
+from pyarrow.lib import ArrowTypeError, frombytes, tobytes
 from pyarrow.includes.libarrow_dataset cimport *
 from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
 from pyarrow._csv cimport ConvertOptions, ParseOptions, ReadOptions
-from pyarrow.util import _is_path_like, _stringify_path
+from pyarrow.util import _is_iterable, _is_path_like, _stringify_path
 
 from pyarrow._parquet cimport (
     _create_writer_properties, _create_arrow_writer_properties,
@@ -442,6 +442,76 @@
         return pyarrow_wrap_schema(self.dataset.schema())
 
 
+cdef class InMemoryDataset(Dataset):
+    """A Dataset wrapping in-memory data.
+
+    Parameters
+    ----------
+    source
+        The data for this dataset. Can be a RecordBatch, Table, list of
+        RecordBatch/Table, iterable of RecordBatch, or a RecordBatchReader.
+        If an iterable is provided, the schema must also be provided.
+    schema : Schema, optional
+        Only required if passing an iterable as the source.
+    """
+
+    cdef:
+        CInMemoryDataset* in_memory_dataset
+
+    def __init__(self, source, Schema schema=None):
+        cdef:
+            RecordBatchReader reader
+            shared_ptr[CInMemoryDataset] in_memory_dataset
+
+        if isinstance(source, (pa.RecordBatch, pa.Table)):
+            source = [source]
+
+        if isinstance(source, (list, tuple)):
+            batches = []
+            for item in source:
+                if isinstance(item, pa.RecordBatch):
+                    batches.append(item)
+                elif isinstance(item, pa.Table):
+                    batches.extend(item.to_batches())
+                else:
+                    raise TypeError(
+                        'Expected a list of tables or batches. The given list '
+                        'contains a ' + type(item).__name__)
+                if schema is None:
+                    schema = item.schema
+                elif not schema.equals(item.schema):
+                    raise ArrowTypeError(
+                        f'Item has schema\n{item.schema}\nwhich does not '
+                        f'match expected schema\n{schema}')
+            if not batches and schema is None:
+                raise ValueError('Must provide schema to construct in-memory '
+                                 'dataset from an empty list')
+            table = pa.Table.from_batches(batches, schema=schema)
+            in_memory_dataset = make_shared[CInMemoryDataset](
+                pyarrow_unwrap_table(table))
+        elif isinstance(source, pa.ipc.RecordBatchReader):
+            reader = source
+            in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
+        elif _is_iterable(source):
+            if schema is None:
+                raise ValueError('Must provide schema to construct in-memory '
+                                 'dataset from an iterable')
+            reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
+            in_memory_dataset = make_shared[CInMemoryDataset](reader.reader)
+        else:
+            raise TypeError(
+                'Expected a table, batch, iterable of tables/batches, or a '
+                'record batch reader instead of the given type: ' +
+                type(source).__name__
+            )
+
+        self.init(<shared_ptr[CDataset]> in_memory_dataset)
+
+    cdef void init(self, const shared_ptr[CDataset]& sp):
+        Dataset.init(self, sp)
+        self.in_memory_dataset = <CInMemoryDataset*> sp.get()
+
+
 cdef class UnionDataset(Dataset):
     """A Dataset wrapping child datasets.
 
@@ -2603,10 +2673,14 @@
 
 
 def _filesystemdataset_write(
-    data not None, object base_dir not None, str basename_template not None,
-    Schema schema not None, FileSystem filesystem not None,
+    Dataset data not None,
+    object base_dir not None,
+    str basename_template not None,
+    Schema schema not None,
+    FileSystem filesystem not None,
     Partitioning partitioning not None,
-    FileWriteOptions file_options not None, bint use_threads,
+    FileWriteOptions file_options not None,
+    bint use_threads,
     int max_partitions,
 ):
     """
@@ -2624,21 +2698,7 @@
     c_options.max_partitions = max_partitions
     c_options.basename_template = tobytes(basename_template)
 
-    if isinstance(data, Dataset):
-        scanner = data._scanner(use_threads=use_threads)
-    else:
-        # data is list of batches/tables
-        for table in data:
-            if isinstance(table, Table):
-                for batch in table.to_batches():
-                    c_batches.push_back((<RecordBatch> batch).sp_batch)
-            else:
-                c_batches.push_back((<RecordBatch> table).sp_batch)
-
-        data = Fragment.wrap(shared_ptr[CFragment](
-            new CInMemoryFragment(move(c_batches), _true.unwrap())))
-
-        scanner = Scanner.from_fragment(data, schema, use_threads=use_threads)
+    scanner = data._scanner(use_threads=use_threads)
 
     c_scanner = (<Scanner> scanner).unwrap()
     with nogil:
diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py
index 195d414..0c65070 100644
--- a/python/pyarrow/dataset.py
+++ b/python/pyarrow/dataset.py
@@ -18,7 +18,7 @@
 """Dataset is currently unstable. APIs subject to change without notice."""
 
 import pyarrow as pa
-from pyarrow.util import _stringify_path, _is_path_like
+from pyarrow.util import _is_iterable, _stringify_path, _is_path_like
 
 from pyarrow._dataset import (  # noqa
     CsvFileFormat,
@@ -37,6 +37,7 @@
     HivePartitioning,
     IpcFileFormat,
     IpcFileWriteOptions,
+    InMemoryDataset,
     ParquetDatasetFactory,
     ParquetFactoryOptions,
     ParquetFileFormat,
@@ -408,6 +409,13 @@
     return factory.finish(schema)
 
 
+def _in_memory_dataset(source, schema=None, **kwargs):
+    if any(v is not None for v in kwargs.values()):
+        raise ValueError(
+            "For in-memory datasets, you cannot pass any additional arguments")
+    return InMemoryDataset(source, schema)
+
+
 def _union_dataset(children, schema=None, **kwargs):
     if any(v is not None for v in kwargs.values()):
         raise ValueError(
@@ -508,7 +516,8 @@
 
     Parameters
     ----------
-    source : path, list of paths, dataset, list of datasets or URI
+    source : path, list of paths, dataset, list of datasets, (list of) batches
+             or tables, iterable of batches, RecordBatchReader, or URI
         Path pointing to a single file:
             Open a FileSystemDataset from a single file.
         Path pointing to a directory:
@@ -524,6 +533,11 @@
             A nested UnionDataset gets constructed, it allows arbitrary
             composition of other datasets.
             Note that additional keyword arguments are not allowed.
+        (List of) batches or tables, iterable of batches, or RecordBatchReader:
+            Create an InMemoryDataset. If an iterable or empty list is given,
+            a schema must also be given. If an iterable or RecordBatchReader
+            is given, the resulting dataset can only be scanned once; further
+            attempts will raise an error.
     schema : Schema, optional
         Optionally provide the Schema for the Dataset, in which case it will
         not be inferred from the source.
@@ -636,7 +650,6 @@
         selector_ignore_prefixes=ignore_prefixes
     )
 
-    # TODO(kszucs): support InMemoryDataset for a table input
     if _is_path_like(source):
         return _filesystem_dataset(source, **kwargs)
     elif isinstance(source, (tuple, list)):
@@ -644,13 +657,22 @@
             return _filesystem_dataset(source, **kwargs)
         elif all(isinstance(elem, Dataset) for elem in source):
             return _union_dataset(source, **kwargs)
+        elif all(isinstance(elem, (pa.RecordBatch, pa.Table))
+                 for elem in source):
+            return _in_memory_dataset(source, **kwargs)
         else:
             unique_types = set(type(elem).__name__ for elem in source)
             type_names = ', '.join('{}'.format(t) for t in unique_types)
             raise TypeError(
-                'Expected a list of path-like or dataset objects. The given '
-                'list contains the following types: {}'.format(type_names)
+                'Expected a list of path-like or dataset objects, or a list '
+                'of batches or tables. The given list contains the following '
+                'types: {}'.format(type_names)
             )
+    elif isinstance(source, (pa.RecordBatch, pa.ipc.RecordBatchReader,
+                             pa.Table)):
+        return _in_memory_dataset(source, **kwargs)
+    elif _is_iterable(source):
+        return _in_memory_dataset(source, **kwargs)
     else:
         raise TypeError(
             'Expected a path-like, list of path-likes or a list of Datasets '
@@ -676,9 +698,11 @@
 
     Parameters
     ----------
-    data : Dataset, Table/RecordBatch, or list of Table/RecordBatch
+    data : Dataset, Table/RecordBatch, RecordBatchReader, list of
+           Table/RecordBatch, or iterable of RecordBatch
         The data to write. This can be a Dataset instance or
-        in-memory Arrow data.
+        in-memory Arrow data. If an iterable is given, the schema must
+        also be given.
     base_dir : str
         The root directory where to write the dataset.
     basename_template : str, optional
@@ -710,15 +734,17 @@
 
     if isinstance(data, Dataset):
         schema = schema or data.schema
-    elif isinstance(data, (pa.Table, pa.RecordBatch)):
-        schema = schema or data.schema
-        data = [data]
-    elif isinstance(data, list):
+    elif isinstance(data, (list, tuple)):
         schema = schema or data[0].schema
+        data = InMemoryDataset(data, schema=schema)
+    elif isinstance(data, (pa.RecordBatch, pa.ipc.RecordBatchReader,
+                           pa.Table)) or _is_iterable(data):
+        data = InMemoryDataset(data, schema=schema)
+        schema = schema or data.schema
     else:
         raise ValueError(
-            "Only Dataset, Table/RecordBatch or a list of Table/RecordBatch "
-            "objects are supported."
+            "Only Dataset, Table/RecordBatch, RecordBatchReader, a list "
+            "of Tables/RecordBatches, or iterable of batches are supported."
         )
 
     if format is None and isinstance(data, FileSystemDataset):
diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd
index f7f2a14..db2e73a 100644
--- a/python/pyarrow/includes/libarrow_dataset.pxd
+++ b/python/pyarrow/includes/libarrow_dataset.pxd
@@ -124,6 +124,11 @@
 
         CResult[shared_ptr[CScannerBuilder]] NewScan()
 
+    cdef cppclass CInMemoryDataset "arrow::dataset::InMemoryDataset"(
+            CDataset):
+        CInMemoryDataset(shared_ptr[CRecordBatchReader])
+        CInMemoryDataset(shared_ptr[CTable])
+
     cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"(
             CDataset):
         @staticmethod
diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py
index 36cff99..a7dd152 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -1622,13 +1622,16 @@
         fs.FileSelector('/schema'),
         format=ds.ParquetFileFormat()
     )
+    batch1 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+    batch2 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["b"])
 
     with pytest.raises(TypeError, match='Expected.*FileSystemDatasetFactory'):
         ds.dataset([child1, child2])
 
     expected = (
-        "Expected a list of path-like or dataset objects. The given list "
-        "contains the following types: int"
+        "Expected a list of path-like or dataset objects, or a list "
+        "of batches or tables. The given list contains the following "
+        "types: int"
     )
     with pytest.raises(TypeError, match=expected):
         ds.dataset([1, 2, 3])
@@ -1640,6 +1643,85 @@
     with pytest.raises(TypeError, match=expected):
         ds.dataset(None)
 
+    expected = (
+        "Must provide schema to construct in-memory dataset from an iterable"
+    )
+    with pytest.raises(ValueError, match=expected):
+        ds.dataset((batch1 for _ in range(3)))
+
+    expected = (
+        "Must provide schema to construct in-memory dataset from an empty list"
+    )
+    with pytest.raises(ValueError, match=expected):
+        ds.InMemoryDataset([])
+
+    expected = (
+        "Item has schema\nb: int64\nwhich does not match expected schema\n"
+        "a: int64"
+    )
+    with pytest.raises(TypeError, match=expected):
+        ds.dataset([batch1, batch2])
+
+    expected = (
+        "Expected a list of path-like or dataset objects, or a list of "
+        "batches or tables. The given list contains the following types:"
+    )
+    with pytest.raises(TypeError, match=expected):
+        ds.dataset([batch1, 0])
+
+    expected = (
+        "Expected a list of tables or batches. The given list contains a int"
+    )
+    with pytest.raises(TypeError, match=expected):
+        ds.InMemoryDataset([batch1, 0])
+
+
+def test_construct_in_memory():
+    batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+    table = pa.Table.from_batches([batch])
+    reader = pa.ipc.RecordBatchReader.from_batches(batch.schema, [batch])
+    iterable = (batch for _ in range(1))
+
+    for source in (batch, table, reader, [batch], [table]):
+        dataset = ds.dataset(source)
+        assert dataset.to_table() == table
+
+    assert ds.dataset(iterable, schema=batch.schema).to_table().equals(table)
+    assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([])
+
+    # When constructed from batches/tables, should be reusable
+    for source in (batch, table, [batch], [table]):
+        dataset = ds.dataset(source)
+        assert len(list(dataset.get_fragments())) == 1
+        assert len(list(dataset.get_fragments())) == 1
+        assert dataset.to_table() == table
+        assert dataset.to_table() == table
+        assert next(dataset.get_fragments()).to_table() == table
+
+    # When constructed from readers/iterators, should be one-shot
+    match = "InMemoryDataset was already consumed"
+    for factory in (
+            lambda: pa.ipc.RecordBatchReader.from_batches(
+                batch.schema, [batch]),
+            lambda: (batch for _ in range(1)),
+    ):
+        dataset = ds.dataset(factory(), schema=batch.schema)
+        # Getting fragments consumes the underlying iterator
+        fragments = list(dataset.get_fragments())
+        assert len(fragments) == 1
+        assert fragments[0].to_table() == table
+        with pytest.raises(pa.ArrowInvalid, match=match):
+            list(dataset.get_fragments())
+        with pytest.raises(pa.ArrowInvalid, match=match):
+            dataset.to_table()
+        # Materializing consumes the underlying iterator
+        dataset = ds.dataset(factory(), schema=batch.schema)
+        assert dataset.to_table() == table
+        with pytest.raises(pa.ArrowInvalid, match=match):
+            list(dataset.get_fragments())
+        with pytest.raises(pa.ArrowInvalid, match=match):
+            dataset.to_table()
+
 
 @pytest.mark.parquet
 def test_open_dataset_partitioned_directory(tempdir):
@@ -2856,6 +2938,28 @@
     )
 
 
+def test_write_iterable(tempdir):
+    table = pa.table([
+        pa.array(range(20)), pa.array(np.random.randn(20)),
+        pa.array(np.repeat(['a', 'b'], 10))
+    ], names=["f1", "f2", "part"])
+
+    base_dir = tempdir / 'inmemory_iterable'
+    ds.write_dataset((batch for batch in table.to_batches()), base_dir,
+                     schema=table.schema,
+                     basename_template='dat_{i}.arrow', format="feather")
+    result = ds.dataset(base_dir, format="ipc").to_table()
+    assert result.equals(table)
+
+    base_dir = tempdir / 'inmemory_reader'
+    reader = pa.ipc.RecordBatchReader.from_batches(table.schema,
+                                                   table.to_batches())
+    ds.write_dataset(reader, base_dir,
+                     basename_template='dat_{i}.arrow', format="feather")
+    result = ds.dataset(base_dir, format="ipc").to_table()
+    assert result.equals(table)
+
+
 def test_write_table_partitioned_dict(tempdir):
     # ensure writing table partitioned on a dictionary column works without
     # specifying the dictionary values explicitly
diff --git a/python/pyarrow/util.py b/python/pyarrow/util.py
index e91294a..446e673 100644
--- a/python/pyarrow/util.py
+++ b/python/pyarrow/util.py
@@ -62,6 +62,14 @@
     return _DeprecatedMeta(old_name, (new_class,), {})
 
 
+def _is_iterable(obj):
+    try:
+        iter(obj)
+        return True
+    except TypeError:
+        return False
+
+
 def _is_path_like(path):
     # PEP519 filesystem path protocol is available from python 3.6, so pathlib
     # doesn't implement __fspath__ for earlier versions