-
-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use DifferentiationInterface for AD in Implicit Solvers #2567
base: master
Are you sure you want to change the base?
Conversation
In order for this to be completely done we'll need a DI equivalent for the SparseDiffTools Is a good way to do this to make an extension in SciMLOperators for DifferentiationInterface that will have something like a @ChrisRackauckas @oscardssmith @gdalle Any thoughts? |
Yes |
@avik-pal might already have one? |
Awesome work @jClugstor, thanks! Ping me when this is ready for a first round of DI-specific review.
Just to be clear, this wasn't possible before? So is this the first time that Enzyme can be used out-of-the-box to solve ODEs?
Another option, which requires a bit more work (and is probably not worth it) would be to make SparseDiffTools compatible with the sparsity API of ADTypes v1. I think it might allow a more seamless upgrade. See e.g. JuliaDiff/SparseDiffTools.jl#298 for the detection aspect, and there should be a similar issue for the coloring aspect. Speaking of SparseDiffTools, it still has an edge over DI when combined with FiniteDiff. The PR JuliaDiff/FiniteDiff.jl#191 could fix that, maybe @oscardssmith would be willing to take another look?
Agreed, preparation is a one-time cost so I don't think we should worry too much (at least in the prototype stage).
What do you mean by unexpected sparse things?
We may also want to involve @oschulz and his AutoDiffOperators package to avoid duplication of efforts? As a side note, DifferentiationInterface only has two dependencies: ADTypes and LinearAlgebra. For packages that use it extensively, I think it's reasonable to make it a full dep instead of a weakdep. |
@@ -7,6 +7,8 @@ version = "1.3.0" | |||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | |||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | |||
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" | |||
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | |||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Enzyme need to become a dependency? This adds significant install overhead, but if AutoEnzyme
is to be the new default AD then it makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, probably doesn't need to be a dependency unless we're committing to having it be the default.
@@ -25,6 +29,7 @@ ADTypes = "1.11" | |||
ArrayInterface = "7" | |||
DiffEqBase = "6" | |||
DiffEqDevTools = "2.44.4" | |||
DifferentiationInterface = "0.6.23" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DifferentiationInterface = "0.6.23" | |
DifferentiationInterface = "0.6.28" |
the other deps are also missing compat bounds?
alg, autodiff = AutoForwardDiff(chunksize = cs)) | ||
function prepare_ADType(alg::AutoFiniteDiff, prob, u0, p, standardtag) | ||
# If the autodiff alg is AutoFiniteDiff, prob.f.f isa FunctionWrappersWrapper, | ||
# and fdtype is complex, fdtype needs to change to something not complex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that DI does not explicitly support complex numbers yet. What I mean by that is that we forward things to the backend as much as possible, so if the backend does support complex numbers then it will probably work, but there are no tests or hard API guarantees on that. See JuliaDiff/DifferentiationInterface.jl#646 for the discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also note that some differentiation operators are not defined unambiguously for complex numbers (e.g. the derivative for complex input)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enzyme has an explicit variant of modes for complex numbers, that it probably would be wise to similarly wrap here (by default it will instead err warning about ambiguity if a function returns a complex number otherwise): https://enzyme.mit.edu/julia/stable/api/#EnzymeCore.ReverseHolomorphic . @gdalle I'm not sure DI supports this yet? so perhaps that means you may need to just call Enzyme.jacobian / autodiff directly in that case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jClugstor can you maybe specify where we will encounter complex numbers by filling the following table?
derivative | jacobian | |
---|---|---|
complex inputs possible | yes / no | yes / no |
complex outputs possible | yes / no | yes / no |
When there are both complex inputs and complex outputs, that's where we run into trouble because we cannot represent derivatives as a single scalar. In that case, the differentiation operators are not clearly defined (the Jacobian matrix is basically twice as big as it should be) so we would need to figure out what convention the ODE solvers need (see https://discourse.julialang.org/t/taking-complex-autodiff-seriously-in-chainrules/39317).
@wsmoses I understand your concern, but I find it encouraging that DI actually allowed Enzyme to be used here for the first time (or at least so I've been told). This makes me think that the right approach is to handle complex numbers properly in DI instead of introducing a special case for Enzyme?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure adding proper complex number support to DI would be great, but a three line change here to use in-spec Complex support when there's already overloads for other ADTypes feels reasonable?
e.g. something like
function jacobian(f, x::AbstractArray{<:Complex}, integrator::WhatevertheTypeIs{<:AutoEnzyme})
Enzyme.jacobian(ReverseHolomorphic, f, x)
end
from the discussion in JuliaDiff/DifferentiationInterface.jl#646 I think DI complex support is a much thornier issue. In particular, various tools have different conventions (e.g. jax vs pytorch pick different conjugates of what is propagated). So either DI needs to make a choice and shim/force all tools to use it (definitely doable), and then user code must be converted to that convention (e.g. a separate shim on the user side). For example, suppose DI picked a different conjugate from forwarddiff.jl. DI could write its shim once in forward diff to convert which is reasonable. But suppose one was defining a custom rule within ForwardDiff and the code called DI somewhere, now that user code needs to conditionally do a different the shim to conjugate which feels kind of nasty to be put everywhere (in contrast to a self consistent assumption). I suppose the other alternative is for DI to not pick a convention, but that again prevents users from using since it's not possible to know whether they get the correct value for them -- and worse, they won't know when they need to do a conversion or not.
Thus, if complex support is desired, a three line patch where things are explicitly supported seems okay (at least until the DI story is figured out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that for now, this change seems to do the job (although it raises the question of consistency with the other backends that are handled via DI). But what will happen if the function in question is not holomorphic? That's the thorniest part of the problem, and that's why I wanted to inquire a bit more as to what kind of functions we can expect. Perhaps @jClugstor or @ChrisRackauckas can tell us more?
In any case, I have started a discussion on Discourse to figure out the right conventions: https://discourse.julialang.org/t/choosing-a-convention-for-complex-numbers-in-differentiationinterface/124433
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also note that the Enzyme-specific fix only handles dense Jacobians, not sparse Jacobians (which are one of the main reasons to use DI in the first place)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I can't really tell you much about the complex number support, other than previously only ForwardDiff or FiniteDiff were used, so when someone used an implicit solver on a complex problem, their conventions were used I guess. Also just wanted to note that the code this comment is on is just making sure that the FiniteDiff fdtype isn't complex if the function is a function wrapper and doesn't have to do with complex numbers through the solver in general.
Add a dispatch to https://github.com/SciML/NonlinearSolve.jl/blob/master/lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl#L115 |
As far as I know this is the first time Enzyme has been used for the implicit solvers yes. |
@avik-pal I noticed that the constructors for your |
the prepare_jvp and prepare_vjp functions assume a 2/3 arg function for oop/iip respectively, that won't hold for ordinarydiffeq |
AutoDiffOperators is definitely open for any extensions/changes that would be necessary, and could also be moved from oschulz to a GitHub Julia org. |
@oschulz I think if we want to use it in OrdinaryDiffEq we should make sure that AutoDiffOperators adheres to the |
The way it is currently implemented, AutoDiffOperators is not opinionated there. To use using AutoDiffOperators, ADTypes, LinearMaps
using DifferentiationInterface
import ForwardDiff, Enzyme
ad = AutoEnzyme()
f(x) = x.^2 + reverse(x)
y, J_op = with_jacobian(f, x, LinearMap, ad) whereas
instantiates the full Jacobian in memory. So SciMLOperators could be supported via y, J_op = with_jacobian(f, x, SciMLOperators.FunctionOperator, ad) (or This way, different opererator "backends" can be implemented via Pkg extensions and the user can select which operator type they want. |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
This is at a point where we can do stuff like this:
and it actually uses sparsity detection and greedy jacobian coloring plus Enzyme to compute the Jacobians.
Some things I'm unsure about:
The current behavior is to use Jacobian coloring and SparseDiffTools by default. In order to keep that up, we have to wrap any ADType given in an AutoSparse unless it's already an AutoSparse. This does change the ADType that the user entered to be wrapped in an AutoSparse, which feels weird to me. Maybe there should be an option to just directly use the ADType entered, but by default we wrap it into an AutoSparse? I'm not sure.
The biggest issue is that the way the sparsity detectors work with DI is by using operator overloading (both TracerSparsityDetector and SymbolicsSparsityDetector do), but that's an issue when using AutoSpecialilzation, because of the FunctionWrappers. The solution I found was to just unwrap the function in the preparation process. I'm not sure what performance implication this will have, but I don't think it should do much, since the preparation should be run just once.
There's still pieces in here that use raw SparseDiffTools, (build_J_W) that I haven't looked in to how to convert to DI yet.
I may need to fix some of the versions.
There are some places that are getting sparse things where it's not expected.