Closed yangkl96 closed 12 months ago
@yangkl96
A quick look at your class, you should not use static NDManager, NDManager should be closed properly to release memory.
In the Translator you should use TraslatorContext
to NDManager instead of using a static one:
NDManager manager = ctx.getNDManager();
Thanks for the response. I do use TranslatorContext
to get an NDManager at the end of the processInput
:
NDArray matrixGlobal = matrixAll.concat(matrixPos.reshape(matrixPos.size()));
NDManager manager = translatorContext.getNDManager();
retList.add(manager.from(matrix));
retList.add(manager.from(matrixSum));
retList.add(manager.from(matrixGlobal));
retList.add(manager.from(matrixHC));
return retList;
However, this is for a TensorFlow model. Because I wanted to use the set method to modify the NDArrays, I used a PyTorch NDManager at the beginning to create the NDArrays, and then used the TensorFlow NDManager to convert it before return.
When I instantiate and close a PyTorch NDManager in processInput
, then the memory is freed up. Thank you!
Description
Using a Tensorflow model for a regression task. I am using PyTorch NDArrays because Tensorflow ones do not support get and set methods, which I need in my Translator class, but I convert them into TFndarrays before returning them. During inference, the heap space is filled by PtNDArrays until there's no more memory left.
As an aside, the set methods seem to be taking very long, to the point that running the model in Python is still faster than using DJL. But I am not sure if this is a related issue
Expected Behavior
Somewhat constant memory usage.
Error Message
(Paste the complete error message, including stack trace.)
How to Reproduce?
My custom Translator class: `public class DeepLCTranslator implements Translator<String, Float> { static NDManager ptManager = NDManager.newBaseManager("PyTorch");
My testing script: `public static void main(String[] args) throws IOException, MalformedModelException, ModelNotFoundException, TranslateException, ExecutionException, InterruptedException {
Steps to reproduce
(Paste the commands you ran that produced the error.)
1. 2.
What have you tried to solve it?
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below: