-
-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pytorch lightning Fail #263
Comments
It appears to me that this happens when you pass in a function that doesn't have its source code thus the autodiff system can't autodiff. Keep in mind that functions defined via clojure are translated to python as opaque C function pointers so they won't be autodifferentiable. |
@cnuernber Thank you for the awesome project!
(ns sample.core
(:require
[libpython-clj2.require :refer [require-python]]
[libpython-clj2.python :as py :refer [py. py.- py.. py* py**]]
))
(py/initialize! :python-executable "/Users/tani/Documents/libpython-clj/.venv/bin/python")
(require-python
'os
'[torch.nn :refer [Linear ReLU Sequential]]
'[torch.nn.functional :refer [mse_loss]]
'[torch.optim :as o :refer [Adam]]
'[torch.utils.data :refer [DataLoader]]
'[lightning :refer [LightningModule Trainer]]
'[torchvision.datasets :refer [MNIST]]
'[torchvision.transforms :refer [ToTensor]])
(def encoder (Sequential (Linear (* 28 28) 64) (ReLU) (Linear 64 3)))
(def decoder (Sequential (Linear 64 3) (ReLU) (Linear (* 28 28) 64)))
(def LitModel
(py/create-class
"LitModel" [LightningModule]
{"__init__"
(py/make-tuple-instance-fn
(fn [self encoder decoder]
(py. LightningModule __init__ self)
(py/set-attr! self "encoder" encoder)
(py/set-attr! self "decoder" decoder)
nil))
"training_step"
(py/make-tuple-instance-fn
(fn [self batch batch_idx]
(let [head (py/get-item batch 0)
x (py. head view (py. head size 0) -1)
z (py. self encoder x)
xhat (py. self decoder z)
loss (mse_loss xhat x)]
loss)))
"configure_optimizers"
(py/make-tuple-instance-fn
(fn [self]
(py** Adam (py. self parameters) {:lr 0.0001})))}))
(def model (LitModel encoder decoder))
(def dataset (MNIST (os/getcwd) :download true :transform (ToTensor)))
(def train_loader (DataLoader dataset))
(def trainer (Trainer :limit_train_batches 2 :max_epochs 1 :logger []))
(py. trainer fit (py/as-python model) train_loader) |
I asked grok chatbot about this. It told me to create a subclass of
Another way is to replace pytorch lightning with pytorch. |
Following code stop with
The text was updated successfully, but these errors were encountered: