google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.58k stars 106 forks source link

Add a Multivariate Normal distribution and expand the GP regression example. #1332

Open emilyfertig opened 1 year ago

emilyfertig commented 1 year ago

Could someone please help me define the lower-triangular chol_cov in stats-tests? Everything I've tried throws either a type error or a compiler bug (let me know if you want me to file an issue for the latter). Thanks!

axch commented 1 year ago

I can't reproduce the compiler bug. All three of the examples you have in stats-tests.dx seem to work fine:

chol_cov_mat : Fin 2=>Fin 2=>Float = [[0.2, 0.], [-0.3, 0.1]]
chol_cov1 = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject(to=Fin 2, j)]
chol_cov2 = for i:(Fin 2). for j:(..i). chol_cov_mat[i, (ordinal j)@_]
chol_cov3 = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject j]

Perhaps you need a dex clean? make just-build doesn't do that by default because it triggers a (slow-ish) recompilation of the prelude, but if you don't clean when the Dex source changes, you can get weird incompatibilities sometimes.

As for the type error from trying to write a triangular literal directly, that was broken intentionally by PR https://github.com/google-research/dex-lang/pull/1207, because the machinery to support the feature was too heavy. The recommended idiom now is to just use coerce_table, thus:

chol_cov : (i:Fin 2)=>(..i)=>Float = [
  coerce_table _ [0.2],
  coerce_table _ [-0.3, 0.1]]
emilyfertig commented 1 year ago

Thanks -- I ran dex clean and defining chol_cov works for me now. I'm seeing the compiler bug below. Let me know if you have any ideas for a workaround (or if I should file a Github issue).


-- multivariate normal                                        (
                                                              (
chol_cov : (i:Fin 2)=>(..i)=>Float = [                        (
  coerce_table _ [0.2],                                       (
  coerce_table _ [-0.3, 0.1]]                                 (
                                                              (
-- chol_cov_mat : Fin 2=>Fin 2=>Float = [[0.2, 0.], [-0.3, 0. (
-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inj (
                                                              (
loc : (Fin 2=>Float) = [1., 2.]                               (
draw (MultivariateNormal loc chol_cov) (new_key 0) :: (Fin 2= (
> [0.706645, 2.599938]                                        | > Compiler bug!
                                                              > > Please report this at github.com/google-research/dex-lang/i
                                                              > >
                                                              > > Unexpected table: chol_co
                                                              > > CallStack (from HasCallStack):
                                                              > >   error, called at src/lib/Simplify.hs:570:22 in dex-0.1.0.
                                                              (
ln (density (MultivariateNormal [1., 1] chol_cov) [0.5, 0.5]) (
> True                                                        | > Compiler bug!
                                                              > > Please report this at github.com/google-research/dex-lang/i
                                                              > >
                                                              > > Unexpected table: chol_co
                                                              > > CallStack (from HasCallStack):
                                                              > >   error, called at src/lib/Simplify.hs:570:22 in dex-0.1.0.

``
axch commented 1 year ago

Looks like a legit bug, which I minimized and refiled as #1333, but I don't know how to fix it.