deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.04k stars 648 forks source link

During inference, heap space is filled by PTndarray until no more memory is available #2766

Closed yangkl96 closed 12 months ago

yangkl96 commented 12 months ago

Description

Using a Tensorflow model for a regression task. I am using PyTorch NDArrays because Tensorflow ones do not support get and set methods, which I need in my Translator class, but I convert them into TFndarrays before returning them. During inference, the heap space is filled by PtNDArrays until there's no more memory left.

As an aside, the set methods seem to be taking very long, to the point that running the model in Python is still faster than using DJL. But I am not sure if this is a related issue ptndarray

Expected Behavior

Somewhat constant memory usage.

Error Message

(Paste the complete error message, including stack trace.)

How to Reproduce?

My custom Translator class: `public class DeepLCTranslator implements Translator<String, Float> { static NDManager ptManager = NDManager.newBaseManager("PyTorch");

@Override
public NDList processInput(TranslatorContext translatorContext, String sequence) {
    HashMap<String, Integer> dictIndex = new HashMap<>();
    String atoms = "CHNOSP";
    for (int i = 0; i < atoms.length(); i++) {
        dictIndex.put(atoms.substring(i, i + 1), i);
    }

    HashMap<String, Integer> dictAA = new HashMap<>();
    String AAs = "KRPTNAQVSGILCMHFYWED";
    for (int i = 0; i < AAs.length(); i++) {
        dictAA.put(AAs.substring(i, i + 1), i);
    }

    NDList retList = new NDList();
    int paddingLength = 60;

    int seqLen = sequence.length();

    if (seqLen > paddingLength) {
        sequence = sequence.substring(0, paddingLength);
        seqLen = sequence.length();
    }

    NDArray matrix = ptManager.zeros(new Shape(paddingLength, dictIndex.size()));
    NDArray matrixAll = ptManager.zeros(new Shape(dictIndex.size() + 1));
    matrixAll.set(new NDIndex(matrixAll.getShape().get(0) - 1), seqLen);
    NDArray matrixSum = ptManager.zeros(new Shape(paddingLength / 2, dictIndex.size()));
    NDArray matrixHC = ptManager.zeros(new Shape(paddingLength, dictAA.size()));
    NDArray matrixPos = ptManager.zeros(new Shape(8, dictIndex.size()));

    //matrix
    StdAAComp stdAAComp = new StdAAComp();
    for (int i = 0; i < seqLen; i++) {
        Map<String, Integer> positionComposition = StdAAComp.stdAaComp.get(sequence.substring(i, i + 1));
        for (Map.Entry<String, Integer> entry : positionComposition.entrySet()) {
            String atom = entry.getKey();
            if (dictIndex.containsKey(atom)) {
                int val = entry.getValue();
                int dictIndexVal = dictIndex.get(atom);
                matrix.set(new NDIndex(i, dictIndexVal), val);
                matrixAll.set(new NDIndex(dictIndexVal),
                        matrixAll.getFloat(dictIndexVal) + val);
                matrixSum.set(new NDIndex((int) Math.floor(i / 2), dictIndexVal),
                        matrixSum.getFloat((int) Math.floor(i / 2), dictIndexVal) + val);
            }
        }
    }

    //matrixPos
    int[] positionsPos = new int[]{0, 1, 2, 3};
    for (int p : positionsPos) {
        Map<String, Integer> positionComposition = StdAAComp.stdAaComp.get(sequence.substring(p, p + 1));
        for (Map.Entry<String, Integer> entry : positionComposition.entrySet()) {
            String atom = entry.getKey();
            if (dictIndex.containsKey(atom)) {
                int val = entry.getValue();
                int dictIndexVal = dictIndex.get(atom);
                matrixPos.set(new NDIndex(p, dictIndexVal), val);
            }
        }
    }

    int[] positionsNeg = new int[]{-1, -2, -3, -4};
    for (int pn : positionsNeg) {
        Map<String, Integer> positionComposition = StdAAComp.stdAaComp.get(
                sequence.substring(seqLen + pn, seqLen + pn + 1));
        for (Map.Entry<String, Integer> entry : positionComposition.entrySet()) {
            String atom = entry.getKey();
            if (dictIndex.containsKey(atom)) {
                int val = entry.getValue();
                int dictIndexVal = dictIndex.get(atom);
                matrixPos.set(new NDIndex(8 + pn, dictIndexVal), val);
            }
        }
    }
    matrixPos.reshape(1, -1);

    //matrixHC
    for (int i = 0; i < sequence.length(); i++) {
        matrixHC.set(new NDIndex(i, dictAA.get(sequence.substring(i, i + 1))), 1);
    }

    NDArray matrixGlobal = matrixAll.concat(matrixPos.reshape(matrixPos.size()));
    NDManager manager = translatorContext.getNDManager();
    retList.add(manager.from(matrix));
    retList.add(manager.from(matrixSum));
    retList.add(manager.from(matrixGlobal));
    retList.add(manager.from(matrixHC));

    return retList;
}

@Override
public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
    return ndList.singletonOrThrow().getFloat(0);
}

@Override
public Batchifier getBatchifier() {
    return Batchifier.STACK;
}`

My testing script: `public static void main(String[] args) throws IOException, MalformedModelException, ModelNotFoundException, TranslateException, ExecutionException, InterruptedException {

    Path modelDir = Paths.get("C:/Users/yangkl/Downloads/proteomics/fastRTprediction/deeplc");
    Criteria<String, Float> criteria =
            Criteria.builder()
                    .setTypes(String.class, Float.class)
                    .optEngine("TensorFlow")
                    .optProgress(new ProgressBar())
                    .optModelPath(modelDir)
                    .optModelName("saved_model")
                    .optTranslator(new DeepLCTranslator())
                    .build();

    Engine.getEngine("PyTorch");

    ZooModel<String, Float> model = criteria.loadModel();

    ArrayList<String> peptides = new PeptideGenerator().generatePeptides();
    Predictor<String, Float> predictor = model.newPredictor();
    for (String peptide : peptides) {
        predictor.predict(peptide);
    }
    predictor.close();`

Steps to reproduce

(Paste the commands you ran that produced the error.)

1. 2.

What have you tried to solve it?

  1. Clearing NDArrays explicitly with .close()
  2. Running with batchPredict and predict

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

PASTE OUTPUT HERE
frankfliu commented 12 months ago

@yangkl96

A quick look at your class, you should not use static NDManager, NDManager should be closed properly to release memory.

frankfliu commented 12 months ago

In the Translator you should use TraslatorContext to NDManager instead of using a static one:

NDManager manager = ctx.getNDManager();

See example: https://github.com/deepjavalibrary/djl/blob/master/api/src/main/java/ai/djl/modality/audio/translator/SpeechRecognitionTranslator.java#L34

yangkl96 commented 12 months ago

Thanks for the response. I do use TranslatorContext to get an NDManager at the end of the processInput:

NDArray matrixGlobal = matrixAll.concat(matrixPos.reshape(matrixPos.size()));
NDManager manager = translatorContext.getNDManager();
retList.add(manager.from(matrix));
retList.add(manager.from(matrixSum));
retList.add(manager.from(matrixGlobal));
retList.add(manager.from(matrixHC));
return retList;

However, this is for a TensorFlow model. Because I wanted to use the set method to modify the NDArrays, I used a PyTorch NDManager at the beginning to create the NDArrays, and then used the TensorFlow NDManager to convert it before return.

When I instantiate and close a PyTorch NDManager in processInput, then the memory is freed up. Thank you!