deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.07k stars 650 forks source link

Pytorch EngineException: Inference tensors cannot be saved for backward #2159

Closed aminnasiri closed 1 year ago

aminnasiri commented 1 year ago

Description

I developed an example of Rank Classification using BERT on Amazon Review dataset. This guide and It worked fine with Apache MXNET engine, but it is throwing EngineException on Pytorch engine.

Expected Behavior

I am expecting to see the Pytorch engine is working fine too.

Error Message

Error: ai.djl.engine.EngineException: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd. at ai.djl.pytorch.jni.PyTorchLibrary.torchNNLinear(Native Method) at ai.djl.pytorch.jni.JniUtils.linear(JniUtils.java:1189) at ai.djl.pytorch.engine.PtNDArrayEx.linear(PtNDArrayEx.java:390) at ai.djl.nn.core.Linear.linear(Linear.java:183) at ai.djl.nn.core.Linear.forwardInternal(Linear.java:88) at ai.djl.nn.AbstractBaseBlock.forwardInternal(AbstractBaseBlock.java:126) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91) at ai.djl.nn.SequentialBlock.forwardInternal(SequentialBlock.java:209) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91) at ai.djl.training.Trainer.forward(Trainer.java:175) at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:122) at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110) at ai.djl.training.EasyTrain.fit(EasyTrain.java:58) at com.thinksky.classification.TrainModel.trainModel(TrainModel.java:96) at com.thinksky.classification.TrainModel_ClientProxy.trainModel(Unknown Source) at com.thinksky.PredictResource.executeQuery(PredictResource.java:104) at com.thinksky.PredictResource_VertxInvoker_executeQuery_315faff7c0f6c7728fd1c92cfb1b39aa7f024059.invokeBean(Unknown Source) at io.quarkus.vertx.runtime.EventConsumerInvoker.invoke(EventConsumerInvoker.java:41) at io.quarkus.vertx.runtime.VertxRecorder$3$1.handle(VertxRecorder.java:135) at io.quarkus.vertx.runtime.VertxRecorder$3$1.handle(VertxRecorder.java:105) at io.vertx.core.impl.ContextInternal.dispatch(ContextInternal.java:264) at io.vertx.core.eventbus.impl.MessageConsumerImpl.dispatch(MessageConsumerImpl.java:177) at io.vertx.core.eventbus.impl.HandlerRegistration$InboundDeliveryContext.execute(HandlerRegistration.java:137) at io.vertx.core.eventbus.impl.DeliveryContextBase.next(DeliveryContextBase.java:72) at io.vertx.core.eventbus.impl.DeliveryContextBase.dispatch(DeliveryContextBase.java:43) at io.vertx.core.eventbus.impl.HandlerRegistration.dispatch(HandlerRegistration.java:98) at io.vertx.core.eventbus.impl.MessageConsumerImpl.deliver(MessageConsumerImpl.java:183) at io.vertx.core.eventbus.impl.MessageConsumerImpl.doReceive(MessageConsumerImpl.java:168) at io.vertx.core.eventbus.impl.HandlerRegistration.lambda$receive$0(HandlerRegistration.java:49) at io.netty.util.concurrent.AbstractEventExecutor.runTask(AbstractEventExecutor.java:174) at io.netty.util.concurrent.AbstractEventExecutor.safeExecute(AbstractEventExecutor.java:167) at io.netty.util.concurrent.SingleThreadEventExecutor.runAllTasks(SingleThreadEventExecutor.java:470) at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:569) at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997) at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74) at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30) at java.base/java.lang.Thread.run(Thread.java:1589)

How to Reproduce?

This is kind of the same project

Steps to reproduce

  1. Run the project: mvn quarkus:dev
  2. Call training endpoint: curl -X GET http://localhost:8080/predict/model

What have you tried to solve it?

Set these properties

Environment Info

OS: Macos JDK: Java 19

Dependecies

<dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-native-cpu</artifactId>
      <classifier>osx-x86_64</classifier>
      <version>1.12.1</version>
      <scope>runtime</scope>
</dependency>
<dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-jni</artifactId>
      <version>1.12.1-0.19.0</version>
      <scope>runtime</scope>
</dependency>
frankfliu commented 1 year ago

I think you hit the same issue as: https://github.com/deepjavalibrary/djl/issues/2144

KexinFeng commented 1 year ago

This error indicates that the backward propogation is affecting the parameters, which is protected in the inference mode and cannot be updated. Look at this issue https://github.com/deepjavalibrary/djl/issues/2144. Consider the solution there.

Or, with PyTorch engine, you can also look at this transfer learning example https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java for training an embedding block.

aminnasiri commented 1 year ago

Thanks @frankfliu & @KexinFeng. This approach is working fine for this issue, so I'm gonna close it. https://github.com/deepjavalibrary/djl/issues/2144#issuecomment-1309112693