ggerganov / ggml

Tensor library for machine learning
MIT License
11.25k stars 1.05k forks source link

ggml/ex: ref. CEL, ggml_backend_sched for MNIST #976

Closed JohannesGaessler closed 1 month ago

JohannesGaessler commented 1 month ago

This PR refactors the CPU implementation of cross entropy loss to avoid false sharing (from the partial sums being in the same cache line) as well as potential issues with the loop variables being int. It also adds ggml_backend_sched support for the MNIST example.

JohannesGaessler commented 1 month ago

There seems to be an issue with CPU MNIST training where the validation loss is much too low.

JohannesGaessler commented 1 month ago

I think I figured out the problem:

enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
    if (!sched->is_reset && !sched->is_alloc) {
        ggml_backend_sched_reset(sched);
    }

    if (!sched->is_alloc) {
        if (!ggml_backend_sched_alloc_graph(sched, graph)) {
            return GGML_STATUS_ALLOC_FAILED;
        }
    }

    return ggml_backend_sched_compute_splits(sched);
}

In mnist-common.cpp the call ggml_backend_sched_alloc_graph(model.backend_sched, gb_opt) allocates a graph that does the forward pass, the backward pass, and the optimizer step. But because this sets is_alloc to true the above code in ggml-backend.c then essentially just discards whatever graph you give it and executes the graph that was used for allocation. I assume this is not the intended behavior, what do you think would be the correct way to fix this?

slaren commented 1 month ago

You need to call ggml_backend_sched_reset when changing graphs.

JohannesGaessler commented 1 month ago

Adding calls to ggml_backend_sched_reset breaks CUDA training. I'm not sure why, it probably has to do with my code expecting persistent tensor data after resetting and allocating a different graph. How about this: extend the logic for ggml_backend_sched_compute_graph in such a way that when a subgraph of the allocated graph is passed (the nodes of the subgraph are a prefix of the nodes of the graph that was allocated), then execute only the nodes in the subgraph.

slaren commented 1 month ago

I don't think that's feasible without adding too much complexity, ggml_backend_sched_graph_compute doesn't work with the original graph directly, it works with a heavily processed version of it split into multiple parts per-backend. Can you point me to the code that does that?

JohannesGaessler commented 1 month ago

The instances where I think a reallocation could maybe be causing problems:

Do you think it would be feasible to allocate separate graphs for the forward pass, the backward pass, and the optimizer step?

slaren commented 1 month ago

Everything can be reallocated between graph evaluations, the goal of ggml-alloc (the ggml_backend_sched is just using ggml-alloc internally) is to reuse the same buffers between different graph evaluations without requiring new allocations. Fundamentally it is not possible to use tensors from a previous graph evaluation in the next graph with the same ggml-alloc, since the tensors of the new graph will overwrite the tensors of the previous graph (or worse, the buffer they were allocated in may have been freed if ggml-alloc needs to increase its buffer size).

I can see two ways to handle this:

JohannesGaessler commented 1 month ago

separate graphs for the forward pass, the backward pass, and the optimizer step

Actually, that won't work. Both the forward and backward pass need to be in the same graph because the backward pass may need to use data from the forward pass that cannot be overwritten until then.

slaren commented 1 month ago

I think you could have a separate graph and ggml-alloc (or sched) for the optimizer only. To do so, you would have to flag all the inputs to the optimizer as graph outputs.

JohannesGaessler commented 1 month ago

I was not able to nail down the exact problem with moderate effort. For now I'll just revert the MNIST changes, once I have a more high-level API for training I'll adapt and expand test-opt (compare tensor changes vs. analytical expectations) and debug the use of ggml_backend_sched using that.