Closed keesterbrugge closed 3 years ago
@keesterbrugge, yes -- for now, this is the expected behavior. ->dataset
is quite powerful, but it doesn't do everything (yet!). Long term we are working towards a more polymorphic data manipulation interface.
ok, thx
setup
```clojure (ns libpython-examples.jax.numpyro-repo-examples.bayesian-regression (:require [clojure.tools.deps.alpha.repl :refer [add-lib]] [libpython-clj.require :refer [require-python]] [libpython-clj.python :refer [py.] :as py] [tech.ml.dataset :as ds] [tech.ml.dataset.pipeline :as dsp] [camel-snake-kebab.core :as csk])) (require-python '[jax.numpy :as jnp]) (require-python '[jax :refer [random vmap]]) ; (require-python '[ax.scipy.special :refer [logsumexp]]) ; errors (py/from-import jax.scipy.special logsumexp) (require-python 'numpyro) (require-python '[numpyro.diagnostics :refer [hpdi]]) (require-python '[numpyro.distributions :as dist]) (py/from-import numpyro handlers) (require-python '[numpyro.infer :refer [MCMC, NUTS]]) (require-python 'operator) (py/from-import numpyro.util set_host_device_count) ;; may help stability, ;; see https://github.com/clj-python/libpython-clj/issues/93#issuecomment-611202595 (set_host_device_count 1) ;; try out py/with-gil-stack-rc-context if still stability issues (def dset (ds/->dataset "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv" {:separator \; :key-fn csk/->kebab-case-keyword})) (def rng_key (py. random PRNGKey 0)) (def new_rng_keys (py. random split rng_key)) (def num_warmup 1000) (def num_samples 2000) (defn model "def model(marriage=None, age=None, divorce=None): a = numpyro.sample('a', dist.Normal(0., 0.2)) M, A = 0., 0. if marriage is not None: bM = numpyro.sample('bM', dist.Normal(0., 0.5)) M = bM * marriage if age is not None: bA = numpyro.sample('bA', dist.Normal(0., 0.5)) A = bA * age sigma = numpyro.sample('sigma', dist.Exponential(1.)) mu = a + M + A numpyro.sample('obs', dist.Normal(mu, sigma), obs=divorce)" [m] (let [marriage (m :marriage) divorce (m :divorce) median-age-marriage (m :median-age-marriage) a (numpyro/sample "a" (dist/Normal 0. 1)) M (if marriage (operator/mul marriage (numpyro/sample "bM" (dist/Normal 0. 0.5))) 0) A (if median-age-marriage (operator/mul median-age-marriage (numpyro/sample "bA" (dist/Normal 0. 0.5))) 0) mu (reduce jnp/add [a M A]) sigma (numpyro/sample "sigma" (dist/Exponential 1.))] (numpyro/sample "obs" (dist/Normal mu sigma) :obs divorce))) (defn map-vals [m f] (reduce-kv (fn [m k v] (assoc m k (f v))) {} m)) (def model-input (-> dset dsp/std-scale (select-keys [:marriage :divorce]) (map-vals jnp/array))) (def kernel (NUTS model)) (def mcmc (MCMC kernel num_warmup num_samples)) (py. mcmc run (last new_rng_keys) model-input) ```
The value for
samples
behaves a lot like a map, but not quite. I hoped I would be able to turn it directly into a dataset using->dataset
but I first need to turn it into clojure map. Is this expected behaviour?