Closed murphyk closed 1 year ago
On 20 Mar 2023, at 23:38, Kevin P Murphy @.***> wrote:
EXTERNAL EMAIL: This email originated outside the University of Galway. Do not open attachments or click on links unless you believe the content is safe. RÍOMHPHOST SEACHTRACH: Níor tháinig an ríomhphost seo ó Ollscoil na Gaillimhe. Ná hoscail ceangaltáin agus ná cliceáil ar naisc mura gcreideann tú go bhfuil an t-ábhar sábháilte.
Tutorial: -1 (Fill-in number of tutorial)
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html
I have to set num_workers=1 for the pytorch dataloaders, otherwise the code that comptues embeddings (used in https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html#Finding-visually-similar-images) fails on GPU colab.
Also, I had to comment out jax.jit in the encode funtion to avoid error 'flax + jax dont mix'.
def embed_imgs(trainer, data_loader):
img_list, embed_list = [], []
***@***.***
def encode(imgs):
return trainer.model_bd.encoder(imgs)
for imgs, _ in data_loader:
#for imgs, _ in tqdm(data_loader, desc="Encoding images", leave=False):
z = encode(imgs)
z = jax.device_get(z)
imgs = jax.device_get(imgs)
img_list.append(imgs)
embed_list.append(z)
return (np.concatenate(img_list, axis=0), np.concatenate(embed_list, axis=0))
— Reply to this email directly, view it on GitHubhttps://github.com/phlippe/uvadlc_notebooks/issues/89, or unsubscribehttps://github.com/notifications/unsubscribe-auth/A3X4YFDJ5COFUO6C2U665O3W5DS5LANCNFSM6AAAAAAWBXU7RI. You are receiving this because you are subscribed to this thread.Message ID: @.***>
Thanks! I fixed the jit-flax issue in 5f7828ea71d43862554882106cf9c3b36b44ab88, but I wasn't able to reproduce the issue with the dataloaders. Did you check a clean run-through of the notebook after fixing the jit-flax issue? I noticed that sometimes the PyTorch data loaders can have issues with JAX when it has multiple workers and there is a failed killing of the subprocesses. That is fixed by persistent_workers=True
for training, but I didn't use it for test/valid since it usually not needed.
Closing due to inactivity, feel free to reopen if error persists.
Tutorial: -1 (Fill-in number of tutorial)
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html
I have to set
num_workers=1
for the pytorch dataloaders, otherwise the code that comptues embeddings (used in https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html#Finding-visually-similar-images) fails on GPU colab.Also, I had to comment out
jax.jit
in theencode
funtion to avoid error 'flax + jax dont mix'.