clj-python / libpython-clj

Python bindings for Clojure
Eclipse Public License 2.0
1.09k stars 68 forks source link

pytorch lightning Fail #263

Open niitsuma opened 5 months ago

niitsuma commented 5 months ago

Following code stop with

AttributeError: 'builtin_function_or_method' object has no attribute 'code'. Did you mean: 'call'?

(ns pytorchlightning.core
  (:gen-class)

  (:require 
   [libpython-clj2.python :as py
    :refer
    [ py. py.- 
     as-jvm
     set-attr!
     get-item
     ->py-tuple
     ->py-list
     ]]

   [libpython-clj2.require :refer [require-python]]
   ))

;;(py/initialize!)

(require-python
 '[builtins :as python]
 '[torch]
 '[torch.nn :as nn :refer [Linear]]
 '[torch.nn.functional :refer [mse_loss]]
 '[torch.utils.data :refer [DataLoader Dataset]]
 '[torch.optim]
 '[pytorch_lightning]
 )

(defonce model (atom nil))

(def LitModel
  (py/create-class
   "LitModel" [pytorch_lightning/LightningModule]
   {"__init__"
    (py/make-tuple-instance-fn
     (fn [self]
       (py. pytorch_lightning/LightningModule  __init__ self)
       (py/set-attr!  self "layer" (Linear 1 1))
       nil))

    "forward"
    (py/make-tuple-instance-fn
     (fn [self x] (py. self layer x))
     :arg-converter as-jvm
     :method-name "forward"
     )

    "training_step"
    (py/make-tuple-instance-fn
     (fn [self batch batch_idx]
       (mse_loss (py/get-item batch 1) (py. self forward (py/get-item batch 0)))
       ))
    "validation_step"
    (py/make-tuple-instance-fn
     (fn [self batch batch_idx]
       (mse_loss (py/get-item batch 1) (py. self forward (py/get-item batch 0)))
       ))
    "test_step"
    (py/make-tuple-instance-fn
     (fn [self batch batch_idx]
       (mse_loss (py/get-item batch 1) (py. self forward (py/get-item batch 0)))
       ))

    "configure_optimizers"
    (py/make-tuple-instance-fn
     (fn [self]
       (torch.optim/SGD
        (py. self parameters)
        :lr 0.02)
       ))
    }))

(def SimpleDataset
  (py/create-class
   "SimpleDataset" [Dataset]
   {"__init__"
    (py/make-tuple-instance-fn
     (fn [self data] (py/set-attr! self "data" data)  nil))
    "__len__"
    (py/make-tuple-instance-fn 
     (fn [self] (python/len (py.- self data))))
    "__getitem__"
    (py/make-tuple-instance-fn
     (fn [self idx]
       (py/->py-tuple
        [
         (torch/tensor [(py/get-item  (py/get-item (py.- self data) idx) 0)] :dtype torch/float32)
         (torch/tensor [(py/get-item  (py/get-item (py.- self data) idx) 1)] :dtype torch/float32)
         ]
        )
       ))
    }))

(def SimpleDataModule
  (py/create-class
   "SimpleDataModule" [pytorch_lightning/LightningDataModule]
   {"__init__"
    (py/make-tuple-instance-fn
     (fn [self data] (py/set-attr! self "data" data)  nil))

    "train_dataloader"
    (py/make-tuple-instance-fn
     (fn [self]
       (DataLoader (SimpleDataset (py.- self data )) :batch_size 2 :shuffle false)))
    "val_dataloader"
    (py/make-tuple-instance-fn
     (fn [self]
       (DataLoader (SimpleDataset (py.- self data )) :batch_size 2 :shuffle false)))
    "test_dataloader"
    (py/make-tuple-instance-fn
     (fn [self]
       (DataLoader (SimpleDataset (py.- self data )) :batch_size 2 :shuffle false)))
    }))

(defn -main [& args]

  (reset! model (LitModel))

  (def data [[1.0 3.0] [2.0 5.0] [3.0 7.0] [4.0 9.0] [5.0 11.0]] )
  (def data_pylist (py/->py-list data))

  ;; (println  (py/get-item  (py/get-item data_pylist 0) 0))

  (def train_dataset  (SimpleDataset data_pylist))
  (def train_loader (DataLoader train_dataset :batch_size 2 :shuffle false))

  ;; ;;for debug
  ;; (def train_pylist (python/list train_loader))
  ;; (def train_pylist_0 (py/get-item train_pylist 0))
  ;; (println (py. @model training_step train_pylist_0 0) );;;OK.  maybe this part work

  (def trainer (pytorch_lightning/Trainer :max_epochs 10 ))

  (def datamodu (SimpleDataModule data_pylist))

  (py. trainer fit @model train_loader)
  ;;(py. trainer fit @model datamodu)  ;;also fail

  )
cnuernber commented 2 months ago

It appears to me that this happens when you pass in a function that doesn't have its source code thus the autodiff system can't autodiff. Keep in mind that functions defined via clojure are translated to python as opaque C function pointers so they won't be autodifferentiable.