Closed JohannesGaessler closed 1 week ago
How do you plan to support multiple GPUs? Currently this interface is taking a ggml_backend
, so I think that would limit it to a single GPU.
For some reason ggml_backend_sched
didn't work correctly but I wasn't able to nail down why because the MNIST example is already relatively complex. My plan is to implement tests for this interface, do the transition to ggml_backend_sched
, and use those tests for debugging. (I'm flexible regarding whether that should already be done in this PR or later)
I mentioned some possible solutions to the problem with ggml_backend_sched
in the previous PR. I think that whatever solution is finally implemented, it will need to be factored into the design early on, because it is going to affect fundamentally the way the tensors are allocated, and it is not a detail that can be ignored until the last moment.
I added tests for the new optimization interface. I'll do the transition towards ggml_backend_sched
next and then I'll try to get this PR into a state where it can be merged.
I adapted the new optimization interface to use ggml_backend_sched
(see also discussion in https://github.com/ggerganov/ggml/pull/999). The tests I wrote seem to work correctly but the MNIST training results using CUDA have for some reason become worse vs. master so there seem to still be issues. The interface should be mostly stable now though.
With some changes it can be used with BLAS and Metal. On M3 Max with BLAS it takes just 3 seconds to train, compared to 15 seconds with 3090 Ti CUDA or ~9 seconds with 13900k CPU.
ggml_opt_new_fit: epoch 0029:
train: [========================================| data=057000/057000, loss=0.011858+-0.000376, accuracy=99.86+-0.02%, t=00:00:00, ETA=00:00:00]
val: [========================================| data=003000/003000, loss=0.065221+-0.011987, accuracy=97.80+-0.27%, t=00:00:00, ETA=00:00:00]
ggml_opt_new_fit: training took 00:00:03
diff --git a/examples/mnist/mnist-common.h b/examples/mnist/mnist-common.h
index 6e2d235..c2a4464 100644
--- a/examples/mnist/mnist-common.h
+++ b/examples/mnist/mnist-common.h
@@ -134,6 +134,17 @@ struct mnist_model {
devices.push_back(dev);
}
+ // add accel devices
+ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+ if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
+ ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+ GGML_ASSERT(backend);
+ backends.push_back(backend);
+ devices.push_back(dev);
+ }
+ }
+
ggml_backend_dev_t dev_cpu = ggml_backend_dev_by_name("CPU");
GGML_ASSERT(dev_cpu);
ggml_backend_t backend_cpu = ggml_backend_dev_init(dev_cpu, nullptr);
@@ -151,12 +162,17 @@ struct mnist_model {
if (backends.size() == 1) {
fprintf(stderr, "%s: using %s (%s) backend\n",
__func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]));
- } else if (backends.size() == 2) {
- fprintf(stderr, "%s: using %s (%s) backend with %s (%s) fallback\n",
- __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]),
- ggml_backend_name(backends[1]), ggml_backend_dev_description(devices[1]));
} else {
- GGML_ASSERT(false);
+
+ fprintf(stderr, "%s: using %s (%s) backend with fallbacks: ",
+ __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]));
+ for (size_t i = 1; i < backends.size(); ++i) {
+ fprintf(stderr, "%s (%s)", ggml_backend_name(backends[i]), ggml_backend_dev_description(devices[i]));
+ if (i + 1 < backends.size()) {
+ fprintf(stderr, ", ");
+ }
+ }
+ fprintf(stderr, "\n");
}
{
diff --git a/src/ggml-metal.m b/src/ggml-metal.m
index fb2efc6..a9f35c7 100644
--- a/src/ggml-metal.m
+++ b/src/ggml-metal.m
@@ -3285,6 +3285,12 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
return ctx->all_data;
}
+static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ memset((char *)tensor->data + offset, value, size);
+
+ UNUSED(buffer);
+}
+
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
memcpy((char *)tensor->data + offset, data, size);
@@ -3318,7 +3324,7 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_get_base,
/* .init_tensor = */ NULL,
- /* .memset_tensor = */ NULL,
+ /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
I changed the MNIST code slightly to a version that I think is simpler. Am I right in assuming that it's unproblematic to initialize two backends for the same device and to then pass those backends to the same instance of ggml_backend_sched
?
compared to 15 seconds with 3090 Ti CUDA
The performance seems very poor but since the model is so small that is basically just a measure of the overhead. I think to remember that you are using WSL2 so maybe that has to do with it? On my machines (all running native Linux) I see the following performance:
Hardware | Time [s] |
---|---|
RTX 4090 | 3 |
RTX 3090 | 4 |
P40 | 5 |
Epyc 7742 | 9 |
RX 6800 | 14 |
Ryzen 5950X (65W power limit) | 17 |
Xeon E5-2683 v4 | 21 |
Thinkpad T16 Gen 1 | 30 |
Notably the RX 6800 is also performing much worse than the P40 even though with llama.cpp the performance is very similar.
I changed the MNIST code slightly to a version that I think is simpler. Am I right in assuming that it's unproblematic to initialize two backends for the same device and to then pass those backends to the same instance of
ggml_backend_sched
?
It may waste some resources and make graph splitting a bit slower, but not much. Generally I don't think it is very useful to have multiple GPU backends, the CPU backend is usually a better fallback since it the cost of copying the state is lower.
I think to remember that you are using WSL2 so maybe that has to do with it?
Kernel launch overhead is higher on Windows (it's the same reason -sm row
is so slow). I assume that it would be possible to remove nearly all of that using CUDA graphs, since it is the same graph being evaluated many times.
I removed the use of GGML graph exports from the MNIST example. In its current state the feature is fundamentally incompatible because it relies on statically allocated CPU tensors (also it would be necessary to mess with the internals of the optimization context).
Currently the optimization interface works by making the user statically allocate the model weights and inputs, and defining the computation of the outputs without allocation. The optimization context then statically allocates tensors for e.g. the optimizer momenta and defines the backward pass without allocation. The unallocated tensors are then given to ggml_backend_sched
. I think the correct way to reintroduce graph exports would be to selectively save the data of only those tensors that were statically allocated by the user and to save only the metadata for the other forward graph tensors. The logic that would be needed to minimize disk space is essentially the same that I have already implemented for allocating the forward/backward graphs.
After this is merged, can all the "opt" functions from ggml.h/ggml.c be removed, or is any that still used? I am moving all the CPU backend specific code to a separate file, and it would be easier if I could just remove these functions, since they only work with the CPU backend.
About the graph exports - I don't think these are used, it seems that it was an experimental feature that never really took off. It may be better to remove these functions entirely. cc @ggerganov.
Yes, everything graph export should be removed.
After this is merged, can all the "opt" functions from ggml.h/ggml.c be removed, or is any that still used?
Actually, my plan was to nail down the features for the new interface, then remove the old ggml_opt functions and rename the new interface to ggml_opt (I just thought it would be easier that way). The ggml_opt functionality on master can already be removed ahead of time I think.
I fixed gradient accumulation and I think that this PR is now feature complete and just needs the ggml_op_new -> ggml_opt transition. @slaren since you are currently also doing something where the old optimization interface would be removed, how should we coordinate this?
I am almost done with the change, I was planning to open a PR later tonight. It's moving code around so there will be merge conflicts, but it should be fairly straightforward to resolve them since I am not changing the functions that you are modifying here.
Unless I'm forgetting something I now have all features that I was targeting for this PR. After https://github.com/ggerganov/ggml/pull/1006 is merged all that is left to do is to rebase the code and change the prefix from ggml_opt_new to ggml_opt.
Minor patch to clear some compile warnings with clang:
diff --git a/src/ggml-opt.cpp b/src/ggml-opt.cpp
index ec9bccd..a1fb512 100644
--- a/src/ggml-opt.cpp
+++ b/src/ggml-opt.cpp
@@ -635,7 +635,7 @@ void ggml_opt_epoch_callback_progress_bar(
const int64_t t_eta_m = t_eta_s / 60;
t_eta_s -= t_eta_m * 60;
- fprintf(stderr, "| data=%06ld/%06ld, loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, t=%02ld:%02ld:%02ld, ETA=%02ld:%02ld:%02ld]\r",
+ fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
if (ibatch == ibatch_max) {
@@ -712,7 +712,7 @@ void ggml_opt_fit(
t_total_s -= t_total_h * 3600;
const int64_t t_total_m = t_total_s / 60;
t_total_s -= t_total_m * 60;
- fprintf(stderr, "%s: training took %02ld:%02ld:%02ld\n", __func__, t_total_h, t_total_m, t_total_s);
+ fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
}
ggml_opt_free(opt_ctx);
ggml-opt.h
as the maintainer for this module? I check llama.cpp and ggml issues for things relevant to my contributions at least once a day but maybe this would make it easier for downstream problems to identify who would be the right person to ask for help?test2.c
and test3.c
can I think be removed. There is a similar test for convergence in test-opt.cpp
and combined with the tests in test-backend-ops.cpp
any bug that would cause those tests to fail would already by covered.typedef struct ggml_backend * ggml_backend_t;
? Is it just to make the code more readable? Should something similar be done for the new structs I added?3. What is the reasoning behind e.g.
typedef struct ggml_backend * ggml_backend_t;
?
The reasoning is that they are opaque types and it is not relevant to the user whether they are structs or not. This is done with all the structs that are hidden from user code.
- Should we put me in ggml-opt.h as the maintainer for this module?
It's up to you. We can also add a CODEOWNERS where maintainers can add themselves if they would like to be notified for PRs.
- test2.c and test3.c can I think be removed.
Sounds good.
I noticed that the carriage return for the progress bar only results in the expected animation-like behavior if the progress bar is short enough to fit the terminal, otherwise it only returns to the point where the line is broken and spams the terminal with one new line for each minibatch. I just reduced the size of the progress bar but maybe there is a better solution.
I don't think there is an easy way to do that better. You could try to find the width of the terminal, but there isn't an easy multi-platform way to do that, it would require different implementations for every OS. It should be ok to assume that the terminal has at least 80 columns.
Is the mnist-train
example supposed to run correctly with Metal backend? The loss does not decrease atm:
Edit: this patch makes it work:
diff --git a/src/ggml-metal.m b/src/ggml-metal.m
index 9a045f8..6ada796 100644
--- a/src/ggml-metal.m
+++ b/src/ggml-metal.m
@@ -829,12 +829,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
return false;
}
case GGML_OP_NONE:
+ return false;
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
+ return true;
case GGML_OP_ADD:
+ return false;
case GGML_OP_SUB:
case GGML_OP_ACC:
case GGML_OP_MUL:
So it means likely that the GGML_OP_ADD
operator in the Metal backend is incorrect for some cases - I will take a look into this. However, I am not yet sure why the GGML_OP_NONE
also needs to return false.
It worked when I tried it a few days ago, not sure what changed. This is the output with the build that I tried:
Running fully with on the CPU or fully on the CUDA backend - results are good. But if I change the matrix multiplications to run on the CPU and the rest on the GPU, is no longer works:
diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu
index e68e405..4e95a18 100644
--- a/src/ggml-cuda.cu
+++ b/src/ggml-cuda.cu
@@ -2976,6 +2976,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
+ return false;
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
May be a hint what is going wrong in the scheduling.
I think I see the problem, in the first evaluation the inputs are correctly set and copied to the splits:
## SPLIT #10: Metal # 2 inputs: [node_18 ( 976K)] [node_20 ( 976K)]
node # 21 ( MUL): node_21 ( 976K) [Metal ]: Metal#node_18#0 ( 976K) [ NULL ] Metal#node_20#0 ( 976K) [ NULL ]
## SPLIT #11: BLAS # 0 inputs:
node # 22 ( OUT_PROD): node_22 ( 1M) [ BLAS ]: images ( 1M) [Metal ] node_21 ( 976K) [Metal ]
## SPLIT #12: Metal # 1 inputs: [node_22 ( 1M)]
node # 23 ( ADD): (view) ( 1M) [Metal ]: leaf_4 ( 1M) [Metal ] Metal#node_22#0 ( 1M) [ NULL ]
## SPLIT #13: CPU # 0 inputs:
node # 24 (REPEAT_BAC): node_24 ( 1K) [ CPU ]: node_21 ( 976K) [Metal ]
## SPLIT #14: Metal # 1 inputs: [node_24 ( 1K)]
node # 25 ( ADD): (view) ( 1K) [Metal ]: leaf_5 ( 1K) [Metal ] Metal#node_24#0 ( 1K) [ NULL ]
However, in subsequent evaluations, the inputs are no longer there:
## SPLIT #10: Metal # 0 inputs:
node # 21 ( MUL): node_21 ( 976K) [Metal ]: Metal#node_18#0 ( 976K) [Metal ] Metal#node_20#0 ( 976K) [Metal ]
## SPLIT #11: BLAS # 0 inputs:
node # 22 ( OUT_PROD): node_22 ( 1M) [ BLAS ]: images ( 1M) [Metal ] node_21 ( 976K) [Metal ]
## SPLIT #12: Metal # 0 inputs:
node # 23 ( ADD): (view) ( 1M) [Metal ]: leaf_4 ( 1M) [Metal ] Metal#node_22#0 ( 1M) [Metal ]
## SPLIT #13: CPU # 0 inputs:
node # 24 (REPEAT_BAC): node_24 ( 1K) [ CPU ]: node_21 ( 976K) [Metal ]
## SPLIT #14: Metal # 0 inputs:
node # 25 ( ADD): (view) ( 1K) [Metal ]: leaf_5 ( 1K) [Metal ] Metal#node_24#0 ( 1K) [Metal ]
I think this is a consequence of reusing the tensors, since ggml_backend_sched
modifies the src
of the tensors when it needs to copy a tensor to a different backend. If you then attempt to reuse that tensor without restoring the sources, well then everything breaks down in many ways.
Does training work correctly with this patch?
diff --git a/examples/mnist/mnist-common.h b/examples/mnist/mnist-common.h
index a871d03..6de9dad 100644
--- a/examples/mnist/mnist-common.h
+++ b/examples/mnist/mnist-common.h
@@ -18,7 +18,7 @@
// The logical batch size determines how many datapoints are used for a gradient update.
// The physical batch size determines how many datapoints are processed in parallel, larger values utilize compute better but need more memory.
#define MNIST_NBATCH_LOGICAL 1000
-#define MNIST_NBATCH_PHYSICAL 500
+#define MNIST_NBATCH_PHYSICAL 1000
static_assert(MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL == 0, "MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL != 0");
static_assert(MNIST_NTRAIN % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0");
diff --git a/examples/mnist/mnist-train.cpp b/examples/mnist/mnist-train.cpp
index a61dd05..4e8a638 100644
--- a/examples/mnist/mnist-train.cpp
+++ b/examples/mnist/mnist-train.cpp
@@ -33,7 +33,7 @@ int main(int argc, char ** argv) {
mnist_model_build(model);
- mnist_model_train(model, dataset, /*nepoch =*/ 30, /*val_split =*/ 0.05f);
+ mnist_model_train(model, dataset, /*nepoch =*/ 30, /*val_split =*/ 0.00f);
mnist_model_save(model, argv[2]);
}
It should make it so that there are no reallocations during training.
Yes, with this patch when using the Metal backend, the results look like what slaren posted earlier.
If we take a step back, the reason there are these issues is that for training the ability to execute only part of a compute graph is needed. So if the logic for ggml_backend_sched
were to be extended in such a way that allows you to skip part of the graph it would be sufficient to statically allocate the backward graph with optimizer step.
Or maybe just use a callback to fetch a new compute graph after all. It's not my preferred interface but it would probably cause the least amount of issues (as long as it's clearly communicated to users that they need to actually re-create the entire graph on each call).
You could also move the optimizer step to a different graph allocated using a different ggml_backend_sched_t
instance.
That would maybe work for the optimizer step but the forward and backward pass need to be allocated simultaneously so that the graph allocator can correctly determine which tensors from the forward pass need to be kept for the backward pass. But also for the validation set only the forward pass should be executed.
Did you find any issues with the idea of making a deep copy of the graph?
I just feel like it's going to be relatively complicated and require a disproportionate amount of effort to debug and maintain. Just setting pointers to null is I think simple enough but I want to avoid messing around with graph construction and allocation too much if I can help it; I'll try to do an implementation with a callback for graph construction, I think I found a reasonably simple way to do it.
Copying the graphs automatically could make the API nicer to use, though. Here is a rough implementation of a graph deep copy in case you want to use it:
std::map<ggml_tensor *, ggml_tensor *> tensor_map;
auto map_tensor = [&] (ggml_tensor * tensor) {
if (tensor_map.find(tensor) != tensor_map.end()) {
return tensor_map[tensor];
}
ggml_tensor * new_tensor;
if (tensor->data == NULL) {
new_tensor = ggml_dup_tensor(ctx, tensor);
// set other fields not copied by ggml_dup
new_tensor->op = tensor->op;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
new_tensor->nb[i] = tensor->nb[i];
}
new_tensor->flags = tensor->flags;
memcpy(new_tensor->op_param, tensor->op_param, sizeof(tensor->op_param));
strcpy(new_tensor->name, tensor->name);
// deep copy src tensors
for (int i = 0; i < GGML_MAX_SRC; i++) {
new_tensor->src[i] = map_tensor(tensor->src[i]);
}
} else {
// tensor is pre-allocated, no need to copy
new_tensor = tensor;
}
tensor_map[tensor] = new_tensor;
return new_tensor;
}
for (int i = 0; i < ggml_graph_n_nodes(graph); i++) {
ggml_graph_add_node(new_graph, map_tensor(ggml_graph_node(graph, i)));
}
I picked my WIP implementation for graph deep copies back up again but I think the approach is conceptually flawed. The fundamental problem is that the deep copies modify the source tensors of statically allocated tensors and that therefore only the first copy is easy and after that you would have to start piling on extra logic that adds a lot of complexity. At that point I think it would be easier to patch the split issue with the current code by remembering the original source tensors and restoring them prior to allocating a new graph.
To address that I think that you could just duplicate the static tensors, and set the same buffer and data. I tried this as a quick proof of concept, and it works somewhat, but there is something wrong because the loss decreases at a much slower rate, but I couldn't find the issue. Anyway here it is in case it is useful.
The bugs in your code seem to be complementary with the bugs that I had in mine; I think I can get a modified version of your code to work correctly. One of the problems is that Adam optimizers have a warmup based on the number of iterations. In GGML the number of iterations is currently stored in ggml_tensor.op_params
. So if you execute a copy of the tensor those iterations do not carry over. I think the correct solution is to store the number of iterations in a GGML tensor instead. If we go with an implementation based on graph deep copies I think we will need to strictly enforce that ggml_tensor.op_params
is not modified after tensor creation.
In the future I think a way to improve this would be to add support for graph plans to ggml_backend_sched
, so that you could create a plan from a graph, and it would store a copy of the splits and all the other processing it does, and then you could execute it any number of times with very low overhead. The implementation would also take advantage of graph plans in the backends, which would allow using CUDA graphs. Moving the iterations counter to a tensor instead of storing in op_params
would be necessary for this to work properly, which I think would be a good change, the tensor struct should be considered constant during evaluation, otherwise many assumptions break.
I pushed a version with graph copies where MNIST training seems to work (the tests are failing because there currently isn't a solution for retrieving the gradients). One issue was that you need to set the gradients in a second pass over the nodes while creating the deep copy. Also the graph reset code had to be adapted to work correctly with gradient accumulation when not directly handling a fully allocated graph (the gradients never get cleared otherwise). I think both of these issues can be fixed in a relatively simple way by refactoring the code to store the mapping from a node to its gradient in ggml_cgraph
(this is one of my long-term goals anyways).
For the optimizer iteration I found that storing the iteration value in a tensor is problematic because that value makes much more sense to manage in host code than device code. But ggml_backend_sched
automatically copies that data to other backends. The solution that I'm currently using is to pass an int64_t
pointer to ggml_opt_step_adamw
and to store that pointer in ggml_tensor.op_params
. The code in ggml_opt.cpp
then increments the value that the pointer points to after each eval.
Training works with Metal now.
But
ggml_backend_sched
automatically copies that data to other backends. The solution that I'm currently using is to pass anint64_t
pointer toggml_opt_step_adamw
and to store that pointer inggml_tensor.op_params
. The code inggml_opt.cpp
then increments the value that the pointer points to after each eval.
The problem with this is that it may be complicated to make it work with CUDA graphs in the future, since the host code is not going to be executed on every evaluation. I think this could be considered a bug in ggml_backend_sched
. It should be simple to fix, but it will require adding a list of ops that have output parameters, so that the copy is only performed if the tensor is both pre-allocated or tagged as an output, and it is being copied to be used as an output parameter of an operation.
If the host code is not guaranteed to be executed the approach needs to be fundamentally different. I think the simplest way to do it would be to allocate a small CPU buffer and to calculate all dynamic parameters needed for the optimizer step at the start of the epoch. This is very fast so it's fine to do it single-threaded on the CPU. The buffer with the dynamic parameters can then be set as an input for the optimizer step and it's fine to just copy it to device memory. More generally, if you let the user define a callback for setting the optimizer parameters dynamically that also enables e.g. dynamically varying the learning rate.
I pushed a version that stores the optimization parameters in the graph. They are allocated in a CPU buffer and written at the start of an eval. The CPU backend can use the parameters directly. I changed the CUDA backend to expect a device buffer with parameters instead of passing the parameters as kernel arguments. To change optimization parameters from their defaults users need to pass a custom function that calculates them.
Edit: no changes to ggml_backend_sched
were necessary.
I did some fixup, from my end this PR would be ready to merge. There is still the issue of refactoring the GGML code in such a way that ggml_tensor.grad
can be removed but that should not result in any externally visible changes. This is what I'll do next but for me it doesn't matter whether this is done in this PR or another one so I'll just go with whatever variant is easier for you to review.
I adapted the MNIST example README and while doing so I noticed that the convolutional model can now be trained with partial CUDA support which is faster than CPU only.
I pushed a refactor of the code around gradients. ggml_tensor.grad
has been removed, the tensor <-> gradient mapping is now entirely defined by the compute graph. I don't have a good understanding of the gradient checkpointing code and I am not aware of any other code (particularly tests) that uses it so I decided to remove it since it is now probably broken (I think something like this should be part of ggml-opt
anyways). I removed test-grad0
and test1
since my impression is that those tests don't provide enough utility to justify maintaining them (but I have no problem adapting them if someone disagrees).
Currently for to get the gradient or gradient accumulator for a tensor there is a loop over ggml_cgraph.nodes
which results in $O(N)$ time to get a gradient (accumulator) and $O(N^2)$ time to build the backward graph. It may make sense to use a hashtable here. This would be a relatively simple change in C++ but a bit tedious in C. So before touching that particular part of the code I would like to ask what the long-term plans for ggml.c
are in terms which language to use.
Actually, I think the hash table should be part of ggml_cgraph
but since that struct is defined in ggml-impl.h
and presumably that header should be C-compatible, wouldn't that imply that a C implementation is necessary anyways?
The goal is to progressive port the code to C++, but modifying ggml_cgraph
to require C++ may require updating a lot of code. If you only need a map from ggml_tensor *
to node index, that should be possible to do with a small addition to the hash table that is already in ggml_cgraph
, just add a an array of values with the same size as the hash set and set it in ggml_visit_parents
such as value[ggml_hash(node)] = index
.
I did an implementation using hashsets but I realized that with the current GGML hashsets building the backward pass would still take quadratic time. If a tensor is not contained in the hash set ggml_hash_find
will iterate over the entire set (and that will always be slower than the current implementation). To get a linear runtime for building the backward pass the hash set would need an explicit value for unused cells that causes the function to return when encountered. But since the backward pass is only constructed once per instance of ggml_opt_ctx
I think this optimization would not be worthwhile and just add unnecessary complexity. I would merge the backward pass construction as-is and revisit the issue if it at a later date it turns out that optimization is needed.
If a tensor is not contained in the hash set
ggml_hash_find
will iterate over the entire set (and that will always be slower than the current implementation).
Is this correct? It should only iterate until the first empty slot. That's just the way of dealing with collisions, but if the table is correctly sized, the number of collisions will be very close to zero.
You are absolutely right, looking at the code again it seems I missed part of the condition in the while loop.
This PR adapts the training code from the MNIST example into GGML with the goal of establishing a new interface for training models. The goal is to provide downstream projects with a common, more high-level interface for training that can be tested and debugged more easily. The general design is procedural and relies on the definition of data structures for optimization contexts, datasets, and results.
As of right now essentially only feed-forward classifiers are supported. I put the code into a new file
ggml-opt.cpp
with a corresponding new headerggml-opt.h
. One reason for this is that I am using some C++ functionality that is not performance critical but convenient. Another reason is that with the current GGML code there is no need to mess around with the internals of a GGML graph so I think it makes sense to split off functionality that is only going to be used by a subset of the userbase into a separate header (also the general vibe from what I can tell is that people findggml.c
hard to navigate due to its size).There is still a lot to do but I would like to get feedback on the interface early if possible. In particular, one thing that is still missing is testing code for the new interface. For now the prefix that I am using for the new interface is
ggml_opt_new
, I plan to change this toggml_opt
and remove the oldggml_opt
code prior to merging.