phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
https://uvadlc-notebooks.readthedocs.io/en/latest/
MIT License
2.59k stars 590 forks source link

dataloder issues with jax tutoiral 9 #89

Closed murphyk closed 1 year ago

murphyk commented 1 year ago

Tutorial: -1 (Fill-in number of tutorial)

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html

I have to set num_workers=1 for the pytorch dataloaders, otherwise the code that comptues embeddings (used in https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html#Finding-visually-similar-images) fails on GPU colab.

Also, I had to comment out jax.jit in the encode funtion to avoid error 'flax + jax dont mix'.

def embed_imgs(trainer, data_loader):
    # Encode all images in the data_laoder using model, and return both images and encodings
    img_list, embed_list = [], []

    #@jax.jit
    def encode(imgs):
        return trainer.model_bd.encoder(imgs)

    for imgs, _ in data_loader:
    #for imgs, _ in tqdm(data_loader, desc="Encoding images", leave=False):
        z = encode(imgs)
        z = jax.device_get(z)
        imgs = jax.device_get(imgs)
        img_list.append(imgs)
        embed_list.append(z)
    return (np.concatenate(img_list, axis=0), np.concatenate(embed_list, axis=0))
aoibhinncrtai commented 1 year ago

On 20 Mar 2023, at 23:38, Kevin P Murphy @.***> wrote:

 EXTERNAL EMAIL: This email originated outside the University of Galway. Do not open attachments or click on links unless you believe the content is safe. RÍOMHPHOST SEACHTRACH: Níor tháinig an ríomhphost seo ó Ollscoil na Gaillimhe. Ná hoscail ceangaltáin agus ná cliceáil ar naisc mura gcreideann tú go bhfuil an t-ábhar sábháilte.

Tutorial: -1 (Fill-in number of tutorial)

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html

I have to set num_workers=1 for the pytorch dataloaders, otherwise the code that comptues embeddings (used in https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html#Finding-visually-similar-images) fails on GPU colab.

Also, I had to comment out jax.jit in the encode funtion to avoid error 'flax + jax dont mix'.

def embed_imgs(trainer, data_loader):

Encode all images in the data_laoder using model, and return both images and encodings

img_list, embed_list = [], []

***@***.***
def encode(imgs):
    return trainer.model_bd.encoder(imgs)

for imgs, _ in data_loader:
#for imgs, _ in tqdm(data_loader, desc="Encoding images", leave=False):
    z = encode(imgs)
    z = jax.device_get(z)
    imgs = jax.device_get(imgs)
    img_list.append(imgs)
    embed_list.append(z)
return (np.concatenate(img_list, axis=0), np.concatenate(embed_list, axis=0))

— Reply to this email directly, view it on GitHubhttps://github.com/phlippe/uvadlc_notebooks/issues/89, or unsubscribehttps://github.com/notifications/unsubscribe-auth/A3X4YFDJ5COFUO6C2U665O3W5DS5LANCNFSM6AAAAAAWBXU7RI. You are receiving this because you are subscribed to this thread.Message ID: @.***>

phlippe commented 1 year ago

Thanks! I fixed the jit-flax issue in 5f7828ea71d43862554882106cf9c3b36b44ab88, but I wasn't able to reproduce the issue with the dataloaders. Did you check a clean run-through of the notebook after fixing the jit-flax issue? I noticed that sometimes the PyTorch data loaders can have issues with JAX when it has multiple workers and there is a failed killing of the subprocesses. That is fixed by persistent_workers=True for training, but I didn't use it for test/valid since it usually not needed.

phlippe commented 1 year ago

Closing due to inactivity, feel free to reopen if error persists.