clj-python / libpython-clj

Python bindings for Clojure
Eclipse Public License 2.0
1.05k stars 68 forks source link

How to `**` apply an object as kwargs #209

Open den1k opened 2 years ago

den1k commented 2 years ago

how can this part from this huggingface example

outputs = model(**encoding, labels=torch.LongTensor([1]))

be translated to libpython-clj?

jjtolton commented 2 years ago
(require '[libpython-clj2.python :as py])
(require '[libpython-clj2.require :refer [import-python])
(import-python) ;; gives us access to python builtins
;; py** has similar syntax to ** in Python, last argument is considered **kwargs
;; py.. is similar in syntax to (..   ) syntax in Clojure
;; note that python datatypes need to be wrapped when using the helper macros,
;; that's why I'm employing (python/list [1]) instead of a Clojure vector of [1]
(def outputs (py/py** model :labels (py/py.. torch (LongTensor (python/list [1]))) encoding))
jjtolton commented 2 years ago

If you put in the Python import statements you used I can show you how to translate them as well. I omitted them from the above snippet.

den1k commented 2 years ago

thank you @jjtolton! I tried your example and it still did not work. I think the issue is that :labels is not a method on model.

I was able to invoke it in this cumbersome way:

(libpython-clj2.python.fn/call-kw
  model
  []
  (py/as-map (py/set-attr! encoding :labels (torch/LongTensor [1]))))

Re: import statements, here's the entire code example


(require-python '[transformers :bind-ns])
(require-python '[torch :bind-ns])

(def tokenizer
  (py. (py.- transformers BertTokenizer)
       from_pretrained "bert-base-uncased"))

(def model
  (py. (py.- transformers "BertForNextSentencePrediction") from_pretrained "bert-base-uncased"))

(let [prompt        "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
      next-sequence "The sky is blue due to the shorter wavelength of blue light."
      encoding      (py/$c tokenizer prompt next-sequence :return_tensors "pt")]
  (libpython-clj2.python.fn/call-kw
    model
    []
    (py/as-map (py/set-attr! encoding :labels (torch/LongTensor [1])))))
jjtolton commented 1 year ago

Ahhh that's a tricky situation I did not consider. model is a raw python object with a dynamic __call__ method, which the py/py{*} macros do not account for. The user experience with the low level code you used could be improved upon by making some friendly and more idiomatic APIs. When I get time I'll make a PR to make this a better experience. Thanks for digging this up and I'm glad you found a solution.

jjtolton commented 1 year ago

@den1k Looking back on this, the solution is

(py/py** model __call__ :labels (py/py.. torch (LongTensor (python/list [1]))) encoding)

not sure why I didn't think of that earlier.