Open never-to-never opened 1 year ago
1) Please paste your code as opposed to attaching it as a file, especially if the code is short.
2) Why are you using an RNN?
3) From looking at your code, it doesn't seem like you're really using the time component. Are you sure that in your preprocessing you're replicating the data over the time axis? A
import haiku as hk
import jax
import jax.numpy as jnp
import optax
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
sequence_length = 28
input_size = 28
hidden_size = 128
num_classes = 10
batch_size = 128
num_epochs = 30
learning_rate = 0.001
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
def unroll_net(seqs: jax.Array):
core = hk.LSTM(128)
batch_size = seqs.shape[1]
outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
return hk.Linear(10)(outs[-1]), state
model = hk.transform(unroll_net)
rng = jax.random.PRNGKey(428)
opt = optax.adam(1e-3)
@jax.jit
def loss(params, x, y):
pred, _ = model.apply(params, None, x)
return jnp.mean(jnp.square(pred - y))
@jax.jit
def accuracy(predict, target):
return jnp.sum(jnp.argmax(predict, axis=1) == jnp.argmax(target, axis=1))
@jax.jit
def update(step, params, opt_state, x, y):
l, grads = jax.value_and_grad(loss)(params, x, y)
grads, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, grads)
return l, params, opt_state
train_ds = iter(train_loader)
valid_ds = iter(test_loader)
sample_x, _ = next(train_ds)
sample_x = sample_x.reshape(sequence_length, -1, input_size)
sample_x = jnp.asarray(sample_x)
params = model.init(rng, sample_x)
opt_state = opt.init(params)
length = len(train_ds)
for step in range(length-1):
if step % 10 == 0:
x, y = next(valid_ds)
x = x.reshape(sequence_length, -1, input_size)
x = jnp.asarray(x)
y = jnp.asarray(y)
y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
print("Step {}: valid loss {}".format(step, loss(params, x, y)))
x, y = next(train_ds)
x = x.reshape(sequence_length, -1, input_size)
x = jnp.asarray(x)
y = jnp.asarray(y)
y = jnp.array(y[:, None] == jnp.arange(10), jnp.float32)
train_loss, params, opt_state = update(step, params, opt_state, x, y)
if step % 10 == 0:
print("Step {}: train loss {}".format(step, train_loss))
Here is the full code.
test.txt I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs
The error in the code is likely due to the mismatch between PyTorch tensors and JAX arrays.
The train_loader
and test_loader
provide PyTorch tensors, while the model and loss functions expect JAX arrays. You need to convert the PyTorch tensors to JAX arrays before passing them to the model and loss functions. Use
jnp.array()
to convert PyTorch tensors to JAX arrays.
test.txt I use LSTM to classify mnist data and find that the loss of the network cannot converge at all. Is the RNN given by the framework correct? I give the script that runs