Closed aminnasiri closed 1 year ago
I think you hit the same issue as: https://github.com/deepjavalibrary/djl/issues/2144
This error indicates that the backward propogation is affecting the parameters, which is protected in the inference mode and cannot be updated. Look at this issue https://github.com/deepjavalibrary/djl/issues/2144. Consider the solution there.
Or, with PyTorch engine, you can also look at this transfer learning example https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java for training an embedding block.
Thanks @frankfliu & @KexinFeng. This approach is working fine for this issue, so I'm gonna close it. https://github.com/deepjavalibrary/djl/issues/2144#issuecomment-1309112693
Description
I developed an example of Rank Classification using BERT on Amazon Review dataset. This guide and It worked fine with Apache MXNET engine, but it is throwing EngineException on Pytorch engine.
Expected Behavior
I am expecting to see the Pytorch engine is working fine too.
Error Message
Error: ai.djl.engine.EngineException: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd. at ai.djl.pytorch.jni.PyTorchLibrary.torchNNLinear(Native Method) at ai.djl.pytorch.jni.JniUtils.linear(JniUtils.java:1189) at ai.djl.pytorch.engine.PtNDArrayEx.linear(PtNDArrayEx.java:390) at ai.djl.nn.core.Linear.linear(Linear.java:183) at ai.djl.nn.core.Linear.forwardInternal(Linear.java:88) at ai.djl.nn.AbstractBaseBlock.forwardInternal(AbstractBaseBlock.java:126) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91) at ai.djl.nn.SequentialBlock.forwardInternal(SequentialBlock.java:209) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91) at ai.djl.training.Trainer.forward(Trainer.java:175) at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:122) at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110) at ai.djl.training.EasyTrain.fit(EasyTrain.java:58) at com.thinksky.classification.TrainModel.trainModel(TrainModel.java:96) at com.thinksky.classification.TrainModel_ClientProxy.trainModel(Unknown Source) at com.thinksky.PredictResource.executeQuery(PredictResource.java:104) at com.thinksky.PredictResource_VertxInvoker_executeQuery_315faff7c0f6c7728fd1c92cfb1b39aa7f024059.invokeBean(Unknown Source) at io.quarkus.vertx.runtime.EventConsumerInvoker.invoke(EventConsumerInvoker.java:41) at io.quarkus.vertx.runtime.VertxRecorder$3$1.handle(VertxRecorder.java:135) at io.quarkus.vertx.runtime.VertxRecorder$3$1.handle(VertxRecorder.java:105) at io.vertx.core.impl.ContextInternal.dispatch(ContextInternal.java:264) at io.vertx.core.eventbus.impl.MessageConsumerImpl.dispatch(MessageConsumerImpl.java:177) at io.vertx.core.eventbus.impl.HandlerRegistration$InboundDeliveryContext.execute(HandlerRegistration.java:137) at io.vertx.core.eventbus.impl.DeliveryContextBase.next(DeliveryContextBase.java:72) at io.vertx.core.eventbus.impl.DeliveryContextBase.dispatch(DeliveryContextBase.java:43) at io.vertx.core.eventbus.impl.HandlerRegistration.dispatch(HandlerRegistration.java:98) at io.vertx.core.eventbus.impl.MessageConsumerImpl.deliver(MessageConsumerImpl.java:183) at io.vertx.core.eventbus.impl.MessageConsumerImpl.doReceive(MessageConsumerImpl.java:168) at io.vertx.core.eventbus.impl.HandlerRegistration.lambda$receive$0(HandlerRegistration.java:49) at io.netty.util.concurrent.AbstractEventExecutor.runTask(AbstractEventExecutor.java:174) at io.netty.util.concurrent.AbstractEventExecutor.safeExecute(AbstractEventExecutor.java:167) at io.netty.util.concurrent.SingleThreadEventExecutor.runAllTasks(SingleThreadEventExecutor.java:470) at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:569) at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997) at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74) at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30) at java.base/java.lang.Thread.run(Thread.java:1589)
How to Reproduce?
This is kind of the same project
Steps to reproduce
mvn quarkus:dev
curl -X GET http://localhost:8080/predict/model
What have you tried to solve it?
Set these properties
Environment Info
OS: Macos JDK: Java 19
Dependecies