Rework dict encoding of PooledArray/CategoricalArray to fix outstandi… (#119)
* Rework dict encoding of PooledArray/CategoricalArray to fix outstanding issues
Fixes #117, #116, and #113. For #116, we just need to special case if user happens to pass in a DictEncoded themselves. We need to pass it through to the `toarrowvector` method that no-ops. For #113, we require the new functionality in PooledArrays that allows passing the `signed` and `compress` keyword arguments to ensure we get signed refs for our dict encoding. For #117, we add CategoricalArrays as a test dependency and ensure that if it contains any `missing` value, we *don't* recode the indices values down by 1, since the `missing` ref is 0, so other refs can already be considered "offsets". If there are no `missing`, then we still need to recode down since refs should always start from 0 in arrow format.
* PooledArrays 1.0 compat
* Update src/arraytypes/dictencoding.jl
Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
* Check refpool
* Fix test
Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
diff --git a/.gitignore b/.gitignore
index 4056c1d..c841c99 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,4 @@
*.jl.mem
test/_scrap.jl
+.DS_STORE
diff --git a/Project.toml b/Project.toml
index afeeee0..fb4e3d7 100644
--- a/Project.toml
+++ b/Project.toml
@@ -21,17 +21,18 @@
CodecLz4 = "0.4"
CodecZstd = "0.7"
DataAPI = "1"
-PooledArrays = "0.5"
+PooledArrays = "0.5, 1.0"
SentinelArrays = "1"
Tables = "1.1"
TimeZones = "1"
julia = "1.3"
[extras]
+CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["Test", "Random", "JSON3", "StructTypes"]
+test = ["Test", "Random", "JSON3", "StructTypes", "CategoricalArrays"]
diff --git a/src/arraytypes/arraytypes.jl b/src/arraytypes/arraytypes.jl
index 005e923..829b3c0 100644
--- a/src/arraytypes/arraytypes.jl
+++ b/src/arraytypes/arraytypes.jl
@@ -51,6 +51,8 @@
function arrowvector(x, i, nl, fi, de, ded, meta; dictencoding::Bool=false, dictencode::Bool=false, kw...)
if !(x isa DictEncode) && !dictencoding && (dictencode || (x isa AbstractArray && DataAPI.refarray(x) !== x))
x = DictEncode(x, dictencodeid(i, nl, fi))
+ elseif x isa DictEncoded
+ return arrowvector(DictEncodeType, x, i, nl, fi, de, ded, meta; dictencode=dictencode, kw...)
end
S = maybemissing(eltype(x))
return arrowvector(S, x, i, nl, fi, de, ded, meta; dictencode=dictencode, kw...)
diff --git a/src/arraytypes/dictencoding.jl b/src/arraytypes/dictencoding.jl
index 47a4965..79aa9fc 100644
--- a/src/arraytypes/dictencoding.jl
+++ b/src/arraytypes/dictencoding.jl
@@ -94,6 +94,7 @@
signedtype(::Type{UInt16}) = Int16
signedtype(::Type{UInt32}) = Int32
signedtype(::Type{UInt64}) = Int64
+signedtype(::Type{T}) where {T <: Signed} = T
indtype(d::DictEncoded{T, S, A}) where {T, S, A} = S
indtype(c::Compressed{Z, A}) where {Z, A <: DictEncoded} = indtype(c.data)
@@ -113,21 +114,31 @@
validity = ValidityBitmap(x)
if !haskey(de, id)
# dict encoding doesn't exist yet, so create for 1st time
- if DataAPI.refarray(x) === x
+ if DataAPI.refarray(x) === x || DataAPI.refpool(x) === nothing
# need to encode ourselves
- x = PooledArray(x, encodingtype(length(x)))
+ x = PooledArray(x; signed=true, compress=true)
inds = DataAPI.refarray(x)
+ pool = DataAPI.refpool(x)
else
- inds = copy(DataAPI.refarray(x))
+ pool = DataAPI.refpool(x)
+ refa = DataAPI.refarray(x)
+ inds = copyto!(similar(Vector{signedtype(eltype(refa))}, length(refa)), refa)
end
- # adjust to "offset" instead of index
- for i = 1:length(inds)
- @inbounds inds[i] -= 1
- end
- pool = DataAPI.refpool(x)
# horrible hack? yes. better than taking CategoricalArrays dependency? also yes.
if typeof(pool).name.name == :CategoricalRefPool
- pool = [get(pool[i]) for i = 1:length(pool)]
+ if eltype(x) >: Missing
+ pool = vcat(missing, DataAPI.levels(x))
+ else
+ pool = DataAPI.levels(x)
+ for i = 1:length(inds)
+ @inbounds inds[i] -= 1
+ end
+ end
+ else
+ # adjust to "offset" instead of index
+ for i = 1:length(inds)
+ @inbounds inds[i] -= 1
+ end
end
data = arrowvector(pool, i, nl, fi, de, ded, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
encoding = DictEncoding{eltype(data), typeof(data)}(id, data, false, getmetadata(data))
diff --git a/test/runtests.jl b/test/runtests.jl
index 4e2a54d..47d0f4a 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-using Test, Arrow, Tables, Dates, PooledArrays, TimeZones, UUIDs
+using Test, Arrow, Tables, Dates, PooledArrays, TimeZones, UUIDs, CategoricalArrays
include(joinpath(dirname(pathof(Arrow)), "../test/testtables.jl"))
include(joinpath(dirname(pathof(Arrow)), "../test/integrationtest.jl"))
@@ -216,6 +216,25 @@
x2 = Arrow.toarrowvector(x)
@test isequal(copy(x2), x)
+# some dict encoding coverage
+
+# signed indices for DictEncodedType #112 #113 #114
+av = Arrow.toarrowvector(PooledArray(repeat(["a", "b"], inner = 5)))
+@test isa(first(av.indices), Signed)
+
+av = Arrow.toarrowvector(CategoricalArray(repeat(["a", "b"], inner = 5)))
+@test isa(first(av.indices), Signed)
+
+av = Arrow.toarrowvector(CategoricalArray(["a", "bb", missing]))
+@test isa(first(av.indices), Signed)
+@test length(av) == 3
+@test eltype(av) == Union{String, Missing}
+
+av = Arrow.toarrowvector(CategoricalArray(["a", "bb", "ccc"]))
+@test isa(first(av.indices), Signed)
+@test length(av) == 3
+@test eltype(av) == String
+
end # @testset "misc"
end