google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.58k stars 107 forks source link

Enforce the constraints implied by the new type-parameter role system #1139

Open apaszke opened 1 year ago

apaszke commented 1 year ago

See the description of #1138 for the two invariants.

dougalm commented 1 year ago

Specifically, quoting #1138 :

  1. instances must be parametric over all data arguments. You can't define an instance like Ix (Fin 10). It has to be {n:Nat} (Ix (Fin n)) instead.
  2. type constructor parameters must be one of type/data/dict, so we can't have, e.g., functions as type constructor parameters.

2 is blocked on having a Data constraint: #1150 1 can be done any time. One way to do it: when we see an instance, like ixbad : Ix (Fin 10) or ixgood : (n:Int) -> Ix (Fin n), we instantiate it with fresh inference vars and unify the result with the class applied to skolem parameters. So for ixbad, we'd end up unifying Fin 10 with Fin skolem, which would fail. Whereas for ixgood, we'd unify Fin ?1 with Fin skolem, which would succeed with ?1 -> skolem.

duvenaud commented 1 year ago

I don't think this is important, but FYI I found a use for functions as type constructor parameters. Specifically, for defining Poisson processes:

data PoissonProcess = 
  MkPoissonProcess (Float -> LogSpace Float) (Float) (Float)   -- rate function, t0, t1

instance Random PoissonProcess (List Float)
  draw = \(MkPoissonProcess rate t0 t1) k.
    sample_poisson_process rate t0 t1 k

I was surprised that this works, and I guess I was right to be, since it sounds like you're planning to disallow this?

In any case, I wouldn't block anything to save this little piece of code.

If you're curious, here's the entire example, that runs:

import stats

'# Poisson Processes

def get_last {n a} (xs:n=>a) : Maybe a =
  case size n > 0 of
    True -> Just xs.(unsafe_from_ordinal _ (unsafe_nat_diff (size n) 1))
    False -> Nothing

def logspace_1d_cumsum (n:Type) [Ix n]
    (rate: Float -> LogSpace Float) (t0:Float) (t1:Float) :
    (n=>Float & n=>LogSpace Float & n=>LogSpace Float & LogSpace Float) =  -- rates, cumulative rates, total
  rate_eval_locations = linspace n t0 t1
  dt = log $ (t1 - t0) / n_to_f (size n)  -- -1?
  log_rate_evals = for i. (Exp dt) * rate rate_eval_locations.i
  cs = cumsum log_rate_evals  -- todo: log cumsum exp?
  total = case get_last cs of
    Just total -> total
    Nothing -> zero
  (rate_eval_locations, log_rate_evals, cs, total)

def sample_poisson_process (rate: Float -> LogSpace Float) (t0:Float) (t1:Float) (key:Key) : List Float =
  (rate_eval_locations, log_rate_evals, cs, total) =
    logspace_1d_cumsum (Fin 100) rate t0 t1

  N = draw (Poisson (ls_to_f total)) key

  locs = for i.
    ix = categorical (map ln log_rate_evals) (ixkey key i)
    rate_eval_locations.ix  -- todo: add interpolation?

  AsList N locs

def poisson_process_density (rate: Float -> LogSpace Float) (t0:Float) (t1:Float) (xs:List Float) : LogSpace Float =
  (rate_eval_locations, log_rate_evals, cs, total) =
    logspace_1d_cumsum (Fin 100) rate t0 t1

  (AsList N xtab) = xs
  poisson_pmf = density (Poisson (ls_to_f total)) N

  loc_density = sum for i.
    bin_ix = from_just $ search_sorted rate_eval_locations xtab.i
    log_rate_evals.bin_ix

  loc_density + poisson_pmf

data PoissonProcess = 
  MkPoissonProcess (Float -> LogSpace Float) (Float) (Float)

instance Random PoissonProcess (List Float)
  draw = \(MkPoissonProcess rate t0 t1) k.
    sample_poisson_process rate t0 t1 k

instance Dist PoissonProcess (List Float) Float
  density = \(MkPoissonProcess rate t0 t1) x.
    todo

import plot

def example_rate (x:Float) : (LogSpace Float) =
  Exp (-2.0 - 3.0 * sin (3.0 * x))

t0 = -2.0
t1 = 1.4
num_evals = 100
eval_spots = (Unit | Fin (unsafe_nat_diff num_evals 1))
rate_eval_locations = linspace eval_spots t0 t1
rate_evals = map (\x. (ls_to_f $ example_rate x) / n_to_f num_evals) rate_eval_locations

:html show_plot $ xy_plot rate_eval_locations rate_evals

pp = MkPoissonProcess example_rate t0 t1
(AsList _ vs) : List Float = (draw pp (new_key 0))
vs

def myfunc (r:Float) : Float =
  example_rate = \x:Float. Exp (x - 10.0 * sin (3.0 * x))
  t0 = -2.0
  t1 = 3.4
  pp = MkPoissonProcess example_rate t0 t1
  (AsList _ vs) : List Float = (draw pp (new_key 0))
  vs.(0@_)

grad myfunc 1.0
dougalm commented 1 year ago

Oh, that's definitely allowed! In your example, the function is an argument to the data constructor, MkPoissonProcess, rather than the type constructor, PoissonProcess.

duvenaud commented 1 year ago

Thanks for explaining, I hadn't understood the distinction between 'type constructor' and 'data constructor'.