deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.07k stars 648 forks source link

Keep running out of memory when trying rmsprop for matrix factorization #1179

Closed hoodiney closed 3 years ago

hoodiney commented 3 years ago

Description

(A clear and concise description of what the bug is.) Hello, I'm new to djl and currently trying rmsprop for matrix factorization. My code is basically written according to the rmsprop code in chapter 11 of the "Dive Into Deep Learning". The problem is whenever I try to run the code, the program would fastly consume up my ram (or my gpu memory when using gpu). But when I run my tensorflow implementation in python the program worked fine. I've tried to switch from cpu to gpu, or switch the engine from mxnet to pytorch. But none of them worked.

Expected Behavior

(what's the expected behavior?) Maybe the program shouldn't keep consuming my memory since basically the cost will be calculated once and discarded in every iteration.

Error Message

Exception in thread "main" ai.djl.engine.EngineException: [enforce fail at ..\..\c10\core\CPUAllocator.cpp:75] data. DefaultCPUAllocator: not enough memory: you tried to allocate 29333056 bytes. Buy new RAM!
    at ai.djl.pytorch.jni.PyTorchLibrary.torchBackward(Native Method)
    at ai.djl.pytorch.jni.JniUtils.backward(JniUtils.java:1340)
    at ai.djl.pytorch.engine.PtGradientCollector.backward(PtGradientCollector.java:46)
    at ai.djl.pytorch.engine.PtGradientCollector.backward(PtGradientCollector.java:31)
    at MFRMSProp.mfRMSProp(MFRMSProp.java:78)
    at MFRMSProp.main(MFRMSProp.java:86)

How to Reproduce?

public class MFRMSProp {
    private NDArray M;
    private NDArray W;
    private NDArray H;
    private NDArray squareW;
    private NDArray squareH;
    private int nodeNum = 2708;
    private int maxIter = 5000;
    private float baseEps = (float) 1e-10;
    private int vecSize = 100;
    private float learnRate = (float) 0.015;
    private float decay = (float) 0.9;

    private float[][] MfromFile() throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader("./M.txt"));
        String line = null;
        float[][] arrayM = new float[nodeNum][nodeNum];
        int iter = 0;
        while ((line = bufferedReader.readLine()) != null) {
            String[] vals = line.split(" ");
            for(int j = 0; j < nodeNum; j++) {
                arrayM[iter][j] = Float.parseFloat(vals[j]);
            }
            iter++;
        }
        return arrayM;
    }

    private void initMatrix() throws IOException {
        NDManager manager = NDManager.newBaseManager();
        M = manager.create(MfromFile());
        W = manager.randomNormal(0f, 1f, new Shape(nodeNum, vecSize), DataType.FLOAT32);
        H = manager.randomNormal(0f, 1f, new Shape(vecSize, nodeNum), DataType.FLOAT32);
        squareW = manager.zeros(new Shape(nodeNum, vecSize), DataType.FLOAT32);
        squareH = manager.zeros(new Shape(vecSize, nodeNum), DataType.FLOAT32);
    }

    private void rmsProp(NDList params, NDList states) {
        float gamma = decay;
        float eps = baseEps;
        for (int i = 0; i < params.size(); i++) {
            NDArray param = params.get(i);
            NDArray state = states.get(i);
            // Update parameter and state
            // state = gamma * state + (1 - gamma) * param.gradient^(1/2)
            state.muli(gamma).addi(param.getGradient().square().mul(1 - gamma));
            // param -= lr * param.gradient / sqrt(s + eps)
            // param.subi(param.getGradient().mul(learnRate).div(state.add(eps).sqrt()));
            param.sub(param.getGradient().mul(learnRate).div(state.add(eps).sqrt()));
        }
    }

    public void mfRMSProp() throws IOException {
        initMatrix();
        System.out.println("matrix initialized");
        W.setRequiresGradient(true);
        H.setRequiresGradient(true);
        NDArray cost;
        NDList params = new NDList(W, H);
        NDList states = new NDList(squareW, squareH);
        for(int i = 0; i < maxIter; i++) {
            System.out.println(i);
            try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
                cost = M.sub(W.dot(H)).pow(2).sum().add(W.norm()).add(H.norm());
                gc.backward(cost);
            }
            rmsProp(params, states);
        }
    }

    public static void main(String[] args) throws IOException {
        MFRMSProp mfrmsProp = new MFRMSProp();
        mfrmsProp.mfRMSProp();
    }
}

Steps to reproduce

(Paste the commands you ran that produced the error.)

  1. Just paste the above code to your IDE and run it, the problem will pop out after a few epochs.

What have you tried to solve it?

  1. Changed the engine from mxnet to pytorch
  2. Tried the same code on win10 and linux
  3. Tried both the cpu and gpu versions The problem still exists even when I remove the gc.backward(cost)

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

------------------ CUDA -----------------
[DEBUG] - Found cudart: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\bin\cudart64_110.dll
GPU Count: 1
CUDA: 110
ARCH: 75
GPU(0) memory used: 1051262976 bytes

----------------- Engines ---------------
Default Engine: MXNet
[DEBUG] - Using cache dir: C:\Users\xuduox\.djl.ai\mxnet
[DEBUG] - Loading mxnet library from: C:\Users\xuduox\.djl.ai\mxnet\1.8.0-cu110mkl-win-x86_64\mxnet.dll
Default Device: null
PyTorch: 2
[DEBUG] - Using cache dir: C:\Users\xuduox\.djl.ai\pytorch
[WARN ] - No matching cuda flavor for win found: cu110.
[DEBUG] - Loading pytorch library from: C:\Users\xuduox\.djl.ai\pytorch\1.9.0-cpu-win-x86_64\0.13.0-SNAPSHOT-cpu-djl_torch.dll
[INFO ] - Number of inter-op threads is 6
[INFO ] - Number of intra-op threads is 12
MXNet: 0
XGBoost: 10
TensorFlow: 3

--------------- Hardware --------------
Available processors (cores): 12
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 153743080
Maximum memory (bytes): 4261412864
Total memory available to JVM (bytes): 266338304
Heap committed: 266338304
Heap nonCommitted: 34013184
frankfliu commented 3 years ago

@hoodiney

Please try the following:

    public void mfRMSProp() throws IOException {
        initMatrix();
        System.out.println("matrix initialized");
        W.setRequiresGradient(true);
        H.setRequiresGradient(true);
        NDArray cost;
        NDList params = new NDList(W, H);
        NDList states = new NDList(squareW, squareH);
        for(int i = 0; i < maxIter; i++) {
            System.out.println(i);
            try (NDManager temp = NDManager.newBaseManager()) {
                params.tempAttach(temp);
                states.tempAttach(temp);
                M.tempAttach(temp);
                try (GradientCollector gc = Engine.getInstance().newGradientCollector())  {
                    cost = M.sub(W.dot(H)).pow(2).sum().add(W.norm()).add(H.norm());
                    gc.backward(cost);
                }
                rmsProp(params, states);
                temp.ret(params);
                temp.ret(states);
                temp.ret(M);
            }
        }
    }
hoodiney commented 3 years ago

@frankfliu Thanks very much for the timely response! I couldn't find the newBaseNDManager() method in the 0.12.0 version of djl's api, so I used newBaseManager() instead. My code lasted more epochs than before but eventually still run out of memory.

frankfliu commented 3 years ago

@hoodiney, ooh, I missed NDArray M, I have update the above code. Please try it again.

hoodiney commented 3 years ago

@frankfliu It worked! Thanks a lot. May I ask where I can learn about the cause of the bug and have a better understanding of the tempAttach and ret functions, so that I can fix the problem by myself next time? I took a look at the Javadoc but still a bit confused. Thanks again for the help!

frankfliu commented 3 years ago

@hoodiney You find memory management document: http://docs.djl.ai/docs/development/memory_management.html

hoodiney commented 3 years ago

@frankfliu Thanks very much for the kind help! That indeed helped me a lot.