Open aapopajunen opened 5 months ago
I confirm this. Also, you need to cleanup inside Value::backward():
//Helper method
void Value::setBackward(const std::function<void()>& f) {
this->_backward = f;
}
void Value::backward() {
std::vector<std::shared_ptr<Value>> topo;
std::unordered_set<std::shared_ptr<Value>> visited;
std::function<void(const std::shared_ptr<Value>&)> build_topo = [&](const std::shared_ptr<Value>& v) {
if (visited.find(v) == visited.end()) {
visited.insert(v);
for (const auto& child : v->prev) {
build_topo(child);
}
topo.push_back(v);
}
};
build_topo(shared_from_this());
grad = 1.0;
for (auto it = topo.rbegin(); it != topo.rend(); ++it) {
const auto& v = *it;
v->_backward();
}
//CLEANUP
for (auto& node : topo) {
node->get_prev().clear();
node->setBackward([](){});
}
}
Following snippet highlights the issue:
Issue seems to be with the captured variables of
_backward
lambda:This leads to a situation where the lambda inside
out
shares ownership ofout
. So,_backward
is not destroyed until untilout
is destroyed andout
is not destroyed until_backward
is destroyed, hence nothing is destroyed.To fix, the lambda should only take weak ownership.