google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.84k stars 229 forks source link

rnn classifies mnist #645

Open never-to-never opened 1 year ago

never-to-never commented 1 year ago

test.txt I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs

IanQS commented 1 year ago

1) Please paste your code as opposed to attaching it as a file, especially if the code is short.

2) Why are you using an RNN?

3) From looking at your code, it doesn't seem like you're really using the time component. Are you sure that in your preprocessing you're replicating the data over the time axis? A

never-to-never commented 1 year ago
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

sequence_length = 28
input_size = 28
hidden_size = 128
num_classes = 10
batch_size = 128
num_epochs = 30
learning_rate = 0.001

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

def unroll_net(seqs: jax.Array):
    core = hk.LSTM(128)
    batch_size = seqs.shape[1]
    outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
    return hk.Linear(10)(outs[-1]), state

model = hk.transform(unroll_net)

rng = jax.random.PRNGKey(428)
opt = optax.adam(1e-3)

@jax.jit
def loss(params, x, y):
  pred, _ = model.apply(params, None, x)
  return jnp.mean(jnp.square(pred - y))

@jax.jit
def accuracy(predict, target):
    return jnp.sum(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))

@jax.jit
def update(step, params, opt_state, x, y):
    l, grads = jax.value_and_grad(loss)(params, x, y)
    grads, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, grads)
    return l, params, opt_state

train_ds = iter(train_loader)
valid_ds = iter(test_loader)
sample_x, _ = next(train_ds)
sample_x = sample_x.reshape(sequence_length, -1, input_size)
sample_x = jnp.asarray(sample_x)
params = model.init(rng, sample_x)
opt_state = opt.init(params)
length = len(train_ds)

for step in range(length-1):
    if step % 10 == 0:
        x, y = next(valid_ds)
        x = x.reshape(sequence_length, -1, input_size)
        x = jnp.asarray(x)
        y = jnp.asarray(y)
        y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
        print("Step {}: valid loss {}".format(step, loss(params, x, y)))
    x, y = next(train_ds)
    x = x.reshape(sequence_length, -1, input_size)
    x = jnp.asarray(x)
    y = jnp.asarray(y)
    y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
    train_loss, params, opt_state = update(step, params, opt_state, x, y)
    if step % 10 == 0:
        print("Step {}: train loss {}".format(step, train_loss))

Here is the full code.

Ekundayo39283 commented 2 months ago

test.txt I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs

The error in the code is likely due to the mismatch between PyTorch tensors and JAX arrays.

The train_loader and test_loader provide PyTorch tensors, while the model and loss functions expect JAX arrays. You need to convert the PyTorch tensors to JAX arrays before passing them to the model and loss functions. Use

jnp.array()

to convert PyTorch tensors to JAX arrays.