Closed JohannesGaessler closed 1 month ago
One potential issue is that the convergence logic would require calls to ggml_backend_tensor_get which would make ggml.c depend on ggml_backend.c (which it currently does not).
Actually, a much bigger issue is that for ggml_backend_graph_compute
a pointer to a backend is needed.
I don't think that's a problem. ggml-backend was designed to not require many changes to the core ggml code, but since then I think it has become the standard way to use ggml, and it doesn't make much sense to maintain the subset of the API that only works with the CPU backend. We should move all the CPU backend code to a separate file, and make all the core ggml functions explicitly compatible with ggml-backend.
The design looks good to me. Something to consider is that to support multiple GPUs and fallback to CPU for unimplemented ops in the backends, it is necessary to use ggml_backend_sched
.
I forgot: the current code also has an extension to the GGML backend interface with memset_tensor
in order to clear specific tensors (since right now I think the only way to do it would be to allocate zeroed memory and invoke set_tensor
).
The Adam optimizer needs to know the current iteration since it does a warmup. I'm currently passing this information via ggml_tensor.op_params
but the downside with this approach is that the information regarding the iteration is duplicated across all tensors. But at the same time I don't think it would be a good idea to add a global state to the forward pass when right now all relevant information is encapsulated in ggml_tensor
.
I pushed a working prototype for CUDA MNIST training/evaluation (fully connected only). Compared to PyTorch the training on my RTX 3090 is ~45x faster (1.25s vs. 56.58s) but with such a small model you're basically just measuring overhead. The CUDA evaluation is actually slower than the CPU evaluation, presumably because the model is too small to make GPU acceleration worthwhile given the additional overhead.
One issue that I still have is how to handle the combination of GGUF+backends other than CPU. Right now I'm allocating a temporary context that just stores the data in RAM but it feels kind of clunky. Is there a better way to do this?
One issue that I still have is how to handle the combination of GGUF+backends other than CPU. Right now I'm allocating a temporary context that just stores the data in RAM but it feels kind of clunky. Is there a better way to do this?
Check the way the magika example does this: make a no_alloc
gguf context, call ggml_backend_alloc_ctx_tensors
, then load the data from file using gguf_get_tensor_offset
. This way at least the whole file does not need to be loaded into memory.
I think for ggml_opt_step_adam
the parameter sched
is not needed. It's essentially a way to adjust the learning rate via a callback but I think something like this should be done one level further up via rather than in the tensors.
Alright, I think this PR now has the minimum (in terms of functionality) where I think merging it would make sense. What is still missing is an update to the documentation and more general CUDA support for GGML_OP_REPEAT_BACK
. Long-term I would also want a way to define datasets in GGML and a more high-level API (comparable to the current ggml_opt
one).
As of right now this PR contains my WIP version that only supports stochastic gradient descent and the CPU backend. The training is ~3x faster than on master (but the overall rate of convergence is worse than fully featured Adam).
Is this still the case? Seems like the PR now implements Adam, but the convergence is still slower compared the mnist-train-fc.py
. Do you know why is that?
# ggml (mnist-fc)
mnist_model_train: epoch 29 start...done, took 0.59s, train_loss=0.058282, train_acc=98.46%, val_loss=0.083196+-0.025347, train_acc=97.63+-0.28%
# py
Epoch [30/30], Step [60000/60000], Loss: 0.0086, Accuracy: 99.91%
Is this still the case? Seems like the PR now implements Adam, but the convergence is still slower compared the mnist-train-fc.py. Do you know why is that?
To be clear, I was specifically referring to the speed of epochs, not to the speed of convergence. The speed of epochs in GGML has not significantly changed from adding momentum. However, I added an optimization to the PyTorch script that loads the data asynchronously so the PyTorch baseline to compare against has become comparatively faster. Currently there are the following relevant differences:
I didn't look at this until now but the training loss with PyTorch seems to decrease much faster per epoch than with GGML and I don't think this can be explained with the above factors; I'll investigate.
I get comparable results between PyTorch and GGML in terms of training loss when I add the following two modifications: disable dataset shuffling for PyTorch and set the GGML physical batch size to 1000. The latter is a bug since by definition the physical batch size should have no effect beyond differences in rounding error. For the dataset shuffling I would have intuitively expected that this is only relevant for generalization but it seems that it also improves the rate at which the model gets better on the training set.
I figured out the problem: I incorrectly assumed that the ggml_tensor.grad
pointers would be constant for my implementation. So while the original gradients are being used as input for the GGML_ADD
tensors that are eventually being used as gradients they are never incremented and thus remain zero. The accumulation steps prior to the last one are effectively just being discarded. I think the fix will be to do in-place additions in ggml_compute_backward
.
I figured out the problem: I incorrectly assumed that the
ggml_tensor.grad
pointers would be constant for my implementation. So while the original gradients are being used as input for theGGML_ADD
tensors that are eventually being used as gradients they are never incremented and thus remain zero. The accumulation steps prior to the last one are effectively just being discarded. I think the fix will be to do in-place additions inggml_compute_backward
.
So if I understand correctly, the following call is basically a noop atm:
The reason is because ggml_backend_graph_compute(model.backend, gb_opt);
ends up using the gradients from the gb_opt
graph which so far haven't been updated. Instead we have been updating the gradients of the gb_grad
graph.
I tried your idea, which I think is simply to set the inplace = true
in ggml_add_or_set
:
diff --git a/src/ggml.c b/src/ggml.c
index de61438..483a3b2 100644
--- a/src/ggml.c
+++ b/src/ggml.c
@@ -18129,7 +18129,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg
if (ggml_hash_contains(zero_table, a)) {
return b;
} else {
- return ggml_add_impl(ctx, a, b, false);
+ return ggml_add_impl(ctx, a, b, true);
}
}
But it seems we are still missing something, as the training accuracy dropped:
mnist_model_train: epoch 29 start...done, took 0.58s, train_loss=0.182481, train_acc=95.08%, val_loss=0.166181+-0.027298, train_acc=96.07+-0.35%
I've pushed a WIP fix that works but has bad performance to clarify the problem. The original gradients are initialized with zero and need to be incremented with the sum tensors after each accumulation step to get correct results. Unrelated to the problem with accumulation there are also two other issues: the wrong graph was being copied for gb_opt
and the execution of the forward graph is not needed because the backwards graphs include all of its tensors (the latter only matters for performance).
I tried your idea, which I think is simply to set the inplace = true in ggml_add_or_set:
The problem is the upper branch where the tensor is in the zero table. In that case there needs to be an in-place addition instead of a replacement. But so far I have not been able to make that work so there is likely still some other issue.
Sorry, the supposed fix had two bugs that happened to cancel each other out.
I pushed a proper fix. The correct handling of gradient accumulation needs some extra bookkeeping to track the gradients of parameters and whether they should be accumulated, I added a new tensor flag for this.
Actually, now that I think about it it would maybe be better to do this via a hash set instead of via tensor modification since whether or not a gradient should be accumulated is a property of the compute graph rather than the gradient tensor. But by that logic the existing code in ggml_build_backward
that is modifying the tensors is also bad.
Also: there was some inconsistent use of ggml_cgraph.nodes[i]->grad
vs. ggml_cgraph.grads[i]
that was causing problems.
My priorities: while ggml_backend_sched
would be nice to have I think it's more important to properly define datasets with functionality such as data shuffling and asynchronous data pre-loading (long-term probably also GGUF support for very large datasets that don't fit in RAM). Using such datasets I would then write a more high-level API that trains a feed-forward neural network given a dataset and compute graph (and optinally labels) as input. In that high-level API I would then start using ggml_backend_sched
.
Right now I have a prototype for a dataset in user space.
The issue with ggml_backend_sched
was more relevant when the plan was to pass a backend to the opt functions. Now that they are ggml ops, it is entirely up to the user code whether to use the scheduler or not.
Re asynchronous data loading: this may already be obvious, but you should look at ggml_backend
instances as streams, and thus if you want to upload data while something else is running, this should be done by creating a new ggml_backend
instance and using ggml_backend_tensor_set_async
. I will make changes that will make this distinction more clear in the future by adding new objects to represent backends and backend devices, and eventually the current ggml_backend
objects will be renamed to something like ggml_backend_stream
.
The ultimate goal of this PR is to add backend support for numerical optimization, namely Adam and L-BFGS. As of right now the corresponding computations are done by a single thread outside any of the GGML graphs. As a consequence only a single thread is used and only the CPU backend is compatible. I think the correct way to remedy this is to make the optimizers part of the GGML compute graphs. This also fixes some allocation issues where the optimization code allocates extra tensors to hold persistent extra data for the optimizers.
As of right now this PR contains my WIP version that only supports stochastic gradient descent and the CPU backend. The training is ~3x faster than on master (but the overall rate of convergence is worse than fully featured Adam).
The overall design that I envision is that the optimizer is specified when creating the backwards graph. If no optimizer is specified, calculate the gradients without touching the weights. If an optimizer is specified, apply it to all parameters after the gradients have been calculated by adding an extra GGML op on top (could probably be optimized to overwrite gradients that are no longer needed). During backwards graph creation also specify any extra tensors needed for the optimizer so they can be correctly allocated for all backends. Functions like
ggml_opt
would then mainly be calling the backwards graph in a loop and check convergence. One potential issue is that the convergence logic would require calls toggml_backend_tensor_get
which would makeggml.c
depend onggml_backend.c
(which it currently does not). If that is a problem the optimization code could maybe be moved to a new file likeggml-algo.c
.If there are issues with my design please let me know early.