Closed hoodiney closed 3 years ago
@hoodiney
Please try the following:
public void mfRMSProp() throws IOException {
initMatrix();
System.out.println("matrix initialized");
W.setRequiresGradient(true);
H.setRequiresGradient(true);
NDArray cost;
NDList params = new NDList(W, H);
NDList states = new NDList(squareW, squareH);
for(int i = 0; i < maxIter; i++) {
System.out.println(i);
try (NDManager temp = NDManager.newBaseManager()) {
params.tempAttach(temp);
states.tempAttach(temp);
M.tempAttach(temp);
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
cost = M.sub(W.dot(H)).pow(2).sum().add(W.norm()).add(H.norm());
gc.backward(cost);
}
rmsProp(params, states);
temp.ret(params);
temp.ret(states);
temp.ret(M);
}
}
}
@frankfliu Thanks very much for the timely response! I couldn't find the newBaseNDManager() method in the 0.12.0 version of djl's api, so I used newBaseManager() instead. My code lasted more epochs than before but eventually still run out of memory.
@hoodiney, ooh, I missed NDArray M
, I have update the above code. Please try it again.
@frankfliu It worked! Thanks a lot. May I ask where I can learn about the cause of the bug and have a better understanding of the tempAttach and ret functions, so that I can fix the problem by myself next time? I took a look at the Javadoc but still a bit confused. Thanks again for the help!
@hoodiney You find memory management document: http://docs.djl.ai/docs/development/memory_management.html
@frankfliu Thanks very much for the kind help! That indeed helped me a lot.
Description
(A clear and concise description of what the bug is.) Hello, I'm new to djl and currently trying rmsprop for matrix factorization. My code is basically written according to the rmsprop code in chapter 11 of the "Dive Into Deep Learning". The problem is whenever I try to run the code, the program would fastly consume up my ram (or my gpu memory when using gpu). But when I run my tensorflow implementation in python the program worked fine. I've tried to switch from cpu to gpu, or switch the engine from mxnet to pytorch. But none of them worked.
Expected Behavior
(what's the expected behavior?) Maybe the program shouldn't keep consuming my memory since basically the cost will be calculated once and discarded in every iteration.
Error Message
How to Reproduce?
Steps to reproduce
(Paste the commands you ran that produced the error.)
What have you tried to solve it?
gc.backward(cost)
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below: