Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 55 additions & 81 deletions src/array/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle, Style
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle, Style, AbstractArrayStyle, Unknown

const STRICT_BROADCAST_CHECKS = Ref(true)
const STRICT_BROADCAST_DOCS = """
Expand Down Expand Up @@ -35,10 +35,9 @@ strict_broadcast!(x::Bool) = STRICT_BROADCAST_CHECKS[] = x
# It preserves the dimension names.
# `S` should be the `BroadcastStyle` of the wrapped type.
# Copied from NamedDims.jl (thanks @oxinabox).
struct BasicDimensionalStyle{N} <: AbstractArrayStyle{Any} end

struct DimensionalStyle{S<:BroadcastStyle} <: AbstractArrayStyle{Any} end
DimensionalStyle(::S) where {S} = DimensionalStyle{S}()
struct DimensionalStyle{S <: AbstractArrayStyle, N} <: AbstractArrayStyle{N} end
DimensionalStyle(::S) where S<:AbstractArrayStyle{N} where N = DimensionalStyle{S, N}()
DimensionalStyle(::S) where {S<:DimensionalStyle} = S() # avoid nested dimensionalstyle
DimensionalStyle(::S, ::Val{N}) where {S,N} = DimensionalStyle(S(Val(N)))
DimensionalStyle(::Val{N}) where N = DimensionalStyle{DefaultArrayStyle{N}}()
function DimensionalStyle(a::BroadcastStyle, b::BroadcastStyle)
Expand All @@ -51,86 +50,59 @@ function DimensionalStyle(a::BroadcastStyle, b::BroadcastStyle)
end
end

function BroadcastStyle(::Type{<:AbstractDimArray{T,N,D,A}}) where {T,N,D,A}
inner_style = typeof(BroadcastStyle(A))
return DimensionalStyle{inner_style}()
end
BroadcastStyle(::Type{<:AbstractBasicDimArray{T,N}}) where {T,N} =
BasicDimensionalStyle{N}()

BroadcastStyle(::Type{<:AbstractDimArray{T,N,D,A}}) where {T,N,D,A} =
DimensionalStyle(BroadcastStyle(A))
BroadcastStyle(::Type{<:AbstractBasicDimArray{T,N,D}}) where {T,N,D} =
DimensionalStyle(DefaultArrayStyle{N}())
BroadcastStyle(::DimensionalStyle, ::Base.Broadcast.Unknown) = Unknown()
BroadcastStyle(::Base.Broadcast.Unknown, ::DimensionalStyle) = Unknown()
BroadcastStyle(::DimensionalStyle{A}, ::DimensionalStyle{B}) where {A, B} = DimensionalStyle(A(), B())
BroadcastStyle(::DimensionalStyle{A}, b::Style) where {A} = DimensionalStyle(A(), b)
BroadcastStyle(a::Style, ::DimensionalStyle{B}) where {B} = DimensionalStyle(a, B())
BroadcastStyle(::DimensionalStyle{A}, b::AbstractArrayStyle{N}) where {A,N} = DimensionalStyle(A(), b)
BroadcastStyle(::DimensionalStyle{A}, b::DefaultArrayStyle{N}) where {A,N} = DimensionalStyle(A(), b) # ambiguity
BroadcastStyle(::DimensionalStyle{A}, b::Style{Tuple}) where {A} = DimensionalStyle(A(), b)
BroadcastStyle(a::Style{Tuple}, ::DimensionalStyle{B}) where {B} = DimensionalStyle(a, B())
# We need to implement copy because if the wrapper array type does not
# support setindex then the `similar` based default method will not work
function Broadcast.copy(bc::Broadcasted{DimensionalStyle{S}}) where S
A = _firstdimarray(bc)
data = copy(_unwrap_broadcasted(bc))

A isa Nothing && return data # No AbstractDimArray

bdims = _broadcasted_dims(bc)
_comparedims_broadcast(A, bdims...)

data isa AbstractArray || return data # result is a scalar

# unwrap AbstractDimArray data
data = data isa AbstractDimArray ? parent(data) : data
dims = format(Dimensions.promotedims(bdims...; skip_length_one=true), data)
return rebuild(A; data, dims, refdims=refdims(A), name=Symbol(""))
end
function Broadcast.copy(bc::Broadcasted{BasicDimensionalStyle{N}}) where N
# override base instantiate to check dimensions as well as axes
@inline function Broadcast.instantiate(bc::Broadcasted{<:DimensionalStyle{S}}) where S
A = _firstdimarray(bc)
data = collect(bc)
A isa Nothing && return data # No AbstractDimArray

# check if there is any DimArray and unwrap immediately if no
isnothing(A) && return Broadcast.instantiate(_unwrap_broadcasted(bc))
bdims = _broadcasted_dims(bc)
if bc.axes isa Nothing
axes = Base.Broadcast.combine_axes(_unwrap_broadcasted(bc).args...)
ds = Dimensions.promotedims(bdims...; skip_length_one=true)
length(axes) == length(ds) ||
throw(ArgumentError("Number of broadcasted dimensions $(length(axes)) larger than $(ds)"))
axes = map(Dimensions.DimUnitRange, axes, ds)
else # bc already has axes which might have dimensions, e.g. when assigning to a DimArray
axes = bc.axes
Base.Broadcast.check_broadcast_axes(axes, bc.args...)
ds = dims(axes)
isnothing(ds) || _comparedims_broadcast(A, ds, bdims...)
end
_comparedims_broadcast(A, bdims...)

data isa AbstractArray || return data # result is a scalar

# Return an AbstractDimArray
dims = format(Dimensions.promotedims(bdims...; skip_length_one=true), data)
return dimconstructor(dims)(data, dims; refdims=refdims(A), name=Symbol(""))
return Broadcasted(bc.style, bc.f, bc.args, axes)
end

function Base.copyto!(dest::AbstractArray, bc::Broadcasted{DimensionalStyle{S}}) where S
fda = _firstdimarray(bc)
isnothing(fda) || _comparedims_broadcast(fda, _broadcasted_dims(bc)...)
copyto!(dest, _unwrap_broadcasted(bc))
end
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{BasicDimensionalStyle{N}}) where N
fda = _firstdimarray(bc)
isnothing(fda) || _comparedims_broadcast(fda, _broadcasted_dims(bc)...)
copyto!(dest, bc)
end

@inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
# Need to check whether the dims are compatible in dest,
# which are already stripped when sent to copyto!
_comparedims_broadcast(dest, dims(dest), _broadcasted_dims(bc)...)
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
Base.Broadcast.materialize!(style, parent(dest), bc)
return dest
end

function Base.similar(bc::Broadcast.Broadcasted{DimensionalStyle{S}}, ::Type{T}) where {S,T}
# Define copy because the inner style S might override copy (e.g. DiskArrays)
function Base.copy(bc::Broadcasted{<:DimensionalStyle{S}}) where S
data = copy(_unwrap_broadcasted(bc))
data isa AbstractArray || return data # in the 0-d case data can be a scalar
# let similar do the work - it will usually call rebuild unless A isa AbstractBasicDimArray
A = _firstdimarray(bc)
data = similar(_unwrap_broadcasted(bc), T, size(bc))
dims, refdims = slicedims(A, axes(bc))
return rebuild(A; data, dims, refdims, name=Symbol(""))
similar(A; data, dims = dims(axes(bc)))
end
function Base.similar(bc::Broadcast.Broadcasted{BasicDimensionalStyle{N}}, ::Type{T}) where {N,T}
# similar is usually only called in broadcast_preserving_zero_d
function Base.similar(bc::Broadcasted{<:DimensionalStyle{S}}, ::Type{T}) where {S,T}
A = _firstdimarray(bc)
data = similar(A, T, size(bc))
dims, refdims = slicedims(A, axes(bc))
return dimconstructor(dims)(data, dims; refdims, name=Symbol(""))
data = similar(_unwrap_broadcasted(bc), T)
similar(A; data, dims = dims(axes(bc)))
end

@inline function Base.materialize!(::DimensionalStyle, dest, bc::Broadcasted)
# check dimensions
bci = Broadcast.instantiate(Broadcasted(bc.style, bc.f, bc.args, axes(dest)))
# unwrap before copying
Base.copyto!(_unwrap_broadcasted(dest), _unwrap_broadcasted(bci))
return dest
end

"""
@d broadcast_expression options
Expand Down Expand Up @@ -407,29 +379,31 @@ end
# Recursively unwraps `AbstractDimArray`s and `DimensionalStyle`s.
# replacing the `AbstractDimArray`s with the wrapped array,
# and `DimensionalStyle` with the wrapped `BroadcastStyle`.
function _unwrap_broadcasted(bc::Broadcasted{DimensionalStyle{S}}) where S

function _unwrap_broadcasted(bc::Broadcasted{<:DimensionalStyle{S}}) where {S}
innerargs = map(_unwrap_broadcasted, bc.args)
return Broadcasted{S}(bc.f, innerargs)
return Broadcasted{S}(bc.f, innerargs, _unwrap_broadcasted(bc.axes))
end
_unwrap_broadcasted(x) = x
_unwrap_broadcasted(nda::AbstractDimArray) = parent(nda)
_unwrap_broadcasted(boda::BroadcastOptionsDimArray) = parent(parent(boda))

_unwrap_broadcasted(bda::AbstractBasicDimArray) = OpaqueArray(bda)
_unwrap_broadcasted(boda::BroadcastOptionsDimArray) = _unwrap_broadcasted(parent(boda))
_unwrap_broadcasted(t::Tuple) = map(_unwrap_broadcasted, t)
_unwrap_broadcasted(du::Dimensions.DimUnitRange) = parent(du)
# Get the first dimensional array in the broadcast
_firstdimarray(x::Broadcasted) = _firstdimarray(x.args)
_firstdimarray(x::Tuple{<:AbstractBasicDimArray,Vararg}) = x[1]
_firstdimarray(x::AbstractBasicDimArray) = x
_firstdimarray(ext::Base.Broadcast.Extruded) = _firstdimarray(ext.x)
function _firstdimarray(x::Tuple{<:Union{Broadcasted,Base.Broadcast.Extruded},Vararg})
function _firstdimarray(x::Tuple)
found = _firstdimarray(x[1])
if found isa Nothing
_firstdimarray(tail(x))
else
found
end
end
_firstdimarray(x::Tuple) = _firstdimarray(tail(x))
_firstdimarray(x::Tuple{}) = nothing
_firstdimarray(ext::Base.Broadcast.Extruded) = _firstdimarray(ext.x)
_firstdimarray(x::AbstractBasicDimArray) = x
_firstdimarray(x) = nothing

# Make sure all arrays have the same dims, and return them
_broadcasted_dims(bc::Broadcasted) = _broadcasted_dims(bc.args...)
Expand Down
4 changes: 2 additions & 2 deletions src/opaque.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ OpaqueArray(st::P) where P<:AbstractDimStack{<:Any,T,N} where {T,N} = OpaqueArra
Base.size(A::OpaqueArray) = size(A.parent)
Base.getindex(A::OpaqueArray, I::Union{StandardIndices,Not{<:StandardIndices}}...) =
Base.getindex(A.parent, I...)
Base.setindex!(A::OpaqueArray, I::Union{StandardIndices,Not{<:StandardIndices}}...) =
Base.setindex!(A.parent, I...)
Base.setindex!(A::OpaqueArray, x, I::Union{StandardIndices,Not{<:StandardIndices}}...) =
Base.setindex!(A.parent, x, I...)
Loading