Git Product home page Git Product logo

Comments (6)

jrevels avatar jrevels commented on July 23, 2024 3

Not sure if that would win over a properly-optimised dual number implementation, but it's worth exploring.

idk where the code for this kernel is, but if your broadcast kernel is sufficiently static and has higher input arity than output arity, than you should be able to beat ForwardDiff there. It's when the kernel is dynamic that forward-mode shines.

Note, though, that if the input/output arity ratio is close enough to 1 (e.g. ~3) than LLVM might optimize the ForwardDiff calculation enough that it basically becomes the same as the reverse-mode computation. Usually it's a combination of scalar-level vectorization and reduction via sparsity (e.g. LLVM figures out that something is 0 and thus not worth computing) that achieves this.

from zygote.jl.

MikeInnes avatar MikeInnes commented on July 23, 2024 1

I can do you one better than an issue -- JuliaDiff/ForwardDiff.jl#357

from zygote.jl.

MikeInnes avatar MikeInnes commented on July 23, 2024 1

Most of these are done now:

julia> @btime logsumexp(x)
  954.130 ns (1 allocation: 896 bytes)
5.141724249721802

julia> @btime derivative(logsumexp, x)
  1.262 μs (21 allocations: 2.19 KiB)

Still more allocations than I'd like, but hopefully we can keep fixing those as we go along.

from zygote.jl.

jrevels avatar jrevels commented on July 23, 2024

The most egregious issue is that ForwardDiff's method for exp actually calculates exp(x) twice

Oof, can you open an issue in ForwardDiff for this? That's definitely a perf bug. We should add that to our tests...

from zygote.jl.

chriselrod avatar chriselrod commented on July 23, 2024

How can you efficiently reuse computations in common between the objective and gradient evaluations?
I'm commenting here because of the theme of optimizations, and reference to exp(x).
Below, g relies on Zygote, g1 does it the naive way, written like you would a function, g2 uses a global constant for type stable storage, g3 doesn't try to save on computations at all, and g4 uses a struct to try and ensure a type stable closure.
g1 has type stability problems, but the others all come up clean with @code_warntype.

julia> using StaticArrays, BenchmarkTools

julia> using Zygote: @adjoint, gradient
[ Info: Precompiling Zygote [e88e6eb3-aa80-5325-afca-941959d7151f]
[ Info: Precompiling IRTools [7869d1d1-7146-5819-86e3-90919afe41df]

julia> x = Ref(@SVector randn(4));

julia> const S = (@SMatrix randn(6,4)) |> x -> x' * x;

julia> g(x) = 0.5 * (x[]' * S * x[])
g (generic function with 1 method)

julia> g1(x) = 0.5 * (x[]' * S * x[])
g1 (generic function with 1 method)

julia> g2(x) = 0.5 * (x[]' * S * x[])
g2 (generic function with 1 method)

julia> g3(x) = 0.5 * (x[]' * S * x[])
g3 (generic function with 1 method)

julia> g4(x) = 0.5 * (x[]' * S * x[])
g4 (generic function with 1 method)

julia> @adjoint function g1(x)
           Sx = S*x[]
           -0.5 * (x[]' * Sx), Δ ->* Sx,)
       end

julia> const Sx = Ref(@SVector zeros(4))
Base.RefValue{SArray{Tuple{4},Float64,1,4}}([0.0, 0.0, 0.0, 0.0])

julia> @adjoint function g2(x)
           Sx[] = S*x[]
            -0.5 * (x[]' * Sx[]), Δ ->* Sx[],)
       end

julia> @adjoint function g3(x) #no reuse
            xs = x[]
           -0.5 * (xs' * S * xs), Δ ->* (S * xs),)
       end

julia> struct closure{T}
           data::T
       end

julia> (c::closure)( Δ ) = ( Δ * c.data, )

julia> @adjoint function g4(x)
           Sx = S * x[]
           c = closure(Sx)
           -0.5 * (x[]' * Sx), c
       end

julia> @benchmark gradient(g, $x)
BenchmarkTools.Trial: 
  memory estimate:  1.53 KiB
  allocs estimate:  28
  --------------
  minimum time:     2.797 μs (0.00% GC)
  median time:      2.882 μs (0.00% GC)
  mean time:        4.631 μs (33.58% GC)
  maximum time:     11.214 ms (99.91% GC)
  --------------
  samples:          10000
  evals/sample:     9

julia> @benchmark gradient(g1, $x)
BenchmarkTools.Trial: 
  memory estimate:  400 bytes
  allocs estimate:  14
  --------------
  minimum time:     1.473 μs (0.00% GC)
  median time:      1.520 μs (0.00% GC)
  mean time:        2.871 μs (43.54% GC)
  maximum time:     9.827 ms (99.90% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark gradient(g2, $x)
BenchmarkTools.Trial: 
  memory estimate:  32 bytes
  allocs estimate:  2
  --------------
  minimum time:     64.501 ns (0.00% GC)
  median time:      66.745 ns (0.00% GC)
  mean time:        74.539 ns (7.69% GC)
  maximum time:     32.162 μs (99.66% GC)
  --------------
  samples:          10000
  evals/sample:     978

julia> @benchmark gradient(g3, $x)
BenchmarkTools.Trial: 
  memory estimate:  288 bytes
  allocs estimate:  12
  --------------
  minimum time:     1.927 μs (0.00% GC)
  median time:      1.959 μs (0.00% GC)
  mean time:        2.321 μs (11.07% GC)
  maximum time:     2.577 ms (99.68% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark gradient(g4, $x)
BenchmarkTools.Trial: 
  memory estimate:  576 bytes
  allocs estimate:  15
  --------------
  minimum time:     2.126 μs (0.00% GC)
  median time:      2.355 μs (0.00% GC)
  mean time:        3.713 μs (37.68% GC)
  maximum time:     10.920 ms (99.91% GC)
  --------------
  samples:          10000
  evals/sample:     9

The constant global reference did the best, but was still much slower than I'd have expected, feature a few allocations.

julia> function fg(x, Δ)
           Sx = S * x[]
           -0.5 * (x[]' * Sx), Δ * Sx
       end 
fg (generic function with 1 method)

julia> @benchmark fg($x, 2.3)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     21.280 ns (0.00% GC)
  median time:      21.294 ns (0.00% GC)
  mean time:        21.349 ns (0.00% GC)
  maximum time:     39.808 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     997

EDIT (December 29 2018):
I updated the above to reflect API changes in Zygote (@Grad -> @adjoint) and in StaticArrays (::MArray * ::SVector = ::MVector now, when it earlier = ::SVector).

Also, rerunning these benchmarks:

julia> using StaticArrays, BenchmarkTools

julia> using Zygote: @adjoint, gradient

julia> x = Ref(@SVector randn(4));

julia> const S = (@SMatrix randn(6,4)) |> x -> x' * x;

julia> g(x) = 0.5 * (x[]' * S * x[])
g (generic function with 1 method)

julia> g1(x) = 0.5 * (x[]' * S * x[])
g1 (generic function with 1 method)

julia> g2(x) = 0.5 * (x[]' * S * x[])
g2 (generic function with 1 method)

julia> g3(x) = 0.5 * (x[]' * S * x[])
g3 (generic function with 1 method)

julia> g4(x) = 0.5 * (x[]' * S * x[])
g4 (generic function with 1 method)

julia> @adjoint function g1(x)
           Sx = S*x[]
           -0.5 * (x[]' * Sx), Δ ->* Sx,)
       end

julia> const Sx = Ref(@SVector zeros(4))
Base.RefValue{SArray{Tuple{4},Float64,1,4}}([0.0, 0.0, 0.0, 0.0])

julia> @adjoint function g2(x)
           Sx[] = S*x[]
            -0.5 * (x[]' * Sx[]), Δ ->* Sx[],)
       end

julia> @adjoint function g3(x) #no reuse
            xs = x[]
           -0.5 * (xs' * S * xs), Δ ->* (S * xs),)
       end

julia> struct closure{T}
           data::T
       end

julia> (c::closure)( Δ ) = ( Δ * c.data, )

julia> @adjoint function g4(x)
           Sx = S * x[]
           c = closure(Sx)
           -0.5 * (x[]' * Sx), c
       end

julia> @benchmark gradient(g, $x)
BenchmarkTools.Trial: 
  memory estimate:  1.67 KiB
  allocs estimate:  32
  --------------
  minimum time:     1.495 μs (0.00% GC)
  median time:      1.692 μs (0.00% GC)
  mean time:        2.698 μs (35.13% GC)
  maximum time:     6.444 ms (99.91% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark gradient(g1, $x)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     4.077 ns (0.00% GC)
  median time:      4.158 ns (0.00% GC)
  mean time:        4.369 ns (0.00% GC)
  maximum time:     27.442 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark gradient(g2, $x)
BenchmarkTools.Trial: 
  memory estimate:  32 bytes
  allocs estimate:  2
  --------------
  minimum time:     28.375 ns (0.00% GC)
  median time:      29.372 ns (0.00% GC)
  mean time:        41.955 ns (24.99% GC)
  maximum time:     63.403 μs (99.93% GC)
  --------------
  samples:          10000
  evals/sample:     995

julia> @benchmark gradient(g3, $x)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     4.077 ns (0.00% GC)
  median time:      4.108 ns (0.00% GC)
  mean time:        4.230 ns (0.00% GC)
  maximum time:     11.211 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @benchmark gradient(g4, $x)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     4.077 ns (0.00% GC)
  median time:      4.098 ns (0.00% GC)
  mean time:        4.219 ns (0.00% GC)
  maximum time:     8.346 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

Versus analytical:

julia> function fg(x, Δ)
           Sx = S * x[]
           -0.5 * (x[]' * Sx), Δ * Sx
       end
fg (generic function with 1 method)

julia> @benchmark fg($x, 2.3)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.742 ns (0.00% GC)
  median time:      8.223 ns (0.00% GC)
  mean time:        8.369 ns (0.00% GC)
  maximum time:     51.007 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     999

Wow, that is good!

EDIT (October 24, 2019):
Things have regressed:

julia> @benchmark gradient(g, $x)
ERROR: Non-differentiable function Core._apply_iterate
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] macro expansion [inlined]
 [3] (::typeof((_apply_iterate)))(::Float64) at /home/chriselrod/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:19
 [4] * [inlined]
 [5] (::typeof((*)))(::Float64) at /home/chriselrod/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:0
 [6] g [inlined]
 [7] (::typeof((g)))(::Float64) at /home/chriselrod/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:0
 [8] (::Zygote.var"#32#33"{typeof((g))})(::Float64)
 [9] gradient(::Function, ::Base.RefValue{SArray{Tuple{4},Float64,1,4}})
 [10] ##core#432(::Base.RefValue{SArray{Tuple{4},Float64,1,4}})
 [11] ##sample#433(::BenchmarkTools.Parameters)
 [12] _run(::BenchmarkTools.Benchmark{Symbol("##benchmark#431")}, ::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Iterators.Pairs{Symbol,Integer,NTuple{4,Symbol},NamedTuple{(:samples, :evals, :gctrial, :gcsample),Tuple{Int64,Int64,Bool,Bool}}})
 [13] (::Base.var"#inner#2"{Base.Iterators.Pairs{Symbol,Integer,NTuple{5,Symbol},NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample),Tuple{Bool,Int64,Int64,Bool,Bool}}},typeof(BenchmarkTools._run),Tuple{BenchmarkTools.Benchmark{Symbol("##benchmark#431")},BenchmarkTools.Parameters}})()
 [14] #invokelatest#1 [inlined]
 [15] #run_result#37 at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:32 [inlined]
 [16] run(::BenchmarkTools.Benchmark{Symbol("##benchmark#431")}, ::BenchmarkTools.Parameters; kwargs::Base.Iterators.Pairs{Symbol,Integer,NTuple{5,Symbol},NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample),Tuple{Bool,Int64,Int64,Bool,Bool}}}) at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:46
 [17] #warmup#42 at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:79 [inlined]
 [18] warmup(::BenchmarkTools.Benchmark{Symbol("##benchmark#431")}) at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:79
 [19] top-level scope at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:213

julia> @benchmark gradient(g1, $x)
BenchmarkTools.Trial: 
  memory estimate:  80 bytes
  allocs estimate:  2
  --------------
  minimum time:     29.164 ns (0.00% GC)
  median time:      29.932 ns (0.00% GC)
  mean time:        38.310 ns (18.49% GC)
  maximum time:     4.436 μs (98.69% GC)
  --------------
  samples:          10000
  evals/sample:     996

julia> @benchmark gradient(g2, $x)
BenchmarkTools.Trial: 
  memory estimate:  48 bytes
  allocs estimate:  2
  --------------
  minimum time:     29.110 ns (0.00% GC)
  median time:      29.498 ns (0.00% GC)
  mean time:        33.915 ns (9.51% GC)
  maximum time:     3.449 μs (98.57% GC)
  --------------
  samples:          10000
  evals/sample:     996

julia> @benchmark gradient(g3, $x)
BenchmarkTools.Trial: 
  memory estimate:  592 bytes
  allocs estimate:  15
  --------------
  minimum time:     772.372 ns (0.00% GC)
  median time:      794.291 ns (0.00% GC)
  mean time:        850.069 ns (4.89% GC)
  maximum time:     24.392 μs (95.61% GC)
  --------------
  samples:          10000
  evals/sample:     148

julia> @benchmark gradient(g4, $x)
BenchmarkTools.Trial: 
  memory estimate:  80 bytes
  allocs estimate:  2
  --------------
  minimum time:     28.715 ns (0.00% GC)
  median time:      29.452 ns (0.00% GC)
  mean time:        38.063 ns (19.60% GC)
  maximum time:     4.677 μs (98.68% GC)
  --------------
  samples:          10000
  evals/sample:     996

from zygote.jl.

MikeInnes avatar MikeInnes commented on July 23, 2024

The approach for g1 should be fine, but you might find that Julia boxes the closure reference. Should be fixable if you use a let block.

from zygote.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.