enpasos / muzero

Apache License 2.0
13 stars 2 forks source link

java.lang.IllegalStateException: Split can only be called on a batch containing a batchifier #1

Closed lukaszkn closed 2 years ago

lukaszkn commented 2 years ago

Hi @enpasos, I've tried to run the code MuZero-TicTacToe but it crashes instantly when trying to train:

[INFO ] 2021-09-28 18:08:08.821 [main] MyLoggingTrainingListener - Load PyTorch Engine Version 1.8.1 in 0.002 ms.
[DEBUG] 2021-09-28 18:08:10.430 [main] NetworkHelper - trainBatch 0
Exception in thread "main" java.lang.IllegalStateException: Split can only be called on a batch containing a batchifier
    at ai.djl.training.dataset.Batch.split(Batch.java:222)
    at ai.djl.training.dataset.Batch.split(Batch.java:196)
    at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:89)
    at ai.enpasos.muzero.agent.fast.model.djl.NetworkHelper.trainAndReturnNumberOfLastTrainingStep(NetworkHelper.java:95)
    at ai.enpasos.muzero.MuZero.run(MuZero.java:83)
    at ai.enpasos.muzero.MuZero.main(MuZero.java:44)

Thanks for any help

enpasos commented 2 years ago

Hi @lukaszkn, looks like a dependencies problem. Is it an option for you to use the dependencies

<dependency>
   <groupId>ai.djl.pytorch</groupId>
   <artifactId>pytorch-native-auto</artifactId>
   <version>1.9.0-SNAPSHOT</version>
</dependency>
<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>basicdataset</artifactId>
    <version>0.13.0-SNAPSHOT</version>
</dependency>

?

You have these dependencies if you just checkout the master branch, build with

mvn clean install -Dmaven.test.skip=true

and run with

mvn exec:java@train

or

java -jar ./target/muzero-0.2.0-SNAPSHOT-jar-with-dependencies.jar
lukaszkn commented 2 years ago

Thanks. 1.9.0-SNAPSHOT doesn't seem to exists or maybe I'm missing something. I've tried with 1.9.0 but still the same java.lang.IllegalStateException is raised.

mvn clean install -Dmaven.test.skip=true
[INFO] Scanning for projects...
[WARNING]
[WARNING] Some problems were encountered while building the effective model for com.enpasos.muzero:muzero:jar:0.2.0-SNAPSHOT
[WARNING] 'dependencies.dependency.(groupId:artifactId:type:classifier)' must be unique: org.apache.commons:commons-csv:jar -> duplicate declaration of version 1.8 @ line 37, column 21
[WARNING] 'dependencies.dependency.(groupId:artifactId:type:classifier)' must be unique: org.apache.commons:commons-csv:jar -> duplicate declaration of version 1.8 @ line 42, column 21
[WARNING] 'dependencies.dependency.version' for org.jetbrains:annotations:jar is either LATEST or RELEASE (both of them are being deprecated) @ line 197, column 22
[WARNING]
[WARNING] It is highly recommended to fix these problems because they threaten the stability of your build.
[WARNING]
[WARNING] For this reason, future Maven versions might no longer support building such malformed projects.
[WARNING]
[INFO]
[INFO] ---------------------< com.enpasos.muzero:muzero >----------------------
[INFO] Building muzero 0.2.0-SNAPSHOT
[INFO] --------------------------------[ jar ]---------------------------------
[WARNING] The POM for ai.djl.pytorch:pytorch-native-auto:jar:1.9.0-SNAPSHOT is missing, no dependency information available
[INFO] ------------------------------------------------------------------------
[INFO] BUILD FAILURE
[INFO] ------------------------------------------------------------------------
[INFO] Total time:  0.563 s
[INFO] Finished at: 2021-09-29T14:14:10+02:00
[INFO] ------------------------------------------------------------------------
[ERROR] Failed to execute goal on project muzero: Could not resolve dependencies for project 
com.enpasos.muzero:muzero:jar:0.2.0-SNAPSHOT: ai.djl.pytorch:pytorch-native-auto:jar:1.9.0-SNAPSHOT was not
found in https://oss.sonatype.org/content/repositories/snapshots/ during a previous attempt. This failure was
cached in the local repository and resolution is not reattempted until the update interval of djl.ai has elapsed
or updates are forced
enpasos commented 2 years ago

Your are right ... checked it ... the dependency was coming from my local maven-repo ... I fixed and pushed it ...

But I understand your problem is still there. I assume you have just checked out, corrected the dependency bug you found, built, started and bang. Correct? Could you give me more information about the stack you are running on (os, gpu, cuda). And the logging info from the start ... somehow like this

OS: Windows 10 GPU: Quadro RTX 5000 CUDA: 11.1

LOG:

E:\public\muzero>java -jar ./target/muzero-0.2.0-SNAPSHOT-jar-with-dependencies.jar
WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance.
[DEBUG] 2021-09-29 17:30:31.263 [main] Engine - Found EngineProvider: PyTorch
[DEBUG] 2021-09-29 17:30:31.265 [main] Engine - Found default engine: PyTorch
[DEBUG] 2021-09-29 17:30:31.273 [main] CudaUtils - Found cudart: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\bin\cudart64_110.dll
[DEBUG] 2021-09-29 17:30:32.885 [main] LibUtils - Using cache dir: C:\Users\matth\.djl.ai\pytorch
[DEBUG] 2021-09-29 17:30:32.886 [main] LibUtils - Loading pytorch library from: C:\Users\matth\.djl.ai\pytorch\1.9.0-cu111-win-x86_64\0.13.0-SNAPSHOT-cu111-djl_torch.dll
[INFO ] 2021-09-29 17:30:33.053 [main] PtEngine - Number of inter-op threads is 8
[INFO ] 2021-09-29 17:30:33.065 [main] PtEngine - Number of intra-op threads is 16
[DEBUG] 2021-09-29 17:30:33.068 [main] Helper - Name: ptModel Parent Name: f195c9f2-e171-4219-88e9-503ead0d02da isOpen: true Resource size: 0
[DEBUG] 2021-09-29 17:30:33.081 [main] Helper - Name: f195c9f2-e171-4219-88e9-503ead0d02da Parent Name: No Parent isOpen: true Resource size: 0
\--- NDManager(503ead0d02da) resource count: 0
[DEBUG] 2021-09-29 17:30:33.119 [main] BaseModel - Try to load model from E:\public\muzero\.\memory\tictactoe\networks\MuZero-TicTacToe-0032.params
[DEBUG] 2021-09-29 17:30:33.146 [main] BaseModel - Loading saved model: MuZero-TicTacToe parameter
[DEBUG] 2021-09-29 17:30:34.823 [main] BaseModel - DJL model loaded successfully
[INFO ] 2021-09-29 17:30:34.825 [main] NetworkHelper - k=0: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:34.849 [main] NetworkHelper - k=1: L2Loss
[INFO ] 2021-09-29 17:30:34.878 [main] NetworkHelper - k=2: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:34.883 [main] NetworkHelper - k=3: L2Loss
[INFO ] 2021-09-29 17:30:34.909 [main] NetworkHelper - k=4: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:34.935 [main] NetworkHelper - k=5: L2Loss
[INFO ] 2021-09-29 17:30:34.959 [main] NetworkHelper - k=6: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:34.982 [main] NetworkHelper - k=7: L2Loss
[INFO ] 2021-09-29 17:30:34.985 [main] NetworkHelper - k=8: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:35.007 [main] NetworkHelper - k=9: L2Loss
[INFO ] 2021-09-29 17:30:35.010 [main] NetworkHelper - k=10: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:35.031 [main] NetworkHelper - k=11: L2Loss
[INFO ] 2021-09-29 17:30:35.063 [main] MyLoggingTrainingListener - Training on: 1 GPUs.
[INFO ] 2021-09-29 17:30:35.064 [main] MyLoggingTrainingListener - Load PyTorch Engine Version 1.9.0 in 0,014 ms.
[DEBUG] 2021-09-29 17:30:35.125 [main] Helper - Name: trainer Parent Name: ptModel isOpen: true Resource size: 0
[DEBUG] 2021-09-29 17:30:35.125 [main] Helper - Name: ptModel Parent Name: f195c9f2-e171-4219-88e9-503ead0d02da isOpen: true Resource size: 185
[DEBUG] 2021-09-29 17:30:35.130 [main] Helper - Name: f195c9f2-e171-4219-88e9-503ead0d02da Parent Name: No Parent isOpen: true Resource size: 0
\--- NDManager(503ead0d02da) resource count: 0
loading ... ./memory/tictactoe/games/buffer72000
[DEBUG] 2021-09-29 17:30:35.606 [main] Helper - Name: ptModel Parent Name: f195c9f2-e171-4219-88e9-503ead0d02da isOpen: true Resource size: 0
[DEBUG] 2021-09-29 17:30:35.606 [main] Helper - Name: f195c9f2-e171-4219-88e9-503ead0d02da Parent Name: No Parent isOpen: true Resource size: 0
\--- NDManager(503ead0d02da) resource count: 0
[DEBUG] 2021-09-29 17:30:35.632 [main] BaseModel - Try to load model from E:\public\muzero\.\memory\tictactoe\networks\MuZero-TicTacToe-0032.params
[DEBUG] 2021-09-29 17:30:35.637 [main] BaseModel - Loading saved model: MuZero-TicTacToe parameter
[DEBUG] 2021-09-29 17:30:35.755 [main] BaseModel - DJL model loaded successfully
[INFO ] 2021-09-29 17:30:35.755 [main] NetworkHelper - k=0: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:35.779 [main] NetworkHelper - k=1: L2Loss
[INFO ] 2021-09-29 17:30:35.806 [main] NetworkHelper - k=2: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:35.834 [main] NetworkHelper - k=3: L2Loss
[INFO ] 2021-09-29 17:30:35.861 [main] NetworkHelper - k=4: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:35.890 [main] NetworkHelper - k=5: L2Loss
[INFO ] 2021-09-29 17:30:35.918 [main] NetworkHelper - k=6: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:35.945 [main] NetworkHelper - k=7: L2Loss
[INFO ] 2021-09-29 17:30:35.973 [main] NetworkHelper - k=8: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:36.001 [main] NetworkHelper - k=9: L2Loss
[INFO ] 2021-09-29 17:30:36.028 [main] NetworkHelper - k=10: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 17:30:36.056 [main] NetworkHelper - k=11: L2Loss
[INFO ] 2021-09-29 17:30:36.060 [main] MyLoggingTrainingListener - Training on: 1 GPUs.
[INFO ] 2021-09-29 17:30:36.083 [main] MyLoggingTrainingListener - Load PyTorch Engine Version 1.9.0 in 0,004 ms.
[DEBUG] 2021-09-29 17:30:36.471 [main] NetworkHelper - trainBatch 0
[DEBUG] 2021-09-29 17:30:39.576 [main] NetworkHelper - trainBatch 1
[DEBUG] 2021-09-29 17:30:41.232 [main] NetworkHelper - trainBatch 2
[DEBUG] 2021-09-29 17:30:42.895 [main] NetworkHelper - trainBatch 3
[DEBUG] 2021-09-29 17:30:44.535 [main] NetworkHelper - trainBatch 4
[DEBUG] 2021-09-29 17:30:46.185 [main] NetworkHelper - trainBatch 5
lukaszkn commented 2 years ago

It's more or less as you described. It crashes shortly after printing this line: NetworkHelper - trainBatch 0

I've just added catch in NetworkHelper.java as it crashes silently:

try (Batch batch = getBatch(config, model.getNDManager(), replayBuffer, withSymmetryEnrichment)) {
    log.debug("trainBatch " + m);
    EasyTrain.trainBatch(trainer, batch);
    trainer.step();
}
catch (Exception ex)
{
    log.error(ex.toString() + "\n" + ExceptionUtils.getStackTrace(ex));
}

I wouldn't want to waste your time so don't worry too much.

OS: Windows 10 GPU: 5x GTX 1070 CUDA: 11.1

LOG:

[DEBUG] 2021-09-29 22:15:25.752 [main] Engine - Found EngineProvider: PyTorch
[DEBUG] 2021-09-29 22:15:25.752 [main] Engine - Found default engine: PyTorch
[DEBUG] 2021-09-29 22:15:25.768 [main] CudaUtils - Found cudart: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\bin\cudart64_110.dll
[DEBUG] 2021-09-29 22:15:25.908 [main] LibUtils - Using cache dir: C:\Users\lynnx\.djl.ai\pytorch
[DEBUG] 2021-09-29 22:15:25.908 [main] LibUtils - Loading pytorch library from: C:\Users\lynnx\.djl.ai\pytorch\1.9.0-cu111-win-x86_64\0.13.0-SNAPSHOT-cu111-djl_torch.dll
[INFO ] 2021-09-29 22:15:26.174 [main] PtEngine - Number of inter-op threads is 2
[INFO ] 2021-09-29 22:15:26.174 [main] PtEngine - Number of intra-op threads is 4
[DEBUG] 2021-09-29 22:15:26.174 [main] Helper - Name: ptModel Parent Name: d02450a7-7b23-42e1-9672-11c223a1b39f isOpen: true Resource size: 0
[DEBUG] 2021-09-29 22:15:26.174 [main] Helper - Name: d02450a7-7b23-42e1-9672-11c223a1b39f Parent Name: No Parent isOpen: true Resource size: 0
\--- NDManager(11c223a1b39f) resource count: 0
[DEBUG] 2021-09-29 22:15:26.189 [main] BaseModel - Try to load model from C:\Projects\muzero\.\memory\tictactoe\networks\MuZero-TicTacToe-0000.params
[DEBUG] 2021-09-29 22:15:26.189 [main] BaseModel - Loading saved model: MuZero-TicTacToe parameter
[DEBUG] 2021-09-29 22:15:30.209 [main] BaseModel - DJL model loaded successfully
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=0: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=1: L2Loss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=2: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=3: L2Loss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=4: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=5: L2Loss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=6: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=7: L2Loss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=8: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=9: L2Loss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=10: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:30.209 [main] NetworkHelper - k=11: L2Loss
[INFO ] 2021-09-29 22:15:30.225 [main] MyLoggingTrainingListener - Training on: 5 GPUs.
[INFO ] 2021-09-29 22:15:30.225 [main] MyLoggingTrainingListener - Load PyTorch Engine Version 1.9.0 in 0.024 ms.
[DEBUG] 2021-09-29 22:15:32.051 [main] Helper - Name: trainer Parent Name: ptModel isOpen: true Resource size: 736
[DEBUG] 2021-09-29 22:15:32.051 [main] Helper - Name: ptModel Parent Name: d02450a7-7b23-42e1-9672-11c223a1b39f isOpen: true Resource size: 921
[DEBUG] 2021-09-29 22:15:32.051 [main] Helper - Name: d02450a7-7b23-42e1-9672-11c223a1b39f Parent Name: No Parent isOpen: true Resource size: 0
\--- NDManager(11c223a1b39f) resource count: 0
loading ... ./memory/tictactoe/games/buffer10000
[DEBUG] 2021-09-29 22:15:38.225 [main] Helper - Name: ptModel Parent Name: d02450a7-7b23-42e1-9672-11c223a1b39f isOpen: true Resource size: 0
[DEBUG] 2021-09-29 22:15:38.225 [main] Helper - Name: d02450a7-7b23-42e1-9672-11c223a1b39f Parent Name: No Parent isOpen: true Resource size: 0
\--- NDManager(11c223a1b39f) resource count: 0
[DEBUG] 2021-09-29 22:15:38.225 [main] BaseModel - Try to load model from C:\Projects\muzero\.\memory\tictactoe\networks\MuZero-TicTacToe-0000.params
[DEBUG] 2021-09-29 22:15:38.241 [main] BaseModel - Loading saved model: MuZero-TicTacToe parameter
[DEBUG] 2021-09-29 22:15:38.678 [main] BaseModel - DJL model loaded successfully
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=0: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=1: L2Loss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=2: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=3: L2Loss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=4: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=5: L2Loss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=6: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=7: L2Loss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=8: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=9: L2Loss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=10: SoftmaxCrossEntropyLoss
[INFO ] 2021-09-29 22:15:38.678 [main] NetworkHelper - k=11: L2Loss
[INFO ] 2021-09-29 22:15:38.678 [main] MyLoggingTrainingListener - Training on: 5 GPUs.
[INFO ] 2021-09-29 22:15:38.678 [main] MyLoggingTrainingListener - Load PyTorch Engine Version 1.9.0 in 0.002 ms.
[DEBUG] 2021-09-29 22:15:40.194 [main] NetworkHelper - trainBatch 0
[ERROR] 2021-09-29 22:15:40.209 [main] NetworkHelper - java.lang.IllegalStateException: Split can only be called on a batch containing a batchifier
java.lang.IllegalStateException: Split can only be called on a batch containing a batchifier
    at ai.djl.training.dataset.Batch.split(Batch.java:222)
    at ai.djl.training.dataset.Batch.split(Batch.java:196)
    at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:89)
    at ai.enpasos.muzero.agent.fast.model.djl.NetworkHelper.trainAndReturnNumberOfLastTrainingStep(NetworkHelper.java:96)
    at ai.enpasos.muzero.MuZero.run(MuZero.java:83)
    at ai.enpasos.muzero.MuZero.main(MuZero.java:44)
enpasos commented 2 years ago

Now, I think I see your problem. It's about parallelization. What about starting with 1 GPU (and later if you like switch to more). When you switch to 1 GTX 1070 the next issue you will likely run into is the GPU memory. GTX 1070 has 8 GB. At the moment the configuration checked in uses about 4 GB in training phase and 15 GB in playing phase. To get the playing phase running on the GTX 1070 you need to reduce games played in parallel by changing lines 165 and 166 in MuZeroConfig to

                .numParallelPlays(500)
                .numPlays(4)
lukaszkn commented 2 years ago

Yes, wow it works! Thanks for your help.

I've added 1 line in NetworkHelper

return new DefaultTrainingConfig(loss)

    .optDevices(Engine.getInstance().getDevices(1))

It still crashes when maxGpu set to > 1 but it's a start. It would be nice to be able to use all gpus in the future.

enpasos commented 2 years ago

Congrats! I pushed your contraint to train only with one GPU.

In case you want to use multiple GPUs in parallel here ... suggested steps:

  1. At the moment training (batches) and playing (batches) is done sequentially on one GPU. Doing training and playing in parallel should be quite simple on two GPUs.
  2. Doing playing in parallel on multiple GPUs should almost come with 1.
  3. Doing training in parallel on multiple GPUs (where you ran into the problem) one has to look at the way the batch training is done here. I left the normal DJL implementation and put the batch into one tensor - It looked to me much faster (needs to be verified).

... going for a mountain bike tour ... I am offline until monday