tjake / Jlama

Jlama is a modern LLM inference engine for Java
Apache License 2.0
669 stars 62 forks source link

Error while trying to generate from the given prompt #61

Open vaibhav-singh-bi opened 1 month ago

vaibhav-singh-bi commented 1 month ago

Hello jake,

I started playing around with jlama for my pet project and i have been hitting with this exception , could you please take a look ? i am using latest jlama and langchain libraries with latest spring boot libraries. Also I am running the code on Apple M1 Pro. I am calling as below

var chatModel = JlamaChatModel.builder()
                .modelName("tjake/TinyLlama-1.1B-Chat-v1.0-Jlama-Q4")
                .modelCachePath(Path.of("<my local machine path>"))
                .temperature(0.2f) // for focused and deterministic answer
                .build();

var aiMessage = chatModel.generate(prompt.toUserMessage()).content();
ERROR o.a.c.c.C.[.[.[.[dispatcherServlet] - Servlet.service() for servlet [dispatcherServlet] in context with path [] threw exception [Request processing failed: java.lang.ArrayIndexOutOfBoundsException: Index 65536 out of bounds for length 65536] with root cause
java.lang.ArrayIndexOutOfBoundsException: Index 65536 out of bounds for length 65536
    at com.github.tjake.jlama.model.CausalSelfAttention.lambda$forward$6(CausalSelfAttention.java:255)
    at java.base/java.util.Optional.ifPresent(Optional.java:178)
    at com.github.tjake.jlama.model.CausalSelfAttention.forward(CausalSelfAttention.java:237)
    at com.github.tjake.jlama.model.TransformerBlock.forward(TransformerBlock.java:99)
    at com.github.tjake.jlama.model.AbstractModel.forward(AbstractModel.java:235)
    at com.github.tjake.jlama.model.AbstractModel.forward(AbstractModel.java:209)
    at com.github.tjake.jlama.model.AbstractModel.generate(AbstractModel.java:471)
    at dev.langchain4j.model.jlama.JlamaChatModel.generate(JlamaChatModel.java:117)
    at dev.langchain4j.model.jlama.JlamaChatModel.generate(JlamaChatModel.java:70)
    at dev.langchain4j.model.chat.ChatLanguageModel.generate(ChatLanguageModel.java:44)
    at com.zaprit.platform.service.compass.service.ClassificationService.classify(ClassificationService.java:100)
    at com.zaprit.platform.service.compass.controller.ClassificationController.classify(ClassificationController.java:45)
    at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103)
    at java.base/java.lang.reflect.Method.invoke(Method.java:580)
    at org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:255)
    at org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:188)
    at org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:118)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:926)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:831)
    at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
    at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:1089)
    at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:979)
    at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1014)
    at org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:914)
    at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:590)
    at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:885)
    at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:658)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:195)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:140)
    at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:51)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:164)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:140)
    at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:100)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:164)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:140)
    at org.springframework.web.filter.FormContentFilter.doFilterInternal(FormContentFilter.java:93)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:164)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:140)
    at org.springframework.web.filter.ServerHttpObservationFilter.doFilterInternal(ServerHttpObservationFilter.java:113)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:164)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:140)
    at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:201)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:164)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:140)
    at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:167)
    at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:90)
    at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:483)
    at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:115)
    at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:93)
    at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:74)
    at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:344)
    at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:384)
    at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:63)
    at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:905)
    at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1741)
    at org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:52)
    at org.apache.tomcat.util.threads.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1190)
    at org.apache.tomcat.util.threads.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:659)
    at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:63)
    at java.base/java.lang.Thread.run(Thread.java:1583)
tjake commented 1 month ago

Hello!

Looks like you reached the context length of the model. Which can happen for that one as it goes off the rails sometimes (even at 0.2).

I can/should add a better error for this.

tjake commented 1 month ago

In the meantime set a lower maxTokens