leanprover / lean4

Lean 4 programming language and theorem prover
https://lean-lang.org
Apache License 2.0
4.63k stars 414 forks source link

Simp and structure projection problem #1041

Closed lecopivo closed 2 years ago

lecopivo commented 2 years ago

I have a simp lemma that behaves differently on projection Prod.fst : α × β → α and on a custom function my_fst : α × β → α.

The lemma is that an adjoint of composition is composition of adjoints: adjoint (λ x => f (g x)) = adjoint g ∘ adjoint f. Interestingly, the simplifier applies this lemma on adjoint Prod.fst with g = id but it does not apply this lemma on adjoint my_fst.

What is going on? Why is Prod.fst and my_fst treated by the simplifier differently? Is this an intended behavior? Is it an unfortunate consequence of another reasonable choice?

mwe:


variable {α β γ} [Inhabited α] [Inhabited β] [Inhabited γ]

constant adjoint : (α → β) → (β → α)

@[simp]
theorem adjoint_of_comp (f : β → γ) (g : α → β) :
  adjoint (λ x => f (g x)) = adjoint g ∘ adjoint f 
  := 
  by sorry

@[simp]
theorem adjoint_of_fst :
  adjoint (Prod.fst : α × β → α) = fun x : α => (x, default) 
  := by sorry

example : adjoint (Prod.fst : α × β → α) = fun x : α => (x, default) 
  :=
by
  -- This applies `adjoint_of_comp` for some reason
  simp (config := { singlePass := true })

  simp -- infinite recursion
  done

constant my_fst : α × β → α 

@[simp]
theorem adjoint_of_my_fst :
  adjoint (my_fst : α × β → α) = fun x : α => (x, default) 
  := by sorry

example : adjoint (my_fst : α × β → α) = fun x : α => (x, default) 
  :=
by
  simp
  done
leodemoura commented 2 years ago

The key problem here is the simp theorem adjoint_of_comp. It requires high-order matching which is approximated in Lean. We should have a linter in the future that generates a warning whenever this kind of theorem is marked with simp. When we try to apply it in the first example, we have to solve the following constraint

  adjoint (fun x => ?f (?g x))  =?= adjoint (@Prod.fst α β)

This constraint is successfully solved using the following derivation

  fun x => ?f (?g x)  =?= @Prod.fst α β    
  fun x => ?f (?g x)  =?= fun x => @Prod.fst α β x   -- by eta
  ?f (?g x) =?= @Prod.fst α β x                      -- isDefEq for lambda
  ?f =?= @Prod.fst α β;   ?g x =?= x                 -- by first-order approximation
  ?f := @Prod.fst α β;  ?g := fun x => x

We use an index for selecting theorems to be applied by simp. The "key" for adjoint_of_comp is adjoint <other>. We use other for terms such as lambdas. The retrieval in this index is performed modulo reducible constants. Proj.fst is reducible. Thus, when looking for theorems to apply at adjoint (@Prod.fst α β), we unfold Prod.fst, and obtain adjoint (fun x => ...) which produces the key adjoint <other>, and the theorem adjoint_of_comp is retrieved. Note that simp loops because it can keep applying adjoint_of_comp using g := fun x => x. In your second example, my_fst is a constant and cannot be reduced. So, the theorem adjoint_of_comp is never even tried.

You should avoid marking theorems such as adjoint_of_comp with [simp]. You can write a "first-order" version that is only applied when Function.comp is used for function composition. That is, the theorem will only be used if the user explicitly used Function.comp.

@[simp]
theorem adjoint_of_comp (f : β → γ) (g : α → β) :
  adjoint (f ∘ g) = adjoint g ∘ adjoint f
  :=
  by sorry
lecopivo commented 2 years ago

Thank you very much for the detailed explanation. I remember you mentioned somewhere else that theorems like adjoint_of_comp are not simp safe but I didn't understand why. Now it makes more sense.

Unfortunately, my project currently stands on Lean's ability to do higher order matching and simp theorems like adjoint_of_comp in the original formulation. I don't see how I can use only the "first-order" version.

I will probably make a specialized tactic which will interleave rewrites with unsafe theorems like adjoint_of_comp and simp. I have encountered similar problem before and this solution seems to work reasonably well.

leodemoura commented 2 years ago

Unfortunately, my project currently stands on Lean's ability to do higher order matching and simp theorems like adjoint_of_comp in the original formulation. I don't see how I can use only the "first-order" version.

I see. Could you please describe the use-cases where you want adjoint_of_comp to succeed? We may be able to tweak the implementation a bit to support them.

I think @gebner and @digama0 have worked around this kind of limitation before while developing Mathlib and may have nice tricks.

gebner commented 2 years ago

cc @Vierkantor, who just gave a talk at MCM about more or less exactly this problem.

lecopivo commented 2 years ago

I'm building automatic/symbolic differentiation library.

To give an example let's compute the gradient of the squared norm(⟪x,x⟫ is inner product on X):

∇ (λ x : X => 2 * ⟪x,x⟫)  =  (λ x : X => 4 * x)

The gradient is defined with differential δ and adjoint as ∇ f = λ x => (δ f x)† 1 for f : X → ℝ

Step 1: differential

The first step is to compute the differential, for x₀ dx : X

 δ (λ x => 2 * ⟪x,x⟫) x₀ dx = 2 * (⟪dx,x₀⟫ + ⟪x₀,dx⟫)

This is done with following steps:

δ (λ x => 2 * ⟪x,x⟫) x₀ dx

rule: δ (λ x => f (g x)) x₀ dx = δ f (g x₀) (δ g x₀ dx)

δ (HMul.hMul 2) (⟪x₀,x₀⟫) (δ (λ x => ⟪x,x⟫) x₀ dx)

rule: δ (λ (x : X) => f (g1 x) (g2 x)) x₀ dx = δ f (g1 x₀) (δ g1 x₀ dx) (g2 x₀) + δ (f (g1 x₀)) (g2 x₀) (δ g2 x₀ dx)

δ (HMul.hMul 2) (⟪x₀,x₀⟫) (δ (⟪·,·⟫) x₀ (δ (λ x => x) x₀ dx) x₀ + δ (⟪x₀,·⟫) x₀ (δ (λ x => x) x₀ dx))

rule: δ fun x => x ==> fun x dx => dx

δ (HMul.hMul 2) (⟪x₀,x₀⟫) (δ (⟪·,·⟫) x₀ dx x₀ + δ (⟪x₀,·⟫) x₀ dx)

rule: if f is linear → δ f x dx = f dx and HMul.hMul 2, ⟪·,·⟫ and ⟪x₀,·⟫ are linear functions

2 * (⟪dx,x₀⟫ + ⟪x₀,dx⟫)

Step 2: adjoint

The second step is to compute the adjoint

(λ dx : X => 2 * (⟪dx,x₀⟫ + ⟪x₀,dx⟫))† 1 = 4 * x₀

It is computed with the following steps

(λ dx : X => 2 * (⟪dx,x₀⟫ + ⟪x₀,dx⟫))†

rule: (λ x => f (g x))† = g† ∘ f†

((λ dx : X => (⟪dx,x₀⟫ + ⟪x₀,dx⟫))† ∘ (HMul.hMul 2)†)

rule: (λ x => f (g1 x) (g2 x))† = (uncurry HAdd.hAdd) ∘ (Prod.map g1† g2†) ∘ (uncurry f)†

uncurry HAdd.hAdd ∘ Prod.map (⟪·,x₀⟫)† (⟪x₀,·⟫)† ∘ (uncurry HAdd.hAdd)† ∘ (HMul.hMul 2)†

atomic adjoints: of inner product (⟪·,x₀⟫)† = λ s => s * x₀ of addition (uncurry HAdd.hAdd)† = λ x => (x,x) of scalar multiplication (HMul.hMul 2)† = HMul.hMul 2

uncurry HAdd.hAdd ∘ Prod.map (·*x₀) (·*x₀) ∘ (λ x => (x,x)) ∘ (HMul.hMul 2)

Now evaluate this at 1 and do some basic siplification and you should get 4 * x₀


If it would help, I can build a small self contained example that is doing these kind of rewrites.


One particularly problematic rule is this one:

adjoint (λ x => f (g x) b) = adjoint g ∘ adjoint (λ y => f y b)

which has a tendency to loop infinitely as the pattern matches on adjoint (λ y => f y b) with g = (λ x => x). My solution is to exclude this theorem from simp and then have specialized tactic that interleaves simp and this rewrite, see the code.

lecopivo commented 2 years ago

Just to point out, the first step in both computations is to apply rule how differential or adjoint acts on function composition. Without it the computation/simplification can't be broken down to smaller pieces.

lecopivo commented 2 years ago

I just want to expand on my "I don't see how I can use only the "first-order" version."

For very long time, almost a year, I was attempting what I'm doing with the "first-order" versions of those theorems. I have a tactic that can eliminate simply typed lambdas with SKI combinators and their variants.

The problem is that proving linearity or differentiabily of the resulting expressions is really difficult, especually when dealing with higher order functions. And all those theorems I'm using are usually conditioned on differentiabily or linearity of certain functions.

When I realized that Learn can do higher order matching and I can work with lambdas directly, have ditched SKI combinators and finally started making some progress.

lecopivo commented 2 years ago

cc @Vierkantor, who just gave a talk at MCM about more or less exactly this problem.

Was it recorded? I would be interested in hearing it.

leodemoura commented 2 years ago

@lecopivo What about adding instances of adjoint_of_comp that are simp "friendly"? Example:

-- @[simp] -- Remove problematic `simp` annotation
theorem adjoint_of_comp (f : β → γ) (g : α → β) :
  adjoint (λ x => f (g x)) = adjoint g ∘ adjoint f 
  := 
  by sorry

@[simp]
theorem adjoint_of_mul_comp [Mul β] (b : β) (g : α → β) : adjoint (λ x => b * g x) = adjoint g ∘ adjoint (HMul.hMul b) :=
 adjoint_of_comp (HMul.hMul b) g

If you have a finite number of them, then this approach may work.

lecopivo commented 2 years ago

I don't think that would scale. I can go over all my unit tests and print out all the uses of adjoint_of_comp and alike. I don't believe its uses can be shrunk to some small set of theorems with specialized f or g

One of my aims is to build flexible machine learning framework where you can easily define a neural network. A network is just a composition of layers and for each type of layer you would have to provide theorem like that. I want the framework to be really flexible and allow to easily define a new layers by users. It would kind of defeat the purpose if a user would have to provide these specialized theorems.

lecopivo commented 2 years ago

I went through the computation of gradient of (Δx/(2*m)) * ∥p∥² + (Δx * k/2) * (∑ i, ∥x[i] - x[i-1]∥²) w.r.t to x and p and the uses of *_of_comp theorems are:

Differential of composition: δ (λ x => f (g x)) x₀ dx = δ f (g x₀) (δ g x₀ dx)

  1. f = HAdd.hAdd
  2. f = HAdd.hAdd c for some constant c
  3. f = HMul.hMul c for some constant c
  4. f = (sum : (Fin n -> ℝ) -> ℝ)
  5. f = (λ x => ∥x∥²)

Adjoint of composition: (λ x => f (g x))† = g† ∘ f†

  1. f = HMul.hMul c for some constant c

It looks like that f is fairly simple or maybe what I would call 'atomic' differentiable function. For these 'atomic' functions I have to write a differentiation rule manually e.g. δ (λ x : X => ∥x∥²) = λ x dx : X => 2*⟪x, dx⟫. So it might be feasible to write down the composition rule too. The slight problem is that I have around 10 different variants of the composition rule, but that can be automated.

If it turns out that f is always simple/atomic in applications that would be indeed an interesting discovery. However, I'm a bit skeptical. What I really want to do is symbolic variational calculus, i.e. differentiation of higher order functions, and there I'm expecting f might not always be some 'atomic' function but can very easily be some lambda's bound variable. Unfortunately, I do not yet have interesting enough examples to make a conclusion if that is the case or not.

lecopivo commented 2 years ago

I wrote a small self contained differentiation test of lambda expressions. There are three main simp theorems D_I D_K D_S corresponding to the three basic combinators. These theorems should be in theory sufficient to differentiate any lambda expression(plus some theorems about addition).

I'm really surprised, simp performs really well. Few months ago, when I was writing the first version of the library, simp was unable to solve D_diag or D_parm automatically.

constant D (f : α → β) : α → α → β := λ x _ => f x  -- just an arbitrary definition to provide Inhabited

variable {α β β₁ β₂ γ δ : Type}
variable [Inhabited α] [Inhabited β] [Inhabited γ] [Inhabited δ] [Inhabited β₁] [Inhabited β₂]
variable [Add α] [Add β] [Add γ] [Add δ] [Add β₁] [Add β₂]

instance {α β : Type} [Add β] : Add (α → β) := ⟨λ f g x => f x + g x⟩

 -- default plays role of a zero
@[simp]
axiom add_default {α} [Inhabited α] [Add α] (x : α) : x + default = x
@[simp]
axiom default_add {α} [Inhabited α] [Add α] (x : α) : default + x = x
-- derivative in zero direction is zero
@[simp]
axiom D_default (f : α → β) (x) : D f x default = default

@[simp]
theorem default_eval {α} [Inhabited β] (x : α) : (default : α → β) x = default :=
by
  unfold default; unfold instInhabitedForAll_1; unfold default; simp; done

-- Basic combinators
@[simp]
axiom D_I         
  : D (λ x : α => x) = λ x dx => dx 
@[simp]
axiom D_K (y : β) 
  : D (λ x : α => y) = λ x dx => default 
@[simp]
axiom D_S (f : α → β → γ) (g : α → β) 
  : D (λ x => f x (g x)) = λ x dx => D (f x) (g x) (D g x dx) + D f x dx (g x)

set_option trace.Meta.Tactic.simp.rewrite true

theorem D_B (f : β → γ) (g : α → β) 
  : D (λ x => f (g x)) = λ x dx => D f (g x) (D g x dx) :=
by 
  simp; done

theorem D_W (f : α → α → β)
  : D (λ x => f x x) = λ x dx => D (f x) x dx + D f x dx x :=
by
  simp; done

theorem D_diag (f : β₁ → β₂ → γ) (g₁ : α → β₁) (g₂ : α → β₂)
  : D (λ x => f (g₁ x) (g₂ x)) = λ x dx => D (f (g₁ x)) (g₂ x) (D g₂ x dx) + D f (g₁ x) (D g₁ x dx) (g₂ x) := 
by
  simp; done

theorem D_parm (f : α → β → γ) (g : β₁ → α) (b : β)
  : D (λ x => f (g x) b) = λ x dx => D f (g x) (D g x dx) b :=
by
  simp; done

In proving D_diag and D_parm, D_S is used with ?f = (fun x => f (g₁ x)) and ?f = (fun x => f). These are not "atomic" expressions as in the previous example.

lecopivo commented 2 years ago

Here are some tests for derivatives of higher order functions. Things get a bit more complicated. For example, in the proof of D_diag_g₁_g₂ the theorem D_S is used with ?f = (fun g₁ => f (g₁ x)).

I think this kills the idea of generating simp "friendly" variants for each "atomic" function. You can have for example ?f = (fun g₁ => sin(exp(cos(g₁ x)))). Therefore you would have to generate infinite number of D_S variants.

The following works when appended to the previous example:

theorem D_partial (f : α → β → γ)
  : D (λ x y => f x y) = λ x dx y => D (λ x' => f x' y) x dx :=
by
  simp; done

theorem D_K_y 
  : D (λ (y : β) (x : α) => y) = λ y dy x => dy :=
by 
  rw[D_partial];
  simp; done

theorem D_S_g_x (f : α → β → γ)
  : D (λ (g : α → β) (x : α) => f x (g x)) = λ g dg x => D (f x) (g x) (dg x) :=
by 
  rw[D_partial]; simp; done

theorem D_S_f_g_x 
  : D (λ (f : α → β → γ) (g : α → β) (x : α) => f x (g x)) = λ f df g x => df x (g x) :=
by 
  -- conv => lhs; rw[D_partial]; rw[D_partial];
  rw[D_partial]; funext f df g; rw[D_partial];
  simp; done

theorem D_diag_g₂_x (f : β₁ → β₂ → γ) (g₁ : α → β₁)
  : D (λ (g₂ : α → β₂) (x : α) => f (g₁ x) (g₂ x)) = λ g₂ dg₂ x => D (f (g₁ x)) (g₂ x) (dg₂ x) := 
by
  rw[D_partial]
  simp; done

theorem D_diag_g₁_g₂ (f : β₁ → β₂ → γ)
  : D (λ (g₁ : α → β₁) (g₂ : α → β₂) (x : α) => f (g₁ x) (g₂ x)) = λ g₁ dg₁ g₂ x => D f (g₁ x) (dg₁ x) (g₂ x) := 
by
  -- conv => lhs; rw[D_partial]; rw[D_partial];
  rw[D_partial]; funext g₁ dg₁ g₂ x; rw[D_partial];
  simp; done
lecopivo commented 2 years ago

One odd thing is that the second rewrite in conv => lhs; rw[D_partial]; rw[D_partial]; fails.

To make it succeed, it is necessary to introduce arguments either with enter[g₁, dg₁, g₂, x] or funext g₁ dg₁ g₂ x. This looks like a bug.

lecopivo commented 2 years ago

I'm actually really happy how the high-order matching works right now and I do not mind it is just an approximation!

So back to the original question: I think, I have managed to understand where the problem is and I do not see a neat solution to it. Luckily, modifying priority works and I will settle with that, not ideal and probably fragile.

An example in the context of the differentiation above. If priority is just mid then D_S is applied indefinitely.

@[simp mid+1]
axiom D_fst : D (Prod.fst : α × β → α) = λ (a,b) (da,db) => da

example : D (Prod.fst : α × β → α) = λ (a,b) (da,db) => da :=
by
  simp; done
digama0 commented 2 years ago

cc @Vierkantor, who just gave a talk at MCM about more or less exactly this problem.

Was it recorded? I would be interested in hearing it.

They were recorded, although the recordings are still pending public release due to some legal issues that will hopefully be straightened out eventually.

leodemoura commented 2 years ago

I am planning to close this issue soon. If you object, say something.