10-zin / cpp-micrograd

A c/c++ implementation of micrograd: a tiny autograd engine with neural net on top.
MIT License
59 stars 6 forks source link

Memory leaks in c++ impl. #3

Open aapopajunen opened 5 months ago

aapopajunen commented 5 months ago

Following snippet highlights the issue:

int main() {
  {
    auto a = std::make_shared<Value>(1);
    auto b = std::make_shared<Value>(2);
    auto c = a + b;
  }
  // a, b and c are not destroyed ...
}

Issue seems to be with the captured variables of _backward lambda:

std::shared_ptr<Value> Value::operator+(const std::shared_ptr<Value>& other) {
    auto out_prev = std::unordered_set<std::shared_ptr<Value>>{shared_from_this(), other};

    auto out = std::make_shared<Value>(data + other->data, out_prev, "+");

    out->_backward = [this, other, out] { // We're capturing out, hence the lambda shares ownership of out!
        grad += out->grad;
        other->grad += out->grad;
    };
    return out;
}

This leads to a situation where the lambda inside out shares ownership of out. So, _backward is not destroyed until until out is destroyed and out is not destroyed until _backward is destroyed, hence nothing is destroyed.

To fix, the lambda should only take weak ownership.

std::shared_ptr<Value> Value::operator+(const std::shared_ptr<Value>& other) {
    auto out_prev = std::unordered_set<std::shared_ptr<Value>>{shared_from_this(), other};

    auto out = std::make_shared<Value>(data + other->data, out_prev, "+");
    Value* weak_ref = out.get();

    out->_backward = [this, other, weak_ref] {
        grad += weak_ref->grad;
        other->grad += weak_ref->grad;
    };
    return out;
}
fernandotenorio commented 2 months ago

I confirm this. Also, you need to cleanup inside Value::backward():

//Helper method
void Value::setBackward(const std::function<void()>& f) {
    this->_backward = f;
}

void Value::backward() {
    std::vector<std::shared_ptr<Value>> topo;
    std::unordered_set<std::shared_ptr<Value>> visited;

    std::function<void(const std::shared_ptr<Value>&)> build_topo = [&](const std::shared_ptr<Value>& v) {

        if (visited.find(v) == visited.end()) {
            visited.insert(v);

            for (const auto& child : v->prev) {
                build_topo(child);
            }
            topo.push_back(v);
        }
    };

    build_topo(shared_from_this());
    grad = 1.0;

    for (auto it = topo.rbegin(); it != topo.rend(); ++it) {
        const auto& v = *it;
        v->_backward();
    }

    //CLEANUP
    for (auto& node : topo) {
        node->get_prev().clear();
        node->setBackward([](){});
    }
}