Add Tables.partitions definition for Arrow.Table (#443)
We had this functionality w/ `Arrow.Stream`, but it's convenient and not
that expensive to define it for `Arrow.Table` as well.
Fixes #293.
diff --git a/src/table.jl b/src/table.jl
index c32fe5a..66fd584 100644
--- a/src/table.jl
+++ b/src/table.jl
@@ -261,6 +261,7 @@
columns(t::Table) = getfield(t, :columns)
lookup(t::Table) = getfield(t, :lookup)
schema(t::Table) = getfield(t, :schema)
+metadata(t::Table) = getfield(t, :metadata)
"""
Arrow.getmetadata(x)
@@ -286,6 +287,41 @@
Tables.getcolumn(t::Table, i::Int) = columns(t)[i]
Tables.getcolumn(t::Table, nm::Symbol) = lookup(t)[nm]
+struct TablePartitions
+ table::Table
+ npartitions::Int
+end
+
+function TablePartitions(table::Table)
+ cols = columns(table)
+ npartitions = if length(cols) == 0
+ 0
+ elseif cols[1] isa ChainedVector
+ length(cols[1].arrays)
+ else
+ 1
+ end
+ return TablePartitions(table, npartitions)
+end
+
+function Base.iterate(tp::TablePartitions, i=1)
+ i > tp.npartitions && return nothing
+ tp.npartitions == 1 && return tp.table, i + 1
+ cols = columns(tp.table)
+ newcols = AbstractVector[cols[j].arrays[i] for j in 1:length(cols)]
+ nms = names(tp.table)
+ tbl = Table(
+ nms,
+ types(tp.table),
+ newcols,
+ Dict{Symbol, AbstractVector}(nms[i] => newcols[i] for i in 1:length(nms)),
+ schema(tp.table)
+ )
+ return tbl, i + 1
+end
+
+Tables.partitions(t::Table) = TablePartitions(t)
+
# high-level user API functions
Table(input, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...)
Table(input::Vector{UInt8}, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...)
diff --git a/test/runtests.jl b/test/runtests.jl
index a46a953..8a8bccd 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -674,6 +674,21 @@
end
+@testset "# 293" begin
+
+t = (a = [1, 2, 3], b = [1.0, 2.0, 3.0])
+buf = Arrow.tobuffer(t)
+tbl = Arrow.Table(buf)
+parts = Tables.partitioner((t, t))
+buf2 = Arrow.tobuffer(parts)
+tbl2 = Arrow.Table(buf2)
+for t in Tables.partitions(tbl2)
+ @test t.a == tbl.a
+ @test t.b == tbl.b
+end
+
+end
+
end # @testset "misc"
end
diff --git a/test/testappend.jl b/test/testappend.jl
index d4834dd..1fc3fad 100644
--- a/test/testappend.jl
+++ b/test/testappend.jl
@@ -129,12 +129,13 @@
arrow_table2 = Arrow.Table(file2)
# now
# arrow_table1: 2 partitions, 20 rows
- # arrow_table2: 2 partitions, 30 rows (both partitions of table1 are appended as single partition)
+ # arrow_table2: 2 partitions, 30 rows (both partitions of table1 are appended as separate partitions)
@test Tables.schema(arrow_table1) == Tables.schema(arrow_table2)
@test length(Tables.columns(arrow_table1)[1]) == 20
@test length(Tables.columns(arrow_table2)[1]) == 30
- @test length(collect(Tables.partitions(Arrow.Stream(file1)))) == length(collect(Tables.partitions(Arrow.Stream(file2))))
+ @test length(collect(Tables.partitions(Arrow.Stream(file1)))) == 2
+ @test length(collect(Tables.partitions(Arrow.Stream(file2)))) == 3
Arrow.append(file1, Arrow.Stream(file2))
arrow_table1 = Arrow.Table(file1)
@@ -145,6 +146,7 @@
@test Tables.schema(arrow_table1) == Tables.schema(arrow_table2)
@test length(Tables.columns(arrow_table1)[1]) == 50
@test length(Tables.columns(arrow_table2)[1]) == 30
- @test length(collect(Tables.partitions(Arrow.Stream(file1)))) == 2 * length(collect(Tables.partitions(Arrow.Stream(file2))))
+ @test length(collect(Tables.partitions(Arrow.Stream(file1)))) == 5
+ @test length(collect(Tables.partitions(Arrow.Stream(file2)))) == 3
end
end