blob: 833b483ca3218ea5b7314952e78fe21e8d546c5d [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Base.prod(x::NDArray; dims = :) = _prod(x, dims)
@_remap _prod(x::NDArray, ::Colon) prod(x)
@_remap _prod(x::NDArray, dims) prod(x; axis = 0 .- dims, keepdims = true)
Base.maximum(x::NDArray; dims = :) = _nd_maximum(x, dims)
@_remap _nd_maximum(x::NDArray, ::Colon) max(x)
@_remap _nd_maximum(x::NDArray, dims) max(x; axis = 0 .- dims, keepdims = true)
Base.minimum(x::NDArray; dims = :) = _nd_minimum(x, dims)
@_remap _nd_minimum(x::NDArray, ::Colon) min(x)
@_remap _nd_minimum(x::NDArray, dims) min(x; axis = 0 .- dims, keepdims = true)
###############################################################################
# min/max
###############################################################################
import Base: min, max
broadcasted(::typeof(max), x::NDArray{T}, y::NDArray{T}) where {T} =
_broadcast_maximum(x, y)
broadcasted(::typeof(min), x::NDArray{T}, y::NDArray{T}) where {T} =
_broadcast_minimum(x, y)
###############################################################################
# argmin/argmax
###############################################################################
# TODO: support CartesianIndex ?
"""
argmax(x::NDArray; dims) -> indices
Note that `NaN` is skipped during comparison.
This is different from Julia `Base.argmax`.
## Examples
```julia-repl
julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
0.0 1.0 2.0
3.0 4.0 5.0
julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
2.0 2.0 2.0
julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
3.0
3.0
```
See also [`argmin`](@ref mx.argmin).
"""
Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1
@_remap _argmax(x::NDArray, ::Colon) argmax(x)
@_remap _argmax(x::NDArray, dims) argmax(x; axis = 0 .- dims, keepdims = true)
"""
argmin(x::NDArray; dims) -> indices
Note that `NaN` is skipped during comparison.
This is different from Julia `Base.argmin`.
## Examples
```julia-repl
julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
0.0 1.0 2.0
3.0 4.0 5.0
julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
2.0 2.0 2.0
julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
3.0
3.0
```
See also [`argmax`](@ref mx.argmax).
"""
Base.argmin(x::NDArray; dims = :) = _argmin(x, dims) .+ 1
@_remap _argmin(x::NDArray, ::Colon) argmin(x)
@_remap _argmin(x::NDArray, dims) argmin(x; axis = 0 .- dims, keepdims = true)
################################################################################
# remapping to solving type unstablility
################################################################################
@_remap _broadcast_maximum(x::NDArray, y::NDArray) broadcast_maximum(x, y)
@_remap _broadcast_maximum!(x::NDArray, y::NDArray) broadcast_maximum(x, y)
@_remap _broadcast_minimum(x::NDArray, y::NDArray) broadcast_minimum(x, y)
@_remap _broadcast_minimum!(x::NDArray, y::NDArray) broadcast_minimum(x, y)