diff --git a/Project.toml b/Project.toml index 17852b9..112489d 100644 --- a/Project.toml +++ b/Project.toml @@ -6,16 +6,15 @@ version = "0.3" BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232" CpuId = "adafc99b-e345-5852-983c-f28acb93d879" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] BinaryProvider = "0.5.8" CpuId = "0.2" -SpecialFunctions = "0.8, 0.9, 0.10" julia = "0.7, 1.0" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [targets] -test = ["Test"] +test = ["Test", "SpecialFunctions"] diff --git a/src/IntelVectorMath.jl b/src/IntelVectorMath.jl index 9d8facf..c60188f 100644 --- a/src/IntelVectorMath.jl +++ b/src/IntelVectorMath.jl @@ -1,12 +1,9 @@ -__precompile__() - module IntelVectorMath export IVM const IVM = IntelVectorMath # import Base: .^, ./ -using SpecialFunctions # using Libdl include("../deps/deps.jl") @@ -105,46 +102,40 @@ for t in (Float32, Float64) end """ - @overload exp log sin - -This macro adds a method to each function in `Base` (or perhaps in `SpecialFunctions`), -so that when acting on an array (or two arrays) it calls the `IntelVectorMath` function of the same name. + @vml_overload Base.exp Base.log SpecialFunctions.erfc -The existing action on scalars is unaffected. However, `exp(M::Matrix)` will now mean -element-wise `IntelVectorMath.exp(M) == exp.(M)`, rather than matrix exponentiation. +This macro adds a method to each given function in `Base` or `SpecialFunctions`, +so that when acting on a `Vector` (or two `Vector`s) it calls the `IntelVectorMath` function of the same name. """ -macro overload(funs...) +macro vml_overload(funs...) out = quote end - say = [] for f in funs - if f in _UNARY - if isdefined(Base, f) - push!(out.args, :( Base.$f(A::Array) = IntelVectorMath.$f(A) )) - push!(say, "Base.$f(A)") - elseif isdefined(SpecialFunctions, f) - push!(out.args, :( IntelVectorMath.SpecialFunctions.$f(A::Array) = IntelVectorMath.$f(A) )) - push!(say, "SpecialFunctions.$f(A)") - else - @error "function IntelVectorMath.$f is not defined in Base or SpecialFunctions, so there is nothing to overload" - end + if f.head !== :(.) || !(length(f.args) == 2) || !(f.args[1] isa Symbol && f.args[2] isa QuoteNode) + error("expected a Module.function type of expression, got $f") + end + mod, f = f.args[1], f.args[2].value + if !(mod in (:Base, :SpecialFunctions)) + error("expected module to be either Base or SpecialFunctions, got $mod") + end + if f in keys(_UNARY) + input_types = _UNARY[f] + expr = :($(esc(mod)).$f(A::Vector{T}) where {T <: Union{$(input_types...)}} = + IntelVectorMath.$f(A)) + push!(out.args, expr) end - if f in _BINARY - if isdefined(Base, f) - push!(out.args, :( Base.$f(A::Array, B::Array) = IntelVectorMath.$f(A, B) )) - push!(say, "Base.$f(A, B)") - else - @error "function IntelVectorMath.$f is not defined in Base, so there is nothing to overload" - end + if f in keys(_BINARY) + input_types = _BINARY[f] + expr = :($(esc(mod)).$f(A::Vector{T}, B::Vector{T}) where {T <: Union{$(input_types...)}} = + IntelVectorMath.$f(A, B)) + push!(out.args, expr) end - if !(f in _UNARY) && !(f in _BINARY) - error("there is no function $f defined by IntelVectorMath.jl") + if !(f in keys(_UNARY)) && !(f in keys(_BINARY)) + error("there is no function $f defined in IntelVectorMath.jl") end end - str = string("Overloaded these functions: \n ", join(say, " \n ")) - push!(out.args, str) - esc(out) + return out end -export VML_LA, VML_HA, VML_EP, vml_set_accuracy, vml_get_accuracy, @overload +export VML_LA, VML_HA, VML_EP, vml_set_accuracy, vml_get_accuracy, @vml_overload end diff --git a/src/setup.jl b/src/setup.jl index 75b514b..733be2a 100644 --- a/src/setup.jl +++ b/src/setup.jl @@ -18,8 +18,8 @@ const VML_LA = VMLAccuracy(0x00000001) const VML_HA = VMLAccuracy(0x00000002) const VML_EP = VMLAccuracy(0x00000003) -const _UNARY = [] # for @overload to check -const _BINARY = [] +const _UNARY = Dict{Symbol, Vector{DataType}}() # for @vml_overload to check +const _BINARY = Dict{Symbol, Vector{DataType}}() Base.show(io::IO, m::VMLAccuracy) = print(io, m == VML_LA ? "VML_LA" : m == VML_HA ? "VML_HA" : "VML_EP") @@ -59,13 +59,13 @@ function vml_prefix(t::DataType) error("unknown type $t") end -function def_unary_op(tin, tout, jlname, jlname!, mklname; +function def_unary_op(tin, tout, jlname, jlname!, mklname; vmltype = tin) mklfn = Base.Meta.quot(Symbol("$(vml_prefix(vmltype))$mklname")) exports = Symbol[] (@isdefined jlname) || push!(exports, jlname) (@isdefined jlname!) || push!(exports, jlname!) - push!(_UNARY, jlname) + push!(get!(_UNARY, jlname, DataType[]), tin) @eval begin function ($jlname!)(out::Array{$tout,N}, A::Array{$tin,N}) where {N} size(out) == size(A) || throw(DimensionMismatch()) @@ -97,7 +97,7 @@ function def_binary_op(tin, tout, jlname, jlname!, mklname, broadcast) exports = Symbol[] (@isdefined jlname) || push!(exports, jlname) (@isdefined jlname!) || push!(exports, jlname!) - push!(_BINARY, jlname) + push!(get!(_BINARY, jlname, DataType[]), tin) @eval begin $(isempty(exports) ? nothing : Expr(:export, exports...)) function ($jlname!)(out::Array{$tout,N}, A::Array{$tin,N}, B::Array{$tin,N}) where {N} diff --git a/test/real.jl b/test/real.jl index fe3e6ef..40c808a 100644 --- a/test/real.jl +++ b/test/real.jl @@ -17,7 +17,7 @@ fns = [[x[1:2] for x in base_unary_real]; [x[1:2] for x in base_binary_real]] @testset "Definitions and Comparison with Base for Reals" begin for t in (Float32, Float64), i = 1:length(fns) - base_fn = eval(:($(fns[i][1]).$(fns[i][2]))) + base_fn = eval(:($(fns[i][1]).$(fns[i][2]))) vml_fn = eval(:(IntelVectorMath.$(fns[i][2]))) vml_fn! = eval(:(IntelVectorMath.$(Symbol(fns[i][2], !)))) @@ -28,10 +28,10 @@ fns = [[x[1:2] for x in base_unary_real]; [x[1:2] for x in base_binary_real]] Test.@test vml_fn(input[t][i]...) ≈ baseres # cis changes type (float to complex, does not have mutating function) - + if length(input[t][i]) == 1 - if fns[i][2] != :cis + if fns[i][2] != :cis vml_fn!(input[t][i]...) Test.@test input[t][i][1] ≈ baseres end @@ -60,15 +60,17 @@ end end -@testset "@overload macro" begin - +@testset "@vml_overload macro" begin @test IntelVectorMath.exp([1.0]) ≈ exp.([1.0]) @test_throws MethodError Base.exp([1.0]) - @test (@overload log exp) isa String + @vml_overload Base.log Base.exp @test Base.exp([1.0]) ≈ exp.([1.0]) @test_throws MethodError Base.atan([1.0], [2.0]) - @test (@overload atan) isa String + @vml_overload Base.atan @test Base.atan([1.0], [2.0]) ≈ atan.([1.0], [2.0]) + @test_throws MethodError SpecialFunctions.erfc([1.0, 2.0]) + @vml_overload SpecialFunctions.erfc + @test SpecialFunctions.erfc([1.0, 2.0]) ≈ SpecialFunctions.erfc.([1.0, 2.0]) end diff --git a/test/runtests.jl b/test/runtests.jl index 7c51a37..2a44f59 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Test using IntelVectorMath +using SpecialFunctions include("common.jl") include("real.jl")