no warning,
make the LSTM parameter to contiguous chunk of memory.
Error Message
[W RNN.cpp:982] Warning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted
at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (function _cudnn_rnn)
How to Reproduce?
Steps to reproduce
(Paste the commands you ran that produced the error.)
public class Main {
public static void main(String[] args) throws IOException, TranslateException {
System.out.println("Hello world!");
final var model = Model.newInstance("tempmodel");
model.setBlock(LSTM.builder().setNumLayers(1).setStateSize(2).build());
try (var trainer = new Trainer(model, new DefaultTrainingConfig(Loss.l2Loss()).optOptimizer(Optimizer.adam().build()))) {
EasyTrain.fit(trainer, 1, new Dataset() {
@Override
public Iterable<Batch> getData(NDManager manager) throws IOException, TranslateException {
final var batchifier = new StackBatchifier();
final var input = manager.zeros(new Shape(1, 1, 1));
final var label = manager.ones(new Shape(1, 1, 2));
return List.of(new Batch(manager, new NDList(input), new NDList(label), 1, batchifier, batchifier, 0, 0));
}
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
}
}, null);
}
// SequentialBlock;
}
}
## What have you tried to solve it?
1. check the parameter structure of DJL. but overwhelmed.
## Environment Info
- Windows, pytorch-engine@0.26.0, pytorch-native-cu121@2.1.1, api@0.26.0
- Maven, pom.xml
Description
Warning messages from Pytorch's LSTM
Expected Behavior
no warning, make the LSTM parameter to contiguous chunk of memory.
Error Message
How to Reproduce?
Steps to reproduce
(Paste the commands you ran that produced the error.)
import java.io.IOException; import java.util.List;
import ai.djl.Model; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.recurrent.LSTM; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; import ai.djl.training.Trainer; import ai.djl.training.dataset.Batch; import ai.djl.training.dataset.Dataset; import ai.djl.training.loss.Loss; import ai.djl.training.optimizer.Optimizer; import ai.djl.translate.StackBatchifier; import ai.djl.translate.TranslateException; import ai.djl.util.Progress;
public class Main { public static void main(String[] args) throws IOException, TranslateException { System.out.println("Hello world!");
}