When I try to run inference on pre-trained embeddings while training NLP models, I see NullPointerException as Predictor is not designed to work with multiple devices.
Error Message
[INFO ] - Load MXNet Engine Version 1.7.0 in 0.181 ms.
[INFO ] - forward P50: 3.519 ms, P90: 3.519 ms
[INFO ] - training-metrics P50: 0.048 ms, P90: 0.048 ms
[INFO ] - backward P50: 1.721 ms, P90: 1.721 ms
Exception in thread "main" java.lang.NullPointerException
at ai.djl.training.ParameterStore.getValue(ParameterStore.java:105)
at ai.djl.nn.core.Embedding.opInputs(Embedding.java:257)
at ai.djl.nn.core.Embedding.forward(Embedding.java:162)
at ai.djl.nn.Block.forward(Block.java:118)
at ai.djl.inference.Predictor.predict(Predictor.java:117)
at ai.djl.inference.Predictor.batchPredict(Predictor.java:144)
at ai.djl.inference.Predictor.predict(Predictor.java:112)
at ai.djl.modality.nlp.embedding.ModelZooTextEmbedding.embedText(ModelZooTextEmbedding.java:57)
at ai.djl.examples.training.TrainSentimentAnalysis$EmbeddingDataManager.getData(TrainSentimentAnalysis.java:277)
at ai.djl.training.Trainer.trainBatch(Trainer.java:159)
at ai.djl.examples.training.util.TrainingUtils.fit(TrainingUtils.java:36)
at ai.djl.examples.training.TrainSentimentAnalysis.runExample(TrainSentimentAnalysis.java:133)
at ai.djl.examples.training.TrainSentimentAnalysis.main(TrainSentimentAnalysis.java:89)
Suppressed: java.lang.IllegalArgumentException: Metric name not found: step
at ai.djl.metric.Metrics.percentile(Metrics.java:135)
at ai.djl.training.listener.LoggingTrainingListener.onTrainingEnd(LoggingTrainingListener.java:167)
at ai.djl.training.Trainer.lambda$close$5(Trainer.java:349)
at java.util.ArrayList.forEach(ArrayList.java:1257)
at ai.djl.training.Trainer.close(Trainer.java:349)
at ai.djl.examples.training.TrainSentimentAnalysis.runExample(TrainSentimentAnalysis.java:159)
... 1 more
Description
When I try to run inference on pre-trained embeddings while training NLP models, I see NullPointerException as Predictor is not designed to work with multiple devices.
Error Message