clj-python / libpython-clj

Python bindings for Clojure
Eclipse Public License 2.0
1.08k stars 68 forks source link

->dataset not accepting pyobject that behaves like clojure map. #127

Closed keesterbrugge closed 3 years ago

keesterbrugge commented 4 years ago
setup

```clojure (ns libpython-examples.jax.numpyro-repo-examples.bayesian-regression (:require [clojure.tools.deps.alpha.repl :refer [add-lib]] [libpython-clj.require :refer [require-python]] [libpython-clj.python :refer [py.] :as py] [tech.ml.dataset :as ds] [tech.ml.dataset.pipeline :as dsp] [camel-snake-kebab.core :as csk])) (require-python '[jax.numpy :as jnp]) (require-python '[jax :refer [random vmap]]) ; (require-python '[ax.scipy.special :refer [logsumexp]]) ; errors (py/from-import jax.scipy.special logsumexp) (require-python 'numpyro) (require-python '[numpyro.diagnostics :refer [hpdi]]) (require-python '[numpyro.distributions :as dist]) (py/from-import numpyro handlers) (require-python '[numpyro.infer :refer [MCMC, NUTS]]) (require-python 'operator) (py/from-import numpyro.util set_host_device_count) ;; may help stability, ;; see https://github.com/clj-python/libpython-clj/issues/93#issuecomment-611202595 (set_host_device_count 1) ;; try out py/with-gil-stack-rc-context if still stability issues (def dset (ds/->dataset "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv" {:separator \; :key-fn csk/->kebab-case-keyword})) (def rng_key (py. random PRNGKey 0)) (def new_rng_keys (py. random split rng_key)) (def num_warmup 1000) (def num_samples 2000) (defn model "def model(marriage=None, age=None, divorce=None): a = numpyro.sample('a', dist.Normal(0., 0.2)) M, A = 0., 0. if marriage is not None: bM = numpyro.sample('bM', dist.Normal(0., 0.5)) M = bM * marriage if age is not None: bA = numpyro.sample('bA', dist.Normal(0., 0.5)) A = bA * age sigma = numpyro.sample('sigma', dist.Exponential(1.)) mu = a + M + A numpyro.sample('obs', dist.Normal(mu, sigma), obs=divorce)" [m] (let [marriage (m :marriage) divorce (m :divorce) median-age-marriage (m :median-age-marriage) a (numpyro/sample "a" (dist/Normal 0. 1)) M (if marriage (operator/mul marriage (numpyro/sample "bM" (dist/Normal 0. 0.5))) 0) A (if median-age-marriage (operator/mul median-age-marriage (numpyro/sample "bA" (dist/Normal 0. 0.5))) 0) mu (reduce jnp/add [a M A]) sigma (numpyro/sample "sigma" (dist/Exponential 1.))] (numpyro/sample "obs" (dist/Normal mu sigma) :obs divorce))) (defn map-vals [m f] (reduce-kv (fn [m k v] (assoc m k (f v))) {} m)) (def model-input (-> dset dsp/std-scale (select-keys [:marriage :divorce]) (map-vals jnp/array))) (def kernel (NUTS model)) (def mcmc (MCMC kernel num_warmup num_samples)) (py. mcmc run (last new_rng_keys) model-input) ```

The value for samples behaves a lot like a map, but not quite. I hoped I would be able to turn it directly into a dataset using ->dataset but I first need to turn it into clojure map. Is this expected behaviour?

(def samples (py. mcmc get_samples))

samples
;; => {'a': DeviceArray([ 0.01160221,  0.10573955,  0.00896238, ..., -0.17218843,
;;                  0.16966356,  0.1265155 ], dtype=float32), 'bM': DeviceArray([0.35984573, 0.26309696, 0.07760117, ..., 0.22615762,
;;                 0.66590315, 0.4897392 ], dtype=float32), 'sigma': DeviceArray([1.1117138 , 0.926471  , 0.8273705 , ..., 1.1577476 ,
;;                 0.95815885, 0.9083632 ], dtype=float32)}

(keys samples)
;; => ("a" "bM" "sigma")

(type samples)
;; => :pyobject

(samples "a")
;; => [ 0.01160221  0.10573955  0.00896238 ... -0.17218843  0.16966356
;;      0.1265155 ]

(vals samples)
;; => ([ 0.01160221  0.10573955  0.00896238 ... -0.17218843  0.16966356
;;      0.1265155 ]
;;     [0.35984573 0.26309696 0.07760117 ... 0.22615762 0.66590315 0.4897392 ]
;;     [1.1117138  0.926471   0.8273705  ... 1.1577476  0.95815885 0.9083632 ])

(into {} samples)
;; => {"a" [ 0.01160221  0.10573955  0.00896238 ... -0.17218843  0.16966356
;;      0.1265155 ],
;;     "bM" [0.35984573 0.26309696 0.07760117 ... 0.22615762 0.66590315 0.4897392 ],
;;     "sigma" [1.1117138  0.926471   0.8273705  ... 1.1577476  0.95815885 0.9083632 ]}

(ds/->dataset samples)
; => error: 
(comment 
; ; Execution error (UnsupportedOperationException) at tech.ml.dataset.parse.mapseq$map$reify$reify__65638/next (mapseq.clj:28).
; nth not supported on this type: bridge$generic_python_as_jvm$reify__47315
  [{:file "RT.java" :line 991 :method "nthFrom" :flags [:java]}
   {:file "RT.java" :line 940 :method "nth" :flags [:java]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:28" :fn "map/reify/reify"  :method "next" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/parse/spreadsheet.clj:283" :fn "process-spreadsheet-rows"  :method "invokeStatic" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/parse/spreadsheet.clj:267" :fn "process-spreadsheet-rows"  :method "invoke" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/parse/mapseq.clj:77" :fn "mapseq->dataset"  :method "invokeStatic" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/parse/mapseq.clj:50" :fn "mapseq->dataset"  :method "invoke" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/base.clj:938" :fn "->dataset"  :method "invokeStatic" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/base.clj:783" :fn "->dataset"  :method "invoke" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/base.clj:943" :fn "->dataset"  :method "invokeStatic" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/techascent/tech.ml.dataset/4.04/tech.ml.dataset-4.04.jar!/tech/ml/dataset/base.clj:783" :fn "->dataset"  :method "invoke" :flags [:clj]}
   {:file "NO_SOURCE_FILE" :line 134 :fn "eval75618"  :method "invokeStatic" :flags [:project :repl :clj]}
   {:file "NO_SOURCE_FILE" :line 134 :fn "eval75618"  :method "invoke" :flags [:dup :project :repl :clj]}
   {:file "Compiler.java" :line 7177 :method "eval" :flags [:tooling :java]}
   {:file "Compiler.java" :line 7132 :method "eval" :flags [:dup :tooling :java]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:3214" :fn "eval"  :method "invokeStatic" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:3210" :fn "eval"  :method "invoke" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.8.2/nrepl-0.8.2.jar!/nrepl/middleware/interruptible_eval.clj:87" :fn "evaluate/fn/fn"  :method "invoke" :flags [:tooling :clj]}
   {:file "AFn.java" :line 152 :method "applyToHelper" :flags [:java]}
   {:file "AFn.java" :line 144 :method "applyTo" :flags [:java]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:665" :fn "apply"  :method "invokeStatic" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:1973" :fn "with-bindings*"  :method "invokeStatic" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:1973" :fn "with-bindings*"  :method "doInvoke" :flags [:dup :clj]}
   {:file "RestFn.java" :line 425 :method "invoke" :flags [:java]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.8.2/nrepl-0.8.2.jar!/nrepl/middleware/interruptible_eval.clj:87" :fn "evaluate/fn"  :method "invoke" :flags [:tooling :clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:437" :fn "repl/read-eval-print/fn"  :method "invoke" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:437" :fn "repl/read-eval-print"  :method "invoke" :flags [:dup :clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:458" :fn "repl/fn"  :method "invoke" :flags [:clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:458" :fn "repl"  :method "invokeStatic" :flags [:dup :clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:368" :fn "repl"  :method "doInvoke" :flags [:clj]}
   {:file "RestFn.java" :line 1523 :method "invoke" :flags [:java]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.8.2/nrepl-0.8.2.jar!/nrepl/middleware/interruptible_eval.clj:84" :fn "evaluate"  :method "invokeStatic" :flags [:tooling :clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.8.2/nrepl-0.8.2.jar!/nrepl/middleware/interruptible_eval.clj:56" :fn "evaluate"  :method "invoke" :flags [:tooling :clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.8.2/nrepl-0.8.2.jar!/nrepl/middleware/interruptible_eval.clj:152" :fn "interruptible-eval/fn/fn"  :method "invoke" :flags [:tooling :clj]}
   {:file "AFn.java" :line 22 :method "run" :flags [:java]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.8.2/nrepl-0.8.2.jar!/nrepl/middleware/session.clj:202" :fn "session-exec/main-loop/fn"  :method "invoke" :flags [:tooling :clj]}
   {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.8.2/nrepl-0.8.2.jar!/nrepl/middleware/session.clj:201" :fn "session-exec/main-loop"  :method "invoke" :flags [:tooling :clj]}
   {:file "AFn.java" :line 22 :method "run" :flags [:java]}
   {:file "Thread.java" :line 745 :method "run" :flags [:java]}])

(ds/->dataset (into {} samples))

;; => _unnamed [2000 3]:
;;    
;;    |           a |         bM |      sigma |
;;    |-------------|------------|------------|
;;    |  0.01160221 | 0.35984573 | 1.11171377 |
;;    |  0.10573955 | 0.26309696 | 0.92647099 |
;;    |  0.00896238 | 0.07760117 | 0.82737052 |
;;    | -0.13939470 | 0.20474765 | 0.89149576 |
;;    | -0.09953939 | 0.43031254 | 0.91204131 |
;;    | -0.15807903 | 0.12200018 | 1.12738907 |
;;    |  0.17177662 | 0.46327570 | 1.15788937 |
;;    | -0.03298758 | 0.42113632 | 0.94018894 |
;;    |  0.03134891 | 0.29801247 | 0.92503810 |
;;    |  0.12909520 | 0.41303188 | 0.89625633 |
;;    | -0.14704621 | 0.26977766 | 1.04158199 |
;;    | -0.23374042 | 0.36004829 | 0.98456824 |
;;    |  0.29089901 | 0.29492304 | 0.96671295 |
;;    | -0.10400254 | 0.24859585 | 1.02428293 |
;;    | -0.15271212 | 0.30198431 | 0.84687734 |
;;    | -0.05826668 | 0.52216291 | 0.96207654 |
;;    | -0.04622043 | 0.38314888 | 0.94919014 |
;;    | -0.11454899 | 0.44680309 | 0.88769615 |
;;    |  0.01220503 | 0.39206502 | 0.95826173 |
;;    | -0.06571017 | 0.31993464 | 1.01031780 |
;;    |  0.02975514 | 0.16330233 | 0.96139914 |
;;    |  0.04367065 | 0.57259941 | 0.91072547 |
;;    | -0.16276032 | 0.23627713 | 0.92811120 |
;;    | -0.23922886 | 0.30689222 | 1.05620122 |
;;    | -0.02052609 | 0.61347049 | 1.01348138 |
jjtolton commented 3 years ago

@keesterbrugge, yes -- for now, this is the expected behavior. ->dataset is quite powerful, but it doesn't do everything (yet!). Long term we are working towards a more polymorphic data manipulation interface.

keesterbrugge commented 3 years ago

ok, thx