dranger003 / llama.cpp-dotnet

Minimal C# bindings for llama.cpp + .NET core library with API host/client.
MIT License
66 stars 7 forks source link

revisiting the load/save state #28

Open HarvieKrumpet opened 2 days ago

HarvieKrumpet commented 2 days ago

I am a bit of a broken record on this. I have indeed tried I few times to do this on my own within your framework. I got this far with it.

I get stuck because this is not all that is needed for a true state capture. I am supposed to be saving the

 llama_decode(ctx, batch);
    n_past += batch.n_tokens;

I would be almost duplicating your inner loop to get at these values and I sure you could stuff away those n_past values in an array should I want to save the state without doing this as a 2nd pass. and your using functions that I have not seen other implementations use. So I am unsure on my levels what I am doing here.

Maybe your in a comfortable point here to want to tackle it. I know your expecting gerganov to roll this into some unified system. but I am not holding my breath for it. I use up nearly 4k for just my system messages nowadays, so I have an large inherent delay that I want to cache away with the load/save state system. I am not even sure if I implemented the pure load/save state properly.

Thanks,

public void LoadState(User _User, string _Name) {
    if (_User?.Mo?.ModelName is string Model) {
        LLM? llm = GetModel(Model);
        if (llm != null) {
            // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
            string filename = IAmandaServer.StateBase + _Name + ".State";
            if (File.Exists(filename)) {
                //using var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null);
                //using var view = file.CreateViewAccessor();
                unsafe {
                    //byte* ptr = null;
                    //view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
                    try {
                        byte[] State = File.ReadAllBytes(filename);
                        Native.llama_state_set_data(llm.Engine.ContextNativeHandle, State, (nuint)State.Length);
                    } finally {
                        //view.SafeMemoryMappedViewHandle.ReleasePointer();
                    }
                }
            }
        }
    }
}

public void SaveState(User _User, string _Name) {
    if (_User?.Mo?.ModelName is string Model) {
        LLM? llm = GetModel(Model);
        if (llm != null) {
            nint x = llm.Engine.ContextNativeHandle;
            uint ModelSize = (uint)Native.llama_state_get_size(x);
            byte[] data = new byte[ModelSize];
            Native.llama_state_get_data(x, data, ModelSize);
            string filename = IAmandaServer.StateBase + _Name + ".State";
            File.WriteAllBytes(filename, data);
        }
    }
}
dranger003 commented 2 days ago

Ah, I see what you mean. For now, I added a raw sample to test state load/save. I will need to think about this, maybe I introduce a new LlmPrompt ctor signature to support state loading - not sure yet.