diff --git a/NEWS.md b/NEWS.md index f5feae306daea..9f90bc9dadcb5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -84,6 +84,9 @@ Standard library changes #### LinearAlgebra +* A new unexported function `BLAS.with_num_threads` allows you to temporarily change the number of + BLAS threads. ([#41785]) + #### Markdown #### Printf diff --git a/stdlib/LinearAlgebra/docs/src/index.md b/stdlib/LinearAlgebra/docs/src/index.md index 38c48bfe6d8d2..bef4f4353be2c 100644 --- a/stdlib/LinearAlgebra/docs/src/index.md +++ b/stdlib/LinearAlgebra/docs/src/index.md @@ -578,6 +578,7 @@ LinearAlgebra.BLAS.trsv! LinearAlgebra.BLAS.trsv LinearAlgebra.BLAS.set_num_threads LinearAlgebra.BLAS.get_num_threads +LinearAlgebra.BLAS.with_num_threads ``` ## LAPACK functions diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 661e9e2b15617..9f08786a5d960 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -120,6 +120,9 @@ Set the number of threads the BLAS library should use equal to `n::Integer`. Also accepts `nothing`, in which case julia tries to guess the default number of threads. Passing `nothing` is discouraged and mainly exists for historical reasons. + +See also [`with_num_threads`](@ref BLAS.with_num_threads) to temporarily change +number of BLAS threads. """ set_num_threads(nt::Integer)::Nothing = lbt_set_num_threads(Int32(nt)) function set_num_threads(::Nothing) @@ -157,6 +160,56 @@ function check() end end +""" + with_num_threads(f, num_threads::Integer) + +Run function `f()` with BLAS threads `num_threads` and then +restore to previous threads setting. + +!!! compat "Julia 1.8" + `with_num_threads` requires at least Julia 1.8. + +# Example + +Depending on the number of available CPU cores, the result can be different: + +```julia +julia> BLAS.get_num_threads() +8 + +julia> with_num_threads(4) do + BLAS.get_num_threads() + # or doing some basic BLAS computation +end +4 + +julia> BLAS.get_num_threads() +8 +``` + +!!! warning + This function is not thread safe. If there are multiple + threads calling BLAS routines, then the threads they are + using will also be changed until this function finishes. + +!!! warning + This interface is experimental and subject to change or + removal without notice. + +See also [`set_num_threads`](@ref BLAS.set_num_threads) to permanently change +number of BLAS threads. +""" +function with_num_threads(f, num_threads::Integer) + prev_num_threads = BLAS.get_num_threads() + BLAS.set_num_threads(num_threads) + try + return f() + catch + rethrow() + finally + BLAS.set_num_threads(prev_num_threads) + end +end # Level 1 ## copy diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index df29c171b2060..7a6980deb86b3 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -641,6 +641,30 @@ end @test BLAS.get_num_threads() === default end +@testset "with_num_threads" begin + prev_num_threads = BLAS.get_num_threads() + context_num_threads = BLAS.with_num_threads(1) do + BLAS.get_num_threads() + end + @test context_num_threads == 1 + @test prev_num_threads == BLAS.get_num_threads() + + @testset "thread unsafe" begin + prev_num_threads = BLAS.get_num_threads() + context_num_threads = 1 + # task A + t = @async BLAS.with_num_threads(context_num_threads) do + sleep(0.5) + end + sleep(0.1) + # check that main thread is affected by task A + @test BLAS.get_num_threads() == context_num_threads + # when the task finishes, the num threads get restored + wait(t) + @test prev_num_threads == BLAS.get_num_threads() + end +end + # https://github.com/JuliaLang/julia/pull/39845 @test LinearAlgebra.BLAS.libblas == "libblastrampoline" @test LinearAlgebra.BLAS.liblapack == "libblastrampoline"