deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.11k stars 653 forks source link

sorry, how could I load my own bert qa model #1753

Closed GaiserChan closed 2 years ago

GaiserChan commented 2 years ago

`@Configuration public class BertQaInferenceConfiguration { static final Logger logger = LoggerFactory.getLogger(BertQaInferenceConfiguration.class); @Bean public Criteria<QAInput, String> nlpCriteria() { Criteria<QAInput, String> result = Criteria.builder() .optApplication(Application.NLP.QUESTION_ANSWER) .setTypes(QAInput.class, String.class) .optFilter("backbone", "bert") .optEngine("PyTorch") .optOption("mapLocation", "true") .optDevice(Device.cpu()) .optProgress(new ProgressBar()) .optModelPath(Paths.get("build/input/models")) .optModelName("bi_encoder.pt") .build(); logger.info(result.toString()); return result; }

@Bean
public ZooModel<QAInput, String> nlpModel(@Qualifier("nlpCriteria") Criteria<QAInput, String> criteria)
        throws ModelNotFoundException, MalformedModelException, IOException {
    return criteria.loadModel();
}

@Bean(destroyMethod = "close", value = "nlpPredictor")
@Scope(value = "prototype", proxyMode = ScopedProxyMode.INTERFACES)
public Predictor<QAInput, String> nlpPredictor(ZooModel<QAInput, String> nlpModel) {

    return new Predictor<>(nlpModel, nlpModel.getTranslator(), Device.gpu(), true);
}

@Bean
public Supplier<Predictor<QAInput, String>> nlpPredictorProvider(ZooModel<QAInput, String> nlpModel) {
    return nlpModel::newPredictor;
}

}`

I run before code to load my own model,but it not work.

`Caused by: org.springframework.beans.factory.UnsatisfiedDependencyException: Error creating bean with name 'nlpPredictorProvider' defined in class path resource [com/fusionbank/base/ekyc/configuration/BertQaInferenceConfiguration.class]: Unsatisfied dependency expressed through method 'nlpPredictorProvider' parameter 0; nested exception is org.springframework.beans.factory.BeanCreationException: Error creating bean with name 'nlpModel' defined in class path resource [com/fusionbank/base/ekyc/configuration/BertQaInferenceConfiguration.class]: Bean instantiation via factory method failed; nested exception is org.springframework.beans.BeanInstantiationException: Failed to instantiate [ai.djl.repository.zoo.ZooModel]: Factory method 'nlpModel' threw exception; nested exception is ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.
    at org.springframework.beans.factory.support.ConstructorResolver.createArgumentArray(ConstructorResolver.java:798)
    at org.springframework.beans.factory.support.ConstructorResolver.instantiateUsingFactoryMethod(ConstructorResolver.java:539)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.instantiateUsingFactoryMethod(AbstractAutowireCapableBeanFactory.java:1338)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.createBeanInstance(AbstractAutowireCapableBeanFactory.java:1177)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.doCreateBean(AbstractAutowireCapableBeanFactory.java:557)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.createBean(AbstractAutowireCapableBeanFactory.java:517)
    at org.springframework.beans.factory.support.AbstractBeanFactory.lambda$doGetBean$0(AbstractBeanFactory.java:323)
    at org.springframework.beans.factory.support.DefaultSingletonBeanRegistry.getSingleton(DefaultSingletonBeanRegistry.java:222)
    at org.springframework.beans.factory.support.AbstractBeanFactory.doGetBean(AbstractBeanFactory.java:321)
    at org.springframework.beans.factory.support.AbstractBeanFactory.getBean(AbstractBeanFactory.java:207)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.resolveBeanByName(AbstractAutowireCapableBeanFactory.java:454)
    at org.springframework.context.annotation.CommonAnnotationBeanPostProcessor.autowireResource(CommonAnnotationBeanPostProcessor.java:543)
    at org.springframework.context.annotation.CommonAnnotationBeanPostProcessor.getResource(CommonAnnotationBeanPostProcessor.java:513)
    at org.springframework.context.annotation.CommonAnnotationBeanPostProcessor$ResourceElement.getResourceToInject(CommonAnnotationBeanPostProcessor.java:653)
    at org.springframework.beans.factory.annotation.InjectionMetadata$InjectedElement.inject(InjectionMetadata.java:224)
    at org.springframework.beans.factory.annotation.InjectionMetadata.inject(InjectionMetadata.java:116)
    at org.springframework.context.annotation.CommonAnnotationBeanPostProcessor.postProcessProperties(CommonAnnotationBeanPostProcessor.java:334)
    ... 30 common frames omitted
Caused by: org.springframework.beans.factory.BeanCreationException: Error creating bean with name 'nlpModel' defined in class path resource [com/fusionbank/base/ekyc/configuration/BertQaInferenceConfiguration.class]: Bean instantiation via factory method failed; nested exception is org.springframework.beans.BeanInstantiationException: Failed to instantiate [ai.djl.repository.zoo.ZooModel]: Factory method 'nlpModel' threw exception; nested exception is ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.
    at org.springframework.beans.factory.support.ConstructorResolver.instantiate(ConstructorResolver.java:656)
    at org.springframework.beans.factory.support.ConstructorResolver.instantiateUsingFactoryMethod(ConstructorResolver.java:636)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.instantiateUsingFactoryMethod(AbstractAutowireCapableBeanFactory.java:1338)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.createBeanInstance(AbstractAutowireCapableBeanFactory.java:1177)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.doCreateBean(AbstractAutowireCapableBeanFactory.java:557)
    at org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory.createBean(AbstractAutowireCapableBeanFactory.java:517)
    at org.springframework.beans.factory.support.AbstractBeanFactory.lambda$doGetBean$0(AbstractBeanFactory.java:323)
    at org.springframework.beans.factory.support.DefaultSingletonBeanRegistry.getSingleton(DefaultSingletonBeanRegistry.java:222)
    at org.springframework.beans.factory.support.AbstractBeanFactory.doGetBean(AbstractBeanFactory.java:321)
    at org.springframework.beans.factory.support.AbstractBeanFactory.getBean(AbstractBeanFactory.java:202)
    at org.springframework.beans.factory.config.DependencyDescriptor.resolveCandidate(DependencyDescriptor.java:276)
    at org.springframework.beans.factory.support.DefaultListableBeanFactory.doResolveDependency(DefaultListableBeanFactory.java:1287)
    at org.springframework.beans.factory.support.DefaultListableBeanFactory.resolveDependency(DefaultListableBeanFactory.java:1207)
    at org.springframework.beans.factory.support.ConstructorResolver.resolveAutowiredArgument(ConstructorResolver.java:885)
    at org.springframework.beans.factory.support.ConstructorResolver.createArgumentArray(ConstructorResolver.java:789)
    ... 46 common frames omitted
Caused by: org.springframework.beans.BeanInstantiationException: Failed to instantiate [ai.djl.repository.zoo.ZooModel]: Factory method 'nlpModel' threw exception; nested exception is ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.
    at org.springframework.beans.factory.support.SimpleInstantiationStrategy.instantiate(SimpleInstantiationStrategy.java:185)
    at org.springframework.beans.factory.support.ConstructorResolver.instantiate(ConstructorResolver.java:651)
    ... 60 common frames omitted
Caused by: ai.djl.repository.zoo.ModelNotFoundException: No matching model with specified Input/Output type found.
    at ai.djl.repository.zoo.Criteria.loadModel(Criteria.java:178)
    at com.fusionbank.base.ekyc.configuration.BertQaInferenceConfiguration.nlpModel(BertQaInferenceConfiguration.java:50)
    at com.fusionbank.base.ekyc.configuration.BertQaInferenceConfiguration$$EnhancerBySpringCGLIB$$55b0440a.CGLIB$nlpModel$0(<generated>)
    at com.fusionbank.base.ekyc.configuration.BertQaInferenceConfiguration$$EnhancerBySpringCGLIB$$55b0440a$$FastClassBySpringCGLIB$$e827f3ac.invoke(<generated>)
    at org.springframework.cglib.proxy.MethodProxy.invokeSuper(MethodProxy.java:244)
    at org.springframework.context.annotation.ConfigurationClassEnhancer$BeanMethodInterceptor.intercept(ConfigurationClassEnhancer.java:363)
    at com.fusionbank.base.ekyc.configuration.BertQaInferenceConfiguration$$EnhancerBySpringCGLIB$$55b0440a.nlpModel(<generated>)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:567)
    at org.springframework.beans.factory.support.SimpleInstantiationStrategy.instantiate(SimpleInstantiationStrategy.java:154)
    ... 61 common frames omitted`
frankfliu commented 2 years ago

@GaiserChan You need to provide a Translator to load your own model:

PtBertQATranslator translator = PtBertQATranslator.builder().opeTokenizerName("distilbert").toLowerCase(true).build();
Criteria<QAInput, String> result = Criteria.builder()
    .setTypes(QAInput.class, String.class)
    .optEngine("PyTorch")
    .optOption("mapLocation", "true")
    .optModelPath(Paths.get("build/input/models"))
    .optModelName("bi_encoder")
    .optTranslator(translator)
    .build();
KexinFeng commented 2 years ago

@GaiserChan Here is an example of how the translator is implemented: https://pub.towardsai.net/deploy-huggingface-nlp-models-in-java-with-deep-java-library-e36c635b2053

You can also look into the example ai/djl/examples/inference/BertQaInference.java

frankfliu commented 2 years ago

Feel free to reopen this issue if you still have questions.