FStarLang / FStar

A Proof-oriented Programming Language
https://www.fstar-lang.org
Apache License 2.0
2.65k stars 232 forks source link

Bad type inference when using subtyping with monads #3309

Open TWal opened 4 weeks ago

TWal commented 4 weeks ago

The problem is described in the comments of the following code:

// Two monads, with returns and binds
assume val m1: Type0 -> Type0
assume val m2: Type0 -> Type0

assume val return1: #a:Type0 -> a -> m1 a
assume val return2: #a:Type0 -> a -> m2 a

assume val (let?): #a:Type0 -> #b:Type0 -> m2 a -> (a -> m2 b) -> m2 b
assume val (let*?): #a:Type0 -> #b:Type0 -> m1 (m2 a) -> (a -> m1 (m2 b)) -> m1 (m2 b)

// Suppose we have a monadic function that returns a `nat`
// (a refined type, namely a subtype of `int`)
assume val f: unit -> m2 nat

// And we want to lift it in a double monadic function that returns an `int`
val g: unit -> m1 (m2 int)

// The naive lifting with `return1` fails as expected:
// return1 (f ()) has type m1 (m2 nat),
// which cannot be converted to m1 (m2 int),
// (well-known problem of sub-typing under an inductive)
[@@ expect_failure]
let g () =
  return1 (f ())

// The standard fix with inductive subtyping is to unwrap and re-wrap,
// but this fails:
[@@ expect_failure]
let g () =
  // We would hope that x here has type `nat`
  let*? x = return1 (f ()) in
  // we can do the subtyping on x, and return x as an int
  return1 (return2 x)

// It doesn't work, because the last line implies that `x` must have type `int`,
// therefore `return1` is used with `a = m2 int`,
// which doesn't work becuase `f` returns an `m2 nat`

// We now explicit the implicits of `return1`, but it still fails:
[@@ expect_failure]
let g () =
  // We would hope that x here has type `nat`
  let*? x = return1 #(m2 nat) (f ()) in
  // we can do the subtyping on x, and return x as an int
  return1 (return2 x)

// It still fails, because `x` has type `nat`, so obviously `return1 (return2 x)` has type `m1 (m2 nat)`!
// Adding a coercion finally fixes things
let g () =
  // We would hope that x here has type `nat`
  let*? x = return1 #(m2 nat) (f ()) in
  // we can do the subtyping on x, and return x as an int
  return1 (return2 (x <: int))

// If we are not lucky to have a type abbreviation like `nat`,
// but an explicit refinement, the workaround looks like this:
val g': unit -> m1 (m2 int)
let g' () =
  //                      vvvvvvvv (ugh)
  let*? x = return1 #(m2 (_:int{_})) (f ()) in
  return1 (return2 (x <: int))

// One other workaround, with an utility function

val unrefine:
  #a:Type0 -> #p:(a -> Type0) ->
  x:a{p x} ->
  a
let unrefine #a #p x = x

val g'': unit -> m1 (m2 int)
let g'' () =
  let*? x = return1 (f ()) in
  return1 (return2 (unrefine x))

I feel like it would be reasonable to expect that the second attempt at defining g (un-wrapping and re-wrapping the inductive) should work, why isn't it the case?

In the line let*? x = return1 (f ()) in, the implicit must be resolved as nat because anything else would lead to a type error, and in the line return1 (return2 x) the implicit must be resolved as int for the same reason, I don't get why the type inference take other decisions.

mtzguido commented 3 weeks ago

Thanks Théophile. I think this is a minimization of the same problem:

val bind : #a:Type -> #b:Type -> a -> (a -> b) -> b
let bind x f = f x

let test (x:nat) =
  bind x (fun n ->
  assert (n >= 0); n)

The assert fails since the implicit for bind is inferred to be int.

mtzguido commented 3 weeks ago

Dumping some thoughts: 1- This does not work since we do not accumulate subtyping constraints (i.e. we try to solve deferred constraints) when checking an abstraction, so the implicit defaults to int since that is how it's used in the body. 2- Even if we did, we would get the following two constraints in the unifier: nat <: ?u and ?u <: int. There's no guarantee that we solve to nat instead of int. Perhaps we can do a pass to collect these constraints and solve to the most restrictive one, but it's also unclear that's the right choice...

Minimized a bit more:

let test (x:nat) =
  let f = (fun n -> assert (n >= 0); n) in
  f x