FluxML / ZygoteRules.jl

MIT License
15 stars 13 forks source link

Add `clamptype` mechanism, to project into cotangent space #16

Closed mcabbott closed 3 years ago

mcabbott commented 3 years ago

This adds a mechanism to constrain tangents based on the type of the input. The initial goal is ensuring that real numbers do not accidentally acquire complex gradients, and that some LinearAlgebra structured arrays are preserved.

Edit -- the bulk of the implementation is now at https://github.com/FluxML/Zygote.jl/pull/965, this is just a stub. (Which ought to be non-breaking.) I've moved discussion there.

willtebbutt commented 3 years ago

I generally really like this, but I wonder whether there's a better name availlable.

This is a bit verbose, but something like project_onto_cotangent_space but might be more informative?

My other question is whether there are any situations in which type information is insufficient to make this work? For example, to know how to clamp a SparseMatrixCSC you need to know the precise sparsity pattern, which is only available at runtime, so this seems like an example where you need the object itself.

On a separate note, what are the considerations regarding doing this here rather than in ChainRules? It's been something I've been thinking about a bit (see here -- I've not gotten very far with it yet though: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/286), so I'd be keen to see it happen in ChainRules. Once we've got Zygote onto ChainRules types, the rule-definition tools here could just utilise this functionality from ChainRules, so it seems plausible that we would just want to implement it there.

mcabbott commented 3 years ago

Can I suggest we take the discussion to the Zygote PR? I put most of the words there now.

Done, thanks! Could delete this.