@@ -347,6 +347,43 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
347347    R
348348end 
349349
350+ # # Base interface
351+ 
352+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Nothing ) = 
353+     accumulate! (op, typed_data (output), typed_data (input); dims= 1 )
354+ 
355+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Nothing ) = 
356+     accumulate! (op, typed_data (output), typed_data (input); dims)
357+ 
358+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Some ) = 
359+     accumulate! (op, typed_data (output), typed_data (input); dims= 1 , init= something (init))
360+ 
361+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Some ) = 
362+     accumulate! (op, typed_data (output), typed_data (input); dims, init= something (init))
363+ 
364+ Base. accumulate_pairwise! (op, result:: AnyJLVector , v:: AnyJLVector ) =  accumulate! (op, result, v)
365+ 
366+ #  default behavior unless dims are specified by the user
367+ function  Base. accumulate (op, A:: AnyJLArray ;
368+                          dims:: Union{Nothing,Integer} = nothing , kw... )
369+     nt =  values (kw)
370+     if  dims ===  nothing  &&  ! (A isa  AbstractVector)
371+         #  This branch takes care of the cases not handled by `_accumulate!`.
372+         return  reshape (accumulate (op, typed_data (A)[:]; kw... ), size (A))
373+     end 
374+     if  isempty (kw)
375+         out =  similar (A, Base. promote_op (op, eltype (A), eltype (A)))
376+         init =  AK. neutral_element (op, eltype (out))
377+     elseif  keys (nt) ===  (:init ,)
378+         out =  similar (A, Base. promote_op (op, typeof (nt. init), eltype (A)))
379+         init =  nt. init
380+     else 
381+         throw (ArgumentError (" accumulate does not support the keyword arguments $(setdiff (keys (nt), (:init ,))) " 
382+     end 
383+     accumulate! (op, typed_data (out), typed_data (A); dims, init)
384+ end 
385+ 
386+ 
350387# # KernelAbstractions interface
351388
352389KernelAbstractions. get_backend (a:: JLA ) where  JLA <:  JLArray  =  JLBackend ()
0 commit comments