FStarLang / FStar

A Proof-oriented Programming Language
https://www.fstar-lang.org
Apache License 2.0
2.63k stars 233 forks source link

Confusing behavior with refinements #3282

Open mtzguido opened 1 month ago

mtzguido commented 1 month ago

Not sure I can call this a bug, but:

You would expect this lemma to be easily provable like this

type r = int

let rec bigsum (m : nat) (n : nat {m <= n}) (f : (i:nat { m <= i /\ i < n } -> r)) : Tot r =
  if m = n then 0 else f (n-1) + bigsum m (n-1) f

let pointwise_add (f g : 'a -> r) : 'a -> r = fun x -> f x + g x

let rec bigsum_pointwise_add
  (m : nat) (n : nat {m <= n})
  (f : (i:nat { m <= i /\ i < n } -> r))
: Lemma (bigsum m n (pointwise_add f f) == bigsum m n f + bigsum m n f) =
  if m = n then () else (
    bigsum_pointwise_add m (n-1) f
  )

However this fails ("could not prove post-condition") regardless of how much one tries to spell out the proof. Even asserting the postcondition right after the recursive call fails:

let rec bigsum_pointwise_add
  (m : nat) (n : nat {m <= n})
  (f : (i:nat { m <= i /\ i < n } -> r))
: Lemma (bigsum m n (pointwise_add f f) == bigsum m n f + bigsum m n f) =
  if m = n then () else (
    bigsum_pointwise_add m (n-1) f;
    assert (bigsum m (n-1) (pointwise_add f f) == bigsum m (n-1) f + bigsum m (n-1) f);
    admit()
  )
* Error 19 at Bug.fst(16,4-16,10):
  - Assertion failed
  - The SMT solver could not prove the query. Use --query_stats for more
    details.
  - See also Bug.fst(16,11-16,86)

The trick is that the implicit of pointwise_add is being instantiated differently. In the recursive call, the domain of f is refined to be between m and n-1 instead of m and n, so one would expect this to work:

let rec bigsum_pointwise_add
  (m : nat) (n : nat {m <= n})
  (f : (i:nat { m <= i /\ i < n } -> r))
: Lemma (bigsum m n (pointwise_add f f) == bigsum m n f + bigsum m n f) =
  if m = n then () else (
    bigsum_pointwise_add m (n-1) f;
    assert (bigsum m (n-1) (pointwise_add #(i:nat { (m <= i /\ i < n-1) }) f f) == bigsum m (n-1) f + bigsum m (n-1) f);
    admit()
  )

Alas, that also fails with the same error. The reason being that F* matched the refinements of both instances of f and computed their meet, which is their conjunction. So, finally, this works:

let rec bigsum_pointwise_add
  (m : nat) (n : nat {m <= n})
  (f : (i:nat { m <= i /\ i < n } -> r))
: Lemma (bigsum m n (pointwise_add f f) == bigsum m n f + bigsum m n f) =
  if m = n then () else (
    bigsum_pointwise_add m (n-1) f;
    assert (bigsum m (n-1) (pointwise_add #(i:nat { (m <= i /\ i < n-1) /\ (m <= i /\ i < n-1) }) f f) == bigsum m (n-1) f + bigsum m (n-1) f);
    admit()
  )

The errors throughout this debugging are pretty bad. But, also, I think the last failed attempt should work as the refinements are logically equivalent (which F* can prove), so the refinement types should be considered equal.

(For proving the actual lemma, I think it has to be restated like this

let rec bigsum_pointwise_add
  (m' n' : nat)
  (m : nat{m >= m'}) (n : nat {m <= n /\ n <= n'})
  (f : (i:nat { m' <= i /\ i < n' } -> r))
: Lemma (bigsum m n (pointwise_add f f) == bigsum m n f + bigsum m n f) =
  if m = n then () else (
    bigsum_pointwise_add m' n' m (n-1) f;
    ()
  )

)

gebner commented 3 weeks ago

We briefly talked about this last week. One possible solution could be to use total functions instead. That is:

let rec bigsum (m : nat) (n : nat) (f : nat -> r) : Tot r =
  if m >= n then 0 else f (n-1) + bigsum m (n-1) f

This is also what the on_range predicate does in pulse. In practice, this is not a big restriction as you can write bigsum m n (fun i -> if m <= i && i < n then f i else 0) instead to make any range precondition in f type check.