tjake / Jlama

Jlama is a modern LLM inference engine for Java
Apache License 2.0
499 stars 48 forks source link

./download-hf-model.sh wont work on Windows #10

Closed skanga closed 7 months ago

skanga commented 8 months ago

Writing it in Java itself will solve that. So I did. Happy to contribute it so here you go FWIW

package com.github.tjake.jlama.cli;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class DownloadModel {
    private static final String HF_ACCESS_TOKEN = System.getenv("HF_ACCESS_TOKEN");
    private static final String MODEL_DIR = "models";

    public static void main(String[] args) throws IOException {
        if (args.length != 1) {
            usage();
            System.exit(1);
        }

        String hfModel = args[0];
        String authHeader = null;
        if (HF_ACCESS_TOKEN != null && !HF_ACCESS_TOKEN.isEmpty()) {
            authHeader = "Authorization: Bearer " + HF_ACCESS_TOKEN;
        }

        InputStream modelInfoStream = getResponse("https://huggingface.co/api/models/" + hfModel, authHeader);
        String modelInfo = readInputStream(modelInfoStream);

        if (modelInfo == null) {
            System.out.println("No valid model found or trying to access a restricted model (use HF_ACCESS_TOKEN env. var.)");
            System.exit(1);
        }

        List<String> allFiles = parseFileList(modelInfo);
        if (allFiles.isEmpty()) {
            System.out.println("No valid model found");
            System.exit(1);
        }

        List<String> tensorFiles = new ArrayList<>();
        for (String currFile : allFiles) {
            if (currFile.contains("safetensor")) {
                tensorFiles.add(currFile);
            }
        }

        if (tensorFiles.isEmpty()) {
            System.out.println("Model is not available in safetensor format");
            System.exit(1);
        }

        allFiles.addAll(Arrays.asList("config.json", "vocab.json", "tokenizer.json"));

        Path modelDir = Paths.get(MODEL_DIR, hfModel);
        try {
            Files.createDirectories(modelDir);
        } catch (IOException e) {
            System.out.println("Error creating directory: " + modelDir);
            System.exit(1);
        }

        for (String currFile : allFiles) {
            System.out.println("Downloading file: " + modelDir.resolve(currFile));
            downloadFile(hfModel, currFile, authHeader, modelDir.resolve(currFile));
        }

        System.out.println("Downloading file: " + modelDir.resolve("tokenizer.model") + " (if it exists)");
        downloadFile(hfModel, "tokenizer.model", authHeader, modelDir.resolve("tokenizer.model"));

        System.out.println("Done! Model downloaded in ./" + MODEL_DIR + "/" + hfModel);
    }

    private static void usage() {
        System.out.println("""
                usage: java DownloadModel [-h] owner/model_name

                This program will download a safetensor files and inference configuration from huggingface.
                To download restricted models set the HF_ACCESS_TOKEN environment variable to a valid HF access token.
                To create a token see https://huggingface.co/settings/tokens

                OPTIONS:
                   -h   Show this message

                EXAMPLES:
                    java DownloadModel gpt2-medium
                    java DownloadModel meta-llama/Llama-2-7b-chat-hf""");
    }

    private static List<String> parseFileList(String modelInfo) {
        List<String> fileList = new ArrayList<>();
        try {
            ObjectMapper objectMapper = new ObjectMapper();
            JsonNode rootNode = objectMapper.readTree(modelInfo);
            JsonNode siblingsNode = rootNode.path("siblings");
            if (siblingsNode.isArray()) {
                for (JsonNode siblingNode : siblingsNode) {
                    String rFilename = siblingNode.path("rfilename").asText();
                    fileList.add(rFilename);
                }
            }
        } catch (IOException e) {
            System.out.println("Error parsing JSON: " + e.getMessage());
        }
        return fileList;
    }

    public static InputStream getResponse(String urlString, String authHeader) {
        try {
            URL url = new URL(urlString);
            HttpURLConnection connection = (HttpURLConnection) url.openConnection();

            // Set the request method
            connection.setRequestMethod("GET");

            // Set the request header
            if (authHeader != null)
                connection.setRequestProperty("Authorization", authHeader);

            // Get the response code
            int responseCode = connection.getResponseCode();

            if (responseCode == HttpURLConnection.HTTP_OK) {
                // If the response code is 200 (HTTP_OK), return the input stream
                return connection.getInputStream();
            } else {
                // If the response code is not 200, throw an IOException
                throw new IOException("HTTP response code: " + responseCode);
            }
        }
        catch (IOException ioe)
        {
            System.out.println("WARNING: Fetch of URL " + urlString + " failed due to " + ioe);
            return null;
        }
    }

    public static String readInputStream(InputStream inStream) throws IOException {
        if (inStream == null) return null;

        BufferedReader inReader = new BufferedReader(new InputStreamReader(inStream));
        StringBuilder stringBuilder = new StringBuilder();

        String currLine;
        while ((currLine = inReader.readLine()) != null) {
            stringBuilder.append(currLine);
            stringBuilder.append(System.lineSeparator());
        }

        return stringBuilder.toString();
    }
    private static void downloadFile(String hfModel, String currFile, String authHeader, Path outputPath) throws IOException {
        InputStream inStream = getResponse("https://huggingface.co/" + hfModel + "/resolve/main/" + currFile, authHeader);
        if (inStream == null)
            throw new IOException("WARNING: Fetch of file " + currFile + " failed.");
        Files.copy(inStream, outputPath, StandardCopyOption.REPLACE_EXISTING);
    }
}
skanga commented 8 months ago

Feel free to incorporate this a "DownloadCommand" within JlamaCli, etc or leave it as a standalone CLI program

skanga commented 8 months ago

Also for the run-cli.sh stuff I just created a run-cli.bat file like this:

java -server -Dstdout.encoding=UTF-8 -Xmx12g -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 --add-modules=jdk.incubator.vector --add-exports java.base/sun.nio.ch=ALL-UNNAMED --enable-preview --enable-native-access=ALL-UNNAMED -XX:+UnlockDiagnosticVMOptions -XX:CompilerDirectivesFile=./inlinerules.json -XX:+AlignVector -XX:+UseStringDeduplication -XX:+UseCompressedOops -XX:+UseCompressedClassPointers -Dlogback.configurationFile=./conf/logback.xml -jar ./jlama-cli/target/jlama-cli.jar %*
skanga commented 8 months ago

Ah! Looks like it uses some native library and won't work anyways on Windows :-(

14:47:34.937 [main] WARN  c.g.t.j.t.o.TensorOperationsProvider - Error loading native operations
java.lang.UnsatisfiedLinkError: no jlama in java.library.path: c:\java\jdk-21.0.1\bin;C:\WINDOWS\Sun\Java\bin;C:\WINDOWS\system32;C:\WINDOWS;c:\java\jdk-21.0.1\bin;C:\WINDOWS\system32;C:\WINDOWS;C:\WINDOWS\System32\Wbem;C:\WINDOWS\System32\WindowsPowerShell\v1.0\;.
        at java.base/java.lang.ClassLoader.loadLibrary(ClassLoader.java:2458)
        at java.base/java.lang.Runtime.loadLibrary0(Runtime.java:916)
        at java.base/java.lang.System.loadLibrary(System.java:2059)
        at com.github.tjake.jlama.tensor.operations.cnative.RuntimeHelper.<clinit>(RuntimeHelper.java:43)
        at com.github.tjake.jlama.tensor.operations.cnative.NativeSimd.<clinit>(NativeSimd.java:19)
        at com.github.tjake.jlama.tensor.operations.NativeTensorOperations.<clinit>(NativeTensorOperations.java:13)
        at java.base/java.lang.Class.forName0(Native Method)
        at java.base/java.lang.Class.forName(Class.java:421)
        at java.base/java.lang.Class.forName(Class.java:412)
        at com.github.tjake.jlama.tensor.operations.TensorOperationsProvider.pickFastestImplementation(TensorOperationsProvider.java:39)
        at com.github.tjake.jlama.tensor.operations.TensorOperationsProvider.<init>(TensorOperationsProvider.java:31)
        at com.github.tjake.jlama.tensor.operations.TensorOperationsProvider.get(TensorOperationsProvider.java:22)
        at com.github.tjake.jlama.cli.JlamaCli.<clinit>(JlamaCli.java:26)
14:47:34.944 [main] INFO  c.g.tjake.jlama.util.MachineSpec - Machine Vector Spec: AVX_256
14:47:34.945 [main] INFO  c.g.tjake.jlama.util.MachineSpec - Byte Order: LITTLE_ENDIAN
14:47:34.957 [main] INFO  c.g.t.j.t.o.TensorOperationsProvider - Using Panama Vector Operations
14:47:35.982 [main] INFO  c.g.tjake.jlama.model.AbstractModel - Working memory type = F32, Quantized memory type = I8
14:47:37.495 [main] INFO  c.g.t.jlama.model.llama.LlamaModel - Model loaded!
14:47:37.701 [main] INFO  o.jboss.resteasy.resteasy_jaxrs.i18n - RESTEASY002225: Deploying javax.ws.rs.core.Application: class com.github.tjake.jlama.cli.serve.JlamaRestApi
14:47:37.704 [main] INFO  o.jboss.resteasy.resteasy_jaxrs.i18n - RESTEASY002205: Adding provider class org.jboss.resteasy.plugins.providers.jackson.ResteasyJackson2Provider from Application class com.github.tjake.jlama.cli.serve.JlamaRestApi
tjake commented 8 months ago

Very cool thank you @skanga I'll add this to the cli as you suggest

RobertoMalatesta commented 8 months ago

Ah! Looks like it uses some native library and won't work anyways on Windows :-(

Maybe you can try with Windows WSL. Good work @skanga!

R

tjake commented 7 months ago

Hey @skanga I added a modified version of this to the cli. Thanks for the help!

skanga commented 6 months ago

Hey @skanga I added a modified version of this to the cli. Thanks for the help!

Hey thanks @tjake

I have figured out how to build Jlama easily on Windows (natively without WSL). Do you want me to add instructions somewhere? In the readme? Here?

skanga commented 6 months ago

I will just mention here anyway ...

:: Install w64devkit & add to path - https://github.com/skeeto/w64devkit/releases

wget https://github.com/skeeto/w64devkit/releases/download/v1.21.0/w64devkit-fortran-1.21.0.zip
unzip w64devkit-fortran-1.21.0.zip
set PATH=<wherever-you-installed-it>\w64devkit\bin;%PATH%

mvn clean package

run-cli download tjake/llama2-7b-chat-hf-jlama-Q4
run-cli serve models/llama2-7b-chat-hf-jlama-Q4