Open BlinkDL opened 1 year ago
This version of inference is even simpler. https://github.com/harrisonvanderbyl/rwkvstic/blob/master/src/rwkvstic/agnostic/agnosticRwkv.py
@BlinkDL Yes! This looks like a great suggestion.
@BlinkDL Yes! This looks like a great suggestion.
The full RWKV model & text generation in 150 lines:
https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py
This is a slower version but shows the extreme simplicity of RWKV :)
And latest 14B ctx4096 is very good :) https://huggingface.co/BlinkDL/rwkv-4-pile-14b/blob/main/RWKV-4-Pile-14B-20230228-ctx4096-test663.pth
Chat record from a user: cryscan 1:02:26 What is your interesting news today? Eloise 1:03:13 I have a dream last night. I dreamed that I would get an audition for a movie in the USA. It was an amazing experience and my most impressive job. But the fact was, I couldn't be selected in the final round because of my physical limitations and some people thought that it was unfair. That's what made me downhearted for a few days, but later I found another thing that could interest me. cryscan 1:03:41 What is that another thing? Eloise 1:03:50 The country and culture of a small but unique Island, the best place on earth, Iceland. I have a dream to go to Iceland and I would like to live there, maybe become a movie director someday. cryscan 1:06:17 A movie director? Eloise 1:06:26 Yes, but that doesn't mean I can't be an actress too. Actually, I think it is not hard to be an actress and a director both. It just takes different skills. But having experience in directing a movie would help me learn the acting part well. cryscan 1:07:12 Yes, they are correlated. Eloise 1:07:19 I am sure they are. When I learned that there was a possibility for me to get an audition in Iceland, I knew it would be hard, but I wanted to try it because of the reasons above. cryscan 1:08:22 So... Do you have any chance to know what is the movie about? Eloise 1:08:30 That's right! I didn't know anything about it until a couple of days ago. But I was so excited when I knew the movie's title was called Eon. It means 'Hour' in Icelandic language. cryscan 1:09:39 Sounds like a movie about time. Eloise 1:09:51 Yes, it's exactly that! I've never seen an Icelandic movie before and the director said it would be the first movie to be shot there. It's a bit difficult for me because Icelandic is not my native language. But I'll try my best to learn how to speak Icelandic in order to get the part in the movie. cryscan 1:10:48 Do you have any lines? Eloise 1:10:53 Well, it is a little bit different. They can just send some scripts to the potential actors and ask them to read them out loud. cryscan 1:11:17 And what about you? Eloise 1:11:23 I'm not sure. Maybe the director will invite me to go to Iceland for a few days and learn the local language. I'll try my best to do that. cryscan 1:12:43 I can see that you are working hard on it! Eloise 1:12:47 I hope I will do well! I think that my acting skill could be very helpful in the audition. cryscan 1:13:20 When do you plan to go to Iceland? Eloise 1:13:25 I am planning to go there in a month or so. As soon as I know about the time, I'll let you know. cryscan 1:13:44 Ok! Good luck then. Eloise 1:13:47 Thanks! I'll talk to you later. Bye!
Hi all! Looks like there is no ggml implementation of RWKV yet, and nobody have announced they would like to implement it.
Today I've started trying to port RWKV to ggml. I'm a noob in both C/C++ and ggml, so it may take time, but I'm determined and motivated by possibility of running 14B model with quantization. I base my impl on llama.cpp version of ggml.
I plan to publish the resulting code regardless of whether it works or not -- I think it would be easier for others to fix existing code, rather than writing from scratch.
Can't tell how much time it would take tho, completely unsure.
Here's an example of how I did FFN part of RWKV block (does not work now):
struct ggml_tensor * sigmoid(ggml_context * ctx, struct ggml_tensor * x) {
// ggml has no native sigmoid, but silu(x) / x can be an approximation
x = ggml_silu(ctx, x);
x = ggml_div(ctx, x, x);
return x;
}
...
// FFN/channel mixing
{
// self.layer_norm(x, self.w.blocks[i].ln2)
struct ggml_tensor * x0 = layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias);
// state[5 * i + 0]
int32_t offset_in_bytes = (5 * i + 0) * n_embd * 4;
struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embd, offset_in_bytes);
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xk = ggml_add(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_k))
);
struct ggml_tensor * xr = ggml_add(
ctx,
ggml_mul(ctx, x0, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, ggml_sub(ctx, ones, layer.ffn_time_mix_r))
);
// state[5 * i + 0] = x
ggml_cpy(ctx, x0, x_prev);
// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = sigmoid(
ctx,
ggml_mul_mat(ctx, layer.ffn_receptance, xr)
);
// k = torch.square(torch.relu(kw @ xk))
struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(
ctx,
// todo not works
ggml_mul_mat(ctx, layer.ffn_key, xk)
));
// r * (vw @ k)
// todo x0 = ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
// x = x + self.channel_mixing(...)
x = ggml_add(ctx, x, x0);
}
@saharNooby
Happy to see that you will give this a try! Make sure to update to the latest changes that I just pushed.
It should be relatively easy to add ggml_sigmoid()
implementation to ggml
- open a PR if you feel like
As a starting point, if you manage to convert the tensors to ggml
format and just load them successfully in a C++ program, it would be a great help. See the gpt-2
and gpt-j
examples in this repo and their associated conversion scripts. You can open a draft PR so we can more easily keep track of the progress.
Let me know if you get stuck somewhere.
We are currently wrapping up some stuff in the llama.cpp
repo and after we are ready, I think we will focus on implementing more models. And RWKV is definitely on the radar. Will let you know if I or someone else starts working on it too
Great work @saharNooby I am training "Raven"-series models (RWKV on alpaca+codealpaca+guanaco+...) Gradio Demo: https://huggingface.co/spaces/BlinkDL/Raven-RWKV-7B
@ggerganov Hi again! I'm a little confused about dimension/value order in ggml. Consider these 2 snippets:
PyTorch code for multiplying a matrix by a vector and serializing the matrix:
import torch
import numpy as np
# (2 rows, 3 columns)
x = torch.tensor([
[0.8012, 0.0138, 0.6916],
[0.2435, 0.3322, 0.4037]
], dtype=torch.float32)
# 1st row is written, then 2nd row is written
# Prints [0.8012 0.0138 0.6916 0.2435 0.3322 0.4037]
print(np.frombuffer(x.numpy().tobytes(), dtype=np.single))
# (3)
y0 = torch.tensor([0.4699, 0.1103, 0.9175], dtype=torch.float32)
z0 = torch.matmul(x, y0)
# Prints (2) [1.0125, 0.5215]
print(z0.shape, z0)
Its ggml equivalent:
// Not in ggml, I made it myself -- may not be correct
void ggml_set_f32_2d(struct ggml_tensor * tensor, int i, int j, float value) {
RWKV_ASSERT(tensor->n_dims == 2, "Not a 2D tensor");
RWKV_ASSERT(tensor->type == GGML_TYPE_F32, "Unsupported data type");
*(float *) ((char *) tensor->data + j * tensor->nb[1] + i * tensor->nb[0]) = value;
}
// (2 rows, 3 columns)
struct ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 2);
ggml_set_f32_2d(x, 0, 0, 0.8012F);
ggml_set_f32_2d(x, 1, 0, 0.0138F);
ggml_set_f32_2d(x, 2, 0, 0.6916F);
ggml_set_f32_2d(x, 0, 1, 0.2435F);
ggml_set_f32_2d(x, 1, 1, 0.3322F);
ggml_set_f32_2d(x, 2, 1, 0.4037F);
// Prints [0.801200 0.013800 0.691600 0.243500 0.332200 0.403700]
printf(
"[%f %f %f %f %f %f]\n",
*(float *) ((char *) x->data + 4 * 0),
*(float *) ((char *) x->data + 4 * 1),
*(float *) ((char *) x->data + 4 * 2),
*(float *) ((char *) x->data + 4 * 3),
*(float *) ((char *) x->data + 4 * 4),
*(float *) ((char *) x->data + 4 * 5)
);
// (3)
struct ggml_tensor * y0 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3);
ggml_set_f32_1d(y0, 0, 0.4699F);
ggml_set_f32_1d(y0, 1, 0.1103F);
ggml_set_f32_1d(y0, 2, 0.9175F);
struct ggml_tensor * y0_new_shape = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 3, 1);
struct ggml_tensor * y0_reshaped = ggml_reshape(ctx, y0, y0_new_shape);
struct ggml_tensor * z0 = ggml_mul_mat(ctx, x, y0_reshaped);
compute_graph(ctx, z0);
// Prints (2, 1) [1.012549, 0.521457]
printf("z0 (%d, %d) [%f, %f]\n", z0->ne[0], z0->ne[1], ggml_get_f32_1d(z0, 0), ggml_get_f32_1d(z0, 1));
Both snippets output same result after matmul, which is good. But I was required to swap dimensions of x
matrix in ggml, othwerise, ggml_mul_mat
did not work.
Is there anything wrong with second snippet? Do PyTorch and ggml store values differently? I also noticed that in llama.cpp
, when converting the model, dimensions are reversed, but data is left untouched -- looks related. Any details/explanations would be very velcome :)
BTW, here is the repo where I work on rwkv.cpp
.
Current status: converted model can be loaded; inference code runs, but I miss element-wise max
and exp
in ggml -- I'll be implementing them later.
+1 for RWKV
@saharNooby
Do PyTorch and ggml store values differently? I also noticed that in llama.cpp, when converting the model, dimensions are reversed, but data is left untouched -- looks related. Any details/explanations would be very velcome :)
From what I understood from messing around with ggml
, dimensions are indeed reversed when comparing to PyTorch. ggml
stores dimensions for multi-dimensional tensors in row-major order, while PyTorch seems to use the last dimension to index elements in a row. (probably because that's also how NumPy indexes multidimensional arrays).
In other words, ggml
stores tensor dimensions with an array of numbers of elements called ne
[^0]. This means that ne[0]
is the number of elements in the first dimension (i.e. elements in a row[^1]), ne[1]
is the number of elements in the second dimension (i.e. elements in a column), and so on.
[^0]: See how the ne
of the ggml_tensor
struct is defined in ggml.h
[^1]: Not to be confused with the number of rows, which can be calculated by multiplying number of elements of the other dimensions together
What might be confusing initially is that a 2D tensor (ndims = 2
) with dimensions set as ne = { 3, 4 }
[^2] has 3 elements in a row, which means it has 3 columns, and it also has 4 elements in a column which means it has 4 rows[^3]. This is reversed when compared to the (number of rows, number of columns)
notation that PyTorch might be using. (I don't have much experience with PyTorch, but it seems that way in the docs)
[^2]: Like in ggml_new_tensor_2d
[^3]: If you have read the footnote about the number of rows[^1], you know that the number of rows here is ne[1] * ne[2] * ne[3]
, so 4 * 1 * 1 = 4
Arguably, ggml
's way might be better (or at least more consistent) at representing dimensions of a tensor (at least in C) than PyTorch's way, since it seems easier to refer to dimensions relative to 0
than dimensions relative to the last dimension (again, at least in C, unlike in Python where the -1
index is the last dimension).
I hope this is helpful.
@compilade Thanks for the explanation, very helpful indeed!
To summarize:
(2, 3)
means "2 rows, 3 columns"(3, 2)
means "3 elements in a row, 2 elements in a column"Both of these represent the same tensor shape fundamentally, so data layout in memory is also the same. The confusion came from PyTorch storing number of dims, where ggml is storing number of elements in a dim.
This explains why when converting PyTorch tensors to ggml we need to reverse order of dims, but keep the data as is -- there are no differences in memory format between PyTorch/ggml, just in meaning of the elements in the shape.
Now we have FP32, FP16 and Q4_0/Q4_1 inference working, and a Python wrapper:
model = rwkv_cpp.RWKVModel(r'bin\Release\rwkv.dll', r'C:\rwkv.cpp-169M.bin')
logits, state = None, None
for token in [1, 2, 3]:
logits, state = model.eval(token, state)
print(f'Output logits: {logits}')
# Don't forget to free memory after you've done working with the model
model.free()
Please add INT8 too :)
@saharNooby Please remember to keep some tensors in fp32.
Can check https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py
You only need INT8 / INT4 for these matrix weights: "if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x:"
Basically: if (len(w[x].shape) == 2) and ('emb' not in x):
@BlinkDL Thanks for the suggestion! I quantize only 2D tensors -- looks like this covers all weights you've named.
But embedding matrix currently is also quantized. Is there a specific reason to not quantize it? It would take 500+ MB in FP16 for 14B model -- pretty large.
(BTW, Issues and Discussions are open in rwkv.cpp repo -- you can create new issue/discussion here if you want)
@ggerganov Since RWKV on ggml basically works now, I think "Example of RWKV inference" can be removed from the ggml roadmap.
There are only links to your repos in the roadmap, so I did not create a PR to add my own link here -- it would look a little out of place.
Though I will appreciate if you add link to rwkv.cpp :)
But embedding matrix currently is also quantized. Is there a specific reason to not quantize it? It would take 500+ MB in FP16 for 14B model -- pretty large.
Because there is no computation involved in embedding :)
So it's better to use the more accurate embedding.
Hey, I just see all this great progress - amazing!
Yes, the dimensions in ggml
are swapped compared to Python - I realised this at a later stage in the development and didn't feel like changing.
I'll take a detailed look in the next days when I get some free time and see if I can give any suggestions for improvements.
@saharNooby Yes - will add the link to your repo. Will try to do so tomorrow
Can try this for INT4: compute "mx my rx ry" as in https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py
Basically: rescale all rows & columns of w --> compute INT4 x @ w --> rescale result.
Probably you only need rx & ry, and you can compute them using max(abs(w)).
And probably only need them for att.output.weight (maybe ffn.value.weight too).
Hi I am the dev of https://github.com/BlinkDL/ChatRWKV and it is a RNN (so faster and saves VRAM) that can match transformer performance (and already scaled to 14B params. more to come).
Let me know if you will be interested in supporting it :) The inference is very simple: https://github.com/BlinkDL/ChatRWKV/blob/main/src/model_run.py