clj-python / libpython-clj

Python bindings for Clojure
Eclipse Public License 2.0
1.06k stars 69 forks source link

map destructuring missing kv doesn't seem to bind value nil #126

Closed keesterbrugge closed 3 years ago

keesterbrugge commented 3 years ago

Hi Chris, thx again for this amazing library.

I'm trying out numpyro for some probabilistic programming.

I suspect something is going wrong. When you desctructure a map on keywords that are missing, they are bound to the value nil, see

(defn missing-keys-test-fn [{:keys [x y z]}]
  {:x x :y y :z z})

(missing-keys-test-fn {})
;; => {:x nil, :y nil, :z nil}

The fn model takes a map as argument and I bind the value of the :y key. I thought that the following 3 cases would yield the same result, but case 1 fails. Am I missing something?

(ns libpython-examples.jax.numpyro-fail-option-map-nil-value
  (:require [libpython-clj.require :refer [require-python]]
            [libpython-clj.python :refer [py.] :as py]))

(require-python '[jax.numpy :as jnp])
(require-python '[jax :refer [random]])
(require-python 'numpyro)
(require-python '[numpyro.distributions :as dist])
(require-python '[numpyro.infer :refer [MCMC, NUTS]])
(require-python 'operator)

(def rng_key (py. random PRNGKey 0))
(def new_rng_keys (py. random split rng_key))
(def num_warmup 1000)
(def num_samples 2000)

(defn model 
  [{:keys [y]}]
  (let [mu (numpyro/sample "mu" (dist/Normal 0. 0.2))]
    (numpyro/sample "obs" (dist/Normal mu 1) :obs y)))

(def kernel (NUTS model))
(def mcmc (MCMC kernel num_warmup num_samples))

;; CASE 1: this does not work, throws error
(py. mcmc run (last new_rng_keys) {})
;; =>
;; ; Execution error at libpython-clj.python.interpreter/check-error-throw (interpreter.clj:499).
; Traceback (most recent call last):
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 446, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 312, in _single_chain_mcmc
    init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 450, in init
    init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 407, in _init_state
    init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/infer/util.py", line 394, in initialize_model
; 
;     inv_transforms, replay_model, has_enumerate_support, model_trace = _get_model_transforms(
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/infer/util.py", line 265, in _get_model_transforms
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/handlers.py", line 156, in get_trace
    self(*args, **kwargs)
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/primitives.py", line 68, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/primitives.py", line 68, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/keesterbrugge/opt/anaconda3/lib/python3.8/site-packages/numpyro/primitives.py", line 68, in __call__
    return self.fn(*args, **kwargs)
Exception: java.lang.Exception: KeyError: 'y'
:java.lang.Exception: KeyError: 'y'

 at libpython_clj.python.interpreter$check_error
; _throw.invokeStatic (interpreter.clj:499)
;     libpython_clj.python.interpreter$check_error_throw.invoke (interpreter.clj:497)
    libpython_clj.python.object$wrap_pyobject.invokeStatic (object.clj:173)
    libpython_clj.python.object$wrap_pyobject.invoke (object.clj:120)
    libpython_clj.python.object$eval45745$fn__45746.invoke (object.clj:839)
    libpython_clj.python.protocols$eval40569$fn__40570$G__40560__40579.invoke (protocols.clj:105)
    libpython_clj.python.bridge$py_impl_call_raw.invokeStatic (bridge.clj:336)
    libpython_clj.python.bridge$py_impl_call_raw.invoke (bridge.clj:333)
    libpython_clj.python.bridge$py_impl_call_as.invokeStatic (bridge.clj:344)
    libpython_clj.python.bridge$py_impl_call_as.invoke (bridge.clj:342)
    libpython_clj.python.bridge$generic_python_as_map$py_call__46299.doInvoke (bridge.clj:367)
    clojure.lang.RestFn.invoke (RestFn.java:423)
    libpython_clj.python.bridge$generic_python_as_map$reify__46307.get (bridge.clj:381)
    clojure.lang.RT.getFrom (RT.java:769)

;   clojure.lang.RT.get (RT.java:761)
;     libpython_examples.jax.numpyro_fail_option_map_nil_value$model.invokeStatic (NO_SOURCE_FILE:39)
    libpython_examples.jax.numpyro_fail_option_map_nil_value$model.invoke (NO_SOURCE_FILE:39)
    clojure.lang.AFn.applyToHelper (AFn.java:154)
    clojure.lang.AFn.applyTo (AFn.java:144)
    clojure.core$apply.invokeStatic (core.clj:665)
    clojure.core$apply.invoke (core.clj:660)
    libpython_clj.python.object$make_tuple_fn$reify__45514.pyinvoke (object.clj:510)
    sun.reflect.GeneratedMethodAccessor11.invoke (:-1)
    sun.reflect.DelegatingMethodAccessorImpl.invoke (DelegatingMethodAccessorImpl.java:43)
    java.lang.reflect.Method.invoke (Method.java:498)
    com.sun.jna.CallbackReference$DefaultCallbackProxy.invokeCallback (CallbackReference.java:520)
    com.sun.jna.CallbackReference$DefaultCallbackProxy.callback (CallbackReference.java:551)
    libpython_clj.jna.DirectMapped.PyObject_CallObject (DirectMapped.java:-2)
    libpython_clj.jna.protocols.object$PyObject_C
; allObject.invokeStatic (object.clj:305)
;     libpython_clj.jna.protocols.object$PyObject_CallObject.invoke (object.clj:295)
    libpython_clj.python.object$eval45745$fn__45746.invoke (object.clj:836)
    libpython_clj.python.protocols$eval40569$fn__40570$G__40560__40579.invoke (protocols.clj:105)
    libpython_clj.python.bridge$generic_python_as_jvm$reify__46630.do_call_fn (bridge.clj:675)
    libpython_clj.python.protocols$call_attr_kw.invokeStatic (protocols.clj:132)
    libpython_clj.python.protocols$call_attr_kw.invoke (protocols.clj:127)
    libpython_examples.jax.numpyro_fail_option_map_nil_value$eval121553.invokeStatic (NO_SOURCE_FILE:48)
    libpython_examples.jax.numpyro_fail_option_map_nil_value$eval121553.invoke (NO_SOURCE_FILE:48)
    clojure.lang.Compiler.eval (Compiler.java:7177)
    clojure.lang.Compiler.eval (Compiler.java:7132)
    clojure.core$eval.invokeStatic (core.clj:3214)
    clojure.core$eval.invoke (core.clj:3210)
    clojure.main$repl$read_eval_print__9086$fn__9089.invoke (main.clj:43
; 7)
;     clojure.main$repl$read_eval_print__9086.invoke (main.clj:437)
    clojure.main$repl$fn__9095.invoke (main.clj:458)
    clojure.main$repl.invokeStatic (main.clj:458)
    clojure.main$repl.doInvoke (main.clj:368)
    clojure.lang.RestFn.invoke (RestFn.java:1523)
    nrepl.middleware.interruptible_eval$evaluate.invokeStatic (interruptible_eval.clj:79)
    nrepl.middleware.interruptible_eval$evaluate.invoke (interruptible_eval.clj:55)
    nrepl.middleware.interruptible_eval$interruptible_eval$fn__935$fn__939.invoke (interruptible_eval.clj:142)
    clojure.lang.AFn.run (AFn.java:22)
    nrepl.middleware.session$session_exec$main_loop__1036$fn__1040.invoke (session.clj:171)
    nrepl.middleware.session$session_exec$main_loop__1036.invoke (session.clj:170)
    clojure.lang.AFn.run (AFn.java:22)
    java.lang.Thread.run (Thread.java:745)

[{:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/interpreter.clj:499" :fn "check-error-throw"  :method "invokeStatic" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/interpreter.clj:497" :fn "check-error-throw"  :method "invoke" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/object.clj:173" :fn "wrap-pyobject"  :method "invokeStatic" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/object.clj:120" :fn "wrap-pyobject"  :method "invoke" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/object.clj:839" :fn "eval45745/fn"  :method "invoke" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/protocols.clj:105" :fn "eval40569/fn/G"  :method "invoke" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/bridge.clj:675" :fn "generic-python-as-jvm/reify"  :method "do_call_fn" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/protocols.clj:132" :fn "call-attr-kw"  :method "invokeStatic" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/clj-python/libpython-clj/1.46/libpython-clj-1.46.jar!/libpython_clj/python/protocols.clj:127" :fn "call-attr-kw"  :method "invoke" :flags [:clj]}
 {:file "NO_SOURCE_FILE" :line 48 :fn "eval121553"  :method "invokeStatic" :flags [:project :repl :clj]}
 {:file "NO_SOURCE_FILE" :line 48 :fn "eval121553"  :method "invoke" :flags [:dup :project :repl :clj]}
 {:file "Compiler.java" :line 7177 :method "eval" :flags [:tooling :java]}
 {:file "Compiler.java" :line 7132 :method "eval" :flags [:dup :tooling :java]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:3214" :fn "eval"  :method "invokeStatic" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/core.clj:3210" :fn "eval"  :method "invoke" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:437" :fn "repl/read-eval-print/fn"  :method "invoke" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:437" :fn "repl/read-eval-print"  :method "invoke" :flags [:dup :clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:458" :fn "repl/fn"  :method "invoke" :flags [:clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:458" :fn "repl"  :method "invokeStatic" :flags [:dup :clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/org/clojure/clojure/1.10.1/clojure-1.10.1.jar!/clojure/main.clj:368" :fn "repl"  :method "doInvoke" :flags [:clj]}
 {:file "RestFn.java" :line 1523 :method "invoke" :flags [:java]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.6.0/nrepl-0.6.0.jar!/nrepl/middleware/interruptible_eval.clj:79" :fn "evaluate"  :method "invokeStatic" :flags [:tooling :clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.6.0/nrepl-0.6.0.jar!/nrepl/middleware/interruptible_eval.clj:55" :fn "evaluate"  :method "invoke" :flags [:tooling :clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.6.0/nrepl-0.6.0.jar!/nrepl/middleware/interruptible_eval.clj:142" :fn "interruptible-eval/fn/fn"  :method "invoke" :flags [:tooling :clj]}
 {:file "AFn.java" :line 22 :method "run" :flags [:java]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.6.0/nrepl-0.6.0.jar!/nrepl/middleware/session.clj:171" :fn "session-exec/main-loop/fn"  :method "invoke" :flags [:tooling :clj]}
 {:file "jar:file:/Users/keesterbrugge/.m2/repository/nrepl/nrepl/0.6.0/nrepl-0.6.0.jar!/nrepl/middleware/session.clj:170" :fn "session-exec/main-loop"  :method "invoke" :flags [:tooling :clj]}
 {:file "AFn.java" :line 22 :method "run" :flags [:java]}
 {:file "Thread.java" :line 745 :method "run" :flags [:java]}]

;; CASE 2: explicit :y key with nil value DOES work
(py. mcmc run (last new_rng_keys) {:y nil})

(defn model2
  [{:keys [y] :or {y nil}}]
  (let [mu (numpyro/sample "mu" (dist/Normal 0. 0.2))]
    (numpyro/sample "obs" (dist/Normal mu 1) :obs y)))

(def kernel (NUTS model2))
(def mcmc (MCMC kernel num_warmup num_samples))

; CASE 3: now this works too
(py. mcmc run (last new_rng_keys) {})
jjtolton commented 3 years ago

Hi there! That's a lot of domain specific code -- are you able to narrow down the problem a little?

cnuernber commented 3 years ago

The root of this is a different standard definition of map and dict. Map's get method returns nil by definition if the key isn't found and a python dict raises KeyError. I guess the implementation of the Map bridge for python objects is incorrect and it should catch that exception and return nil.

jjtolton commented 3 years ago

Interesting. My only concern would be that changing the Map bridge globally might have some surprising behavior. KeyError is supposed to be raised with a missing key in Python, since None is a valid value :thinking:.

That being said, it would be much more convenient from the Clojure perspective. Trying to think if there would be any unanticipated consequences of this