Comments (6)
Not sure if that would win over a properly-optimised dual number implementation, but it's worth exploring.
idk where the code for this kernel is, but if your broadcast kernel is sufficiently static and has higher input arity than output arity, than you should be able to beat ForwardDiff there. It's when the kernel is dynamic that forward-mode shines.
Note, though, that if the input/output arity ratio is close enough to 1 (e.g. ~3) than LLVM might optimize the ForwardDiff calculation enough that it basically becomes the same as the reverse-mode computation. Usually it's a combination of scalar-level vectorization and reduction via sparsity (e.g. LLVM figures out that something is 0
and thus not worth computing) that achieves this.
from zygote.jl.
I can do you one better than an issue -- JuliaDiff/ForwardDiff.jl#357
from zygote.jl.
Most of these are done now:
julia> @btime logsumexp(x)
954.130 ns (1 allocation: 896 bytes)
5.141724249721802
julia> @btime derivative(logsumexp, x)
1.262 μs (21 allocations: 2.19 KiB)
Still more allocations than I'd like, but hopefully we can keep fixing those as we go along.
from zygote.jl.
The most egregious issue is that ForwardDiff's method for exp actually calculates exp(x) twice
Oof, can you open an issue in ForwardDiff for this? That's definitely a perf bug. We should add that to our tests...
from zygote.jl.
How can you efficiently reuse computations in common between the objective and gradient evaluations?
I'm commenting here because of the theme of optimizations, and reference to exp(x).
Below, g
relies on Zygote, g1
does it the naive way, written like you would a function, g2
uses a global constant for type stable storage, g3
doesn't try to save on computations at all, and g4
uses a struct to try and ensure a type stable closure.
g1
has type stability problems, but the others all come up clean with @code_warntype
.
julia> using StaticArrays, BenchmarkTools
julia> using Zygote: @adjoint, gradient
[ Info: Precompiling Zygote [e88e6eb3-aa80-5325-afca-941959d7151f]
[ Info: Precompiling IRTools [7869d1d1-7146-5819-86e3-90919afe41df]
julia> x = Ref(@SVector randn(4));
julia> const S = (@SMatrix randn(6,4)) |> x -> x' * x;
julia> g(x) = 0.5 * (x[]' * S * x[])
g (generic function with 1 method)
julia> g1(x) = 0.5 * (x[]' * S * x[])
g1 (generic function with 1 method)
julia> g2(x) = 0.5 * (x[]' * S * x[])
g2 (generic function with 1 method)
julia> g3(x) = 0.5 * (x[]' * S * x[])
g3 (generic function with 1 method)
julia> g4(x) = 0.5 * (x[]' * S * x[])
g4 (generic function with 1 method)
julia> @adjoint function g1(x)
Sx = S*x[]
-0.5 * (x[]' * Sx), Δ -> (Δ * Sx,)
end
julia> const Sx = Ref(@SVector zeros(4))
Base.RefValue{SArray{Tuple{4},Float64,1,4}}([0.0, 0.0, 0.0, 0.0])
julia> @adjoint function g2(x)
Sx[] = S*x[]
-0.5 * (x[]' * Sx[]), Δ -> (Δ * Sx[],)
end
julia> @adjoint function g3(x) #no reuse
xs = x[]
-0.5 * (xs' * S * xs), Δ -> (Δ * (S * xs),)
end
julia> struct closure{T}
data::T
end
julia> (c::closure)( Δ ) = ( Δ * c.data, )
julia> @adjoint function g4(x)
Sx = S * x[]
c = closure(Sx)
-0.5 * (x[]' * Sx), c
end
julia> @benchmark gradient(g, $x)
BenchmarkTools.Trial:
memory estimate: 1.53 KiB
allocs estimate: 28
--------------
minimum time: 2.797 μs (0.00% GC)
median time: 2.882 μs (0.00% GC)
mean time: 4.631 μs (33.58% GC)
maximum time: 11.214 ms (99.91% GC)
--------------
samples: 10000
evals/sample: 9
julia> @benchmark gradient(g1, $x)
BenchmarkTools.Trial:
memory estimate: 400 bytes
allocs estimate: 14
--------------
minimum time: 1.473 μs (0.00% GC)
median time: 1.520 μs (0.00% GC)
mean time: 2.871 μs (43.54% GC)
maximum time: 9.827 ms (99.90% GC)
--------------
samples: 10000
evals/sample: 10
julia> @benchmark gradient(g2, $x)
BenchmarkTools.Trial:
memory estimate: 32 bytes
allocs estimate: 2
--------------
minimum time: 64.501 ns (0.00% GC)
median time: 66.745 ns (0.00% GC)
mean time: 74.539 ns (7.69% GC)
maximum time: 32.162 μs (99.66% GC)
--------------
samples: 10000
evals/sample: 978
julia> @benchmark gradient(g3, $x)
BenchmarkTools.Trial:
memory estimate: 288 bytes
allocs estimate: 12
--------------
minimum time: 1.927 μs (0.00% GC)
median time: 1.959 μs (0.00% GC)
mean time: 2.321 μs (11.07% GC)
maximum time: 2.577 ms (99.68% GC)
--------------
samples: 10000
evals/sample: 10
julia> @benchmark gradient(g4, $x)
BenchmarkTools.Trial:
memory estimate: 576 bytes
allocs estimate: 15
--------------
minimum time: 2.126 μs (0.00% GC)
median time: 2.355 μs (0.00% GC)
mean time: 3.713 μs (37.68% GC)
maximum time: 10.920 ms (99.91% GC)
--------------
samples: 10000
evals/sample: 9
The constant global reference did the best, but was still much slower than I'd have expected, feature a few allocations.
julia> function fg(x, Δ)
Sx = S * x[]
-0.5 * (x[]' * Sx), Δ * Sx
end
fg (generic function with 1 method)
julia> @benchmark fg($x, 2.3)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 21.280 ns (0.00% GC)
median time: 21.294 ns (0.00% GC)
mean time: 21.349 ns (0.00% GC)
maximum time: 39.808 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 997
EDIT (December 29 2018):
I updated the above to reflect API changes in Zygote (@Grad -> @adjoint) and in StaticArrays (::MArray * ::SVector = ::MVector now, when it earlier = ::SVector).
Also, rerunning these benchmarks:
julia> using StaticArrays, BenchmarkTools
julia> using Zygote: @adjoint, gradient
julia> x = Ref(@SVector randn(4));
julia> const S = (@SMatrix randn(6,4)) |> x -> x' * x;
julia> g(x) = 0.5 * (x[]' * S * x[])
g (generic function with 1 method)
julia> g1(x) = 0.5 * (x[]' * S * x[])
g1 (generic function with 1 method)
julia> g2(x) = 0.5 * (x[]' * S * x[])
g2 (generic function with 1 method)
julia> g3(x) = 0.5 * (x[]' * S * x[])
g3 (generic function with 1 method)
julia> g4(x) = 0.5 * (x[]' * S * x[])
g4 (generic function with 1 method)
julia> @adjoint function g1(x)
Sx = S*x[]
-0.5 * (x[]' * Sx), Δ -> (Δ * Sx,)
end
julia> const Sx = Ref(@SVector zeros(4))
Base.RefValue{SArray{Tuple{4},Float64,1,4}}([0.0, 0.0, 0.0, 0.0])
julia> @adjoint function g2(x)
Sx[] = S*x[]
-0.5 * (x[]' * Sx[]), Δ -> (Δ * Sx[],)
end
julia> @adjoint function g3(x) #no reuse
xs = x[]
-0.5 * (xs' * S * xs), Δ -> (Δ * (S * xs),)
end
julia> struct closure{T}
data::T
end
julia> (c::closure)( Δ ) = ( Δ * c.data, )
julia> @adjoint function g4(x)
Sx = S * x[]
c = closure(Sx)
-0.5 * (x[]' * Sx), c
end
julia> @benchmark gradient(g, $x)
BenchmarkTools.Trial:
memory estimate: 1.67 KiB
allocs estimate: 32
--------------
minimum time: 1.495 μs (0.00% GC)
median time: 1.692 μs (0.00% GC)
mean time: 2.698 μs (35.13% GC)
maximum time: 6.444 ms (99.91% GC)
--------------
samples: 10000
evals/sample: 10
julia> @benchmark gradient(g1, $x)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 4.077 ns (0.00% GC)
median time: 4.158 ns (0.00% GC)
mean time: 4.369 ns (0.00% GC)
maximum time: 27.442 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
julia> @benchmark gradient(g2, $x)
BenchmarkTools.Trial:
memory estimate: 32 bytes
allocs estimate: 2
--------------
minimum time: 28.375 ns (0.00% GC)
median time: 29.372 ns (0.00% GC)
mean time: 41.955 ns (24.99% GC)
maximum time: 63.403 μs (99.93% GC)
--------------
samples: 10000
evals/sample: 995
julia> @benchmark gradient(g3, $x)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 4.077 ns (0.00% GC)
median time: 4.108 ns (0.00% GC)
mean time: 4.230 ns (0.00% GC)
maximum time: 11.211 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
julia> @benchmark gradient(g4, $x)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 4.077 ns (0.00% GC)
median time: 4.098 ns (0.00% GC)
mean time: 4.219 ns (0.00% GC)
maximum time: 8.346 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
Versus analytical:
julia> function fg(x, Δ)
Sx = S * x[]
-0.5 * (x[]' * Sx), Δ * Sx
end
fg (generic function with 1 method)
julia> @benchmark fg($x, 2.3)
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 7.742 ns (0.00% GC)
median time: 8.223 ns (0.00% GC)
mean time: 8.369 ns (0.00% GC)
maximum time: 51.007 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 999
Wow, that is good!
EDIT (October 24, 2019):
Things have regressed:
julia> @benchmark gradient(g, $x)
ERROR: Non-differentiable function Core._apply_iterate
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] macro expansion [inlined]
[3] (::typeof(∂(_apply_iterate)))(::Float64) at /home/chriselrod/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:19
[4] * [inlined]
[5] (::typeof(∂(*)))(::Float64) at /home/chriselrod/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:0
[6] g [inlined]
[7] (::typeof(∂(g)))(::Float64) at /home/chriselrod/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:0
[8] (::Zygote.var"#32#33"{typeof(∂(g))})(::Float64)
[9] gradient(::Function, ::Base.RefValue{SArray{Tuple{4},Float64,1,4}})
[10] ##core#432(::Base.RefValue{SArray{Tuple{4},Float64,1,4}})
[11] ##sample#433(::BenchmarkTools.Parameters)
[12] _run(::BenchmarkTools.Benchmark{Symbol("##benchmark#431")}, ::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Iterators.Pairs{Symbol,Integer,NTuple{4,Symbol},NamedTuple{(:samples, :evals, :gctrial, :gcsample),Tuple{Int64,Int64,Bool,Bool}}})
[13] (::Base.var"#inner#2"{Base.Iterators.Pairs{Symbol,Integer,NTuple{5,Symbol},NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample),Tuple{Bool,Int64,Int64,Bool,Bool}}},typeof(BenchmarkTools._run),Tuple{BenchmarkTools.Benchmark{Symbol("##benchmark#431")},BenchmarkTools.Parameters}})()
[14] #invokelatest#1 [inlined]
[15] #run_result#37 at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:32 [inlined]
[16] run(::BenchmarkTools.Benchmark{Symbol("##benchmark#431")}, ::BenchmarkTools.Parameters; kwargs::Base.Iterators.Pairs{Symbol,Integer,NTuple{5,Symbol},NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample),Tuple{Bool,Int64,Int64,Bool,Bool}}}) at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:46
[17] #warmup#42 at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:79 [inlined]
[18] warmup(::BenchmarkTools.Benchmark{Symbol("##benchmark#431")}) at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:79
[19] top-level scope at /home/chriselrod/.julia/packages/BenchmarkTools/7aqwe/src/execution.jl:213
julia> @benchmark gradient(g1, $x)
BenchmarkTools.Trial:
memory estimate: 80 bytes
allocs estimate: 2
--------------
minimum time: 29.164 ns (0.00% GC)
median time: 29.932 ns (0.00% GC)
mean time: 38.310 ns (18.49% GC)
maximum time: 4.436 μs (98.69% GC)
--------------
samples: 10000
evals/sample: 996
julia> @benchmark gradient(g2, $x)
BenchmarkTools.Trial:
memory estimate: 48 bytes
allocs estimate: 2
--------------
minimum time: 29.110 ns (0.00% GC)
median time: 29.498 ns (0.00% GC)
mean time: 33.915 ns (9.51% GC)
maximum time: 3.449 μs (98.57% GC)
--------------
samples: 10000
evals/sample: 996
julia> @benchmark gradient(g3, $x)
BenchmarkTools.Trial:
memory estimate: 592 bytes
allocs estimate: 15
--------------
minimum time: 772.372 ns (0.00% GC)
median time: 794.291 ns (0.00% GC)
mean time: 850.069 ns (4.89% GC)
maximum time: 24.392 μs (95.61% GC)
--------------
samples: 10000
evals/sample: 148
julia> @benchmark gradient(g4, $x)
BenchmarkTools.Trial:
memory estimate: 80 bytes
allocs estimate: 2
--------------
minimum time: 28.715 ns (0.00% GC)
median time: 29.452 ns (0.00% GC)
mean time: 38.063 ns (19.60% GC)
maximum time: 4.677 μs (98.68% GC)
--------------
samples: 10000
evals/sample: 996
from zygote.jl.
The approach for g1
should be fine, but you might find that Julia boxes the closure reference. Should be fixable if you use a let
block.
from zygote.jl.
Related Issues (20)
- `repeat(X; outer, inner)` triggers scalar indexing error with CUDA HOT 1
- Missing support for muladd in case of brodcasting with a complex argument HOT 1
- `nothing` in output of a `pullback` HOT 2
- Assignment to multiple arrays is not differentiable on GPU since Zygote.jl 0.6.67 HOT 5
- Spurious "Output is complex, so the gradient is not defined" error HOT 2
- NaN in gradient of abs() on complex 0 HOT 1
- Pullback on mean() gives illegal memory access code 700 HOT 32
- test
- Type unstable gradients (@code_warntype) HOT 1
- Type unstable gradients HOT 1
- Zygote gradients different from ForwardDiff/ReverseDiff on Julia 1.10-rc2 HOT 3
- try/catch is not supported when attempting to use `remake` with Zygote HOT 1
- gradient of SVD not working for complex input HOT 1
- `Zygote` doesn't properly work with `Metal.jl` and half precision. HOT 4
- `gradient` broken for `(*)(::Diagonal{Real}, ::Matrix{Complex}, ::Diagonal{Real})` when updating Julia 1.8 -> 1.9 HOT 6
- Method ambiguities reported by Aqua
- slow/high allocation gradient with mapreduce and iterators HOT 11
- error in summation of product iterator HOT 2
- `sort(x; rev=true)` is not supported HOT 1
- Incorrect gradients for `plan_rfft(x) * x` HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from zygote.jl.