deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.14k stars 658 forks source link

Inference tensors cannot be saved for backward. #2144

Closed SuperMaskv closed 1 year ago

SuperMaskv commented 2 years ago

I don't know if this is a bug. I'm trying to follow the official tutorial using pytorch engine.

Here is my code and exception.

String modelUrl = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
Criteria<NDList, NDList> criteria = Criteria.builder()
    .optApplication(Application.NLP.WORD_EMBEDDING)
    .setTypes(NDList.class, NDList.class)
    .optModelUrls(modelUrl)
    .optProgress(new ProgressBar())
    .build();

ZooModel<NDList, NDList> embedding = criteria.loadModel();
Predictor<NDList, NDList> embedder = embedding.newPredictor();
SequentialBlock classifier = new SequentialBlock()
    .add(
            ndList -> {
                NDArray data = ndList.singletonOrThrow();
                long batchSize = data.getShape().get(0);
                long maxLen = data.getShape().get(1);
                NDList inputs = new NDList();
                inputs.add(data.toType(DataType.INT64, false));
                inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
                inputs.add(data.getManager().arange(maxLen).toType(DataType.INT64, false).broadcast(data.getShape()));
                try {
                    return embedder.predict(inputs);
                } catch (TranslateException e) {
                    throw new RuntimeException(e);
                }
            }
    )
    .add(Linear.builder().setUnits(768).build())
    .add(Activation::relu)
    .add(Dropout.builder().optRate(0.2f).build())
    .add(Linear.builder().setUnits(5).build())
    .addSingleton(nd -> nd.get(":,0"));

Model model = Model.newInstance("review_classification");
model.setBlock(classifier);

DefaultVocabulary vocabulary = DefaultVocabulary.builder()
    .addFromTextFile(embedding.getArtifact("vocab.txt"))
    .optUnknownToken("[UNK]")
    .build();

int maxTokenLen = 64;
int batchSize = 8;
int limit = Integer.MAX_VALUE;

BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset awsDataset = getDataset(batchSize, tokenizer, maxTokenLen, limit);
RandomAccessDataset[] datasets = awsDataset.randomSplit(7, 3);
RandomAccessDataset trainDataset = datasets[0];
RandomAccessDataset evalDataset = datasets[1];

SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
listener.setSaveModelCallback(
    trainer -> {
        TrainingResult result = trainer.getTrainingResult();
        Model trainerModel = trainer.getModel();
        float acc = result.getValidateEvaluation("Accuracy");
        trainerModel.setProperty("Accuracy", String.format("%.5f", acc));
        trainerModel.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
    }
);

DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    .addEvaluator(new Accuracy())
    .addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
    .addTrainingListeners(listener);

int epoch = 2;
Trainer trainer = model.newTrainer(trainingConfig);
trainer.setMetrics(new Metrics());
Shape shape = new Shape(batchSize, maxTokenLen);
trainer.initialize(shape);
EasyTrain.fit(trainer, epoch, trainDataset, evalDataset);
System.out.println(trainer.getTrainingResult());
model.save(Paths.get("build/model"), "aws-review-rank");
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 6
[main] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 6
[main] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().
[main] INFO ai.djl.training.listener.LoggingTrainingListener - Load PyTorch Engine Version 1.12.1 in 0.079 ms.
Exception in thread "main" ai.djl.engine.EngineException: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
    at ai.djl.pytorch.jni.PyTorchLibrary.torchNNLinear(Native Method)
    at ai.djl.pytorch.jni.JniUtils.linear(JniUtils.java:1189)
    at ai.djl.pytorch.engine.PtNDArrayEx.linear(PtNDArrayEx.java:390)
    at ai.djl.nn.core.Linear.linear(Linear.java:183)
    at ai.djl.nn.core.Linear.forwardInternal(Linear.java:88)
    at ai.djl.nn.AbstractBaseBlock.forwardInternal(AbstractBaseBlock.java:126)
    at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91)
    at ai.djl.nn.SequentialBlock.forwardInternal(SequentialBlock.java:209)
    at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:91)
    at ai.djl.training.Trainer.forward(Trainer.java:175)
    at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:122)
    at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110)
    at ai.djl.training.EasyTrain.fit(EasyTrain.java:58)
    at cn.amberdata.misc.djl.rankcls.Main.main(Main.java:114)

And here is my dependencies.

        <!-- https://mvnrepository.com/artifact/ai.djl/api -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.19.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/org.slf4j/slf4j-simple -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>1.7.36</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.19.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/ai.djl/basicdataset -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.19.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/ai.djl/model-zoo -->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.19.0</version>
        </dependency>
siddvenk commented 2 years ago

Thanks for pointing out this bug to us - this is an issue in the NoopTranslator when using the PyTorch engine. We will take a look and determine the best fix.

For now, you can adjust your code as follows to get around the exception:

  1. Define your own custom translator that extends the NoopTranslator and overrides the processOutput method

    public static final class MyTranslator extends NoopTranslator {
    
        @Override
        public NDList processOutput(TranslatorContext ctx, NDList input) {
            return new NDList(
                input.stream().map(ndArray -> ndArray.duplicate()).collect(Collectors.toList()));
        }
    }
  2. When you are building the Criteria object, add the optTranslator builder method and pass in the customer translator
    Criteria<NDList, NDList> criteria = Criteria.builder()
    .optApplication(Application.NLP.WORD_EMBEDDING)
    .setTypes(NDList.class, NDList.class)
    .optModelUrls(modelUrl)
    .optProgress(new ProgressBar())
    .optTranslator(new MyTranslator()) // Add this line
    .build();
frankfliu commented 2 years ago

@SuperMaskv

The problem comes from return embedder.predict(inputs);, the predict() assume inference only, it set the Inference mode. What you can do is instead of manually calling embedder.predict(), you can add the block into your classifer, something like:

SequentialBlock classifier = new SequentialBlock()
    .add(
            ndList -> {
                NDArray data = ndList.singletonOrThrow();
                long batchSize = data.getShape().get(0);
                long maxLen = data.getShape().get(1);
                NDList inputs = new NDList();
                inputs.add(data.toType(DataType.INT64, false));
                inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
                inputs.add(data.getManager().arange(maxLen).toType(DataType.INT64, false).broadcast(data.getShape()));
                return inputs;
            })
    .add(embedder.getBlock());
SuperMaskv commented 2 years ago

@frankfliu Thanks for helping. I have tried ur solution. And here comes out a new problem with initializing parameter. Here is exception stack trace:

Exception in thread "main" java.lang.NullPointerException: No parameter shape has been set
    at java.util.Objects.requireNonNull(Objects.java:228)
    at ai.djl.nn.Parameter.initialize(Parameter.java:181)
    at ai.djl.nn.AbstractBaseBlock.initialize(AbstractBaseBlock.java:182)
    at ai.djl.nn.SequentialBlock.initializeChildBlocks(SequentialBlock.java:224)
    at ai.djl.nn.AbstractBaseBlock.initialize(AbstractBaseBlock.java:184)
    at ai.djl.training.Trainer.initialize(Trainer.java:117)
    at cn.amberdata.misc.djl.rankcls.Main.main(Main.java:72)

It seems that djl do not find a suitable implements for function prepare in AbstractBaseBlock.


@siddvenk Code in tutorial works fine with ur solution, thanks a lot.

adepase commented 1 year ago

I'm in exactly the same point. I cannot follow the fix for #2173. because it freezes the original distilbert model, while I need to finetune all the layers. @frankfliu setup seemed to me good, since it just put the model block in the sequence and nothing seems preventing it to be part of the backward. But @SuperMaskv last NullPointerError blocks the initialize. It initializes correctly the input of the sequenceBlock, then it finds a PtSymbolBlock. In AbstractBaseBlock.initialize() it checks:

if (!isInitialized()) { // setShape for all params prepare(inputShapes); }

Here we have some issues:

  1. isInitialized checks:

for (Parameter param : getParameters().values()) { if (!param.isInitialized()) { return false; } }

and they are initialized (but it checks for arrays == null, not shape == null - maybe it's a consequence, but this is a different check and shape == null will be the problem, as we'll see in a while). So it doesn't execute prepare(inputShapes).

But later, in AbstractBaseBlock.initialize() in the loop:

for (Parameter parameter : getDirectParameters().values()) { parameter.initialize(manager, dataType); }

it finds that the first parameter has shape null, so it throws the error. BTW: I don't know if getDirectParameters() and getParameters() return the same result.

  1. let's suppose that isInitialized is false (I commented the check to simulate it): sadly prepare() is not abstract, but empty and PtSymbolBlock doesn't implement it. So it cannot set shape in parameters.

Since I understand that shapes should be the input/output of every layer, it seems to me that somewhere they should be set correctly if it's true (and it's true, as in the AmazonRating example) that the same block, if frozen, works.

Maybe is there a way to use that code (wherever it may be)? I hope my comment could help in solving this issue without creating confusion.

KexinFeng commented 1 year ago

@adepase Could you provide the code you have tried? Especially your implementation of not freezing the embedding layer?

The problems you ran into here are not too different from the example examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java. In this example, not freezing the embedding layer is enabled. My idea is similar to this example. But I wasn't able to reproduce your error. So if you could provide the code where the error is reproducible that would be helpful.

Please take a look at the PR below to see if it works for you.

adepase commented 1 year ago

@KexinFeng Sure, it's almost the same as the one provided before by @SuperMaskv , with the incremental corrections, but let's put it here to have a refreshed and summarized view:

package it.algaware.mrjvs.djl;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.input.BOMInputStream;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.engine.Engine;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.Classifications.Classification;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.LinearTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.training.tracker.WarmUpTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;

public class Test2 {

    public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {

        System.out.println("You are using: " + Engine.getInstance().getEngineName() + " Engine");
        System.out.println("GPU count: " + Engine.getInstance().getGpuCount());
        Device[] pass1 = Engine.getInstance().getDevices();
        System.out.println("GPU: "+pass1[0]);

        getIntentAll(2, "C:\\clinc150_uci\\train",
                "C:\\clinc150_uci\\test", "feature", "label");

    }

    public static void getIntentAll(int epoch, String trainingPath, String validationPath, String feature, String label) throws ModelNotFoundException, MalformedModelException, IOException {
        // let's find intents, first
        ArrayList<String> intents = new ArrayList<>();

        File file = new File(trainingPath);
        Stream<String> stream = null;
        int trainingSetSize = 0;
        try {
            stream = Files.lines(file.toPath(), StandardCharsets.UTF_8);
            Iterator<String> iterator = stream.iterator();
            boolean headerFound = false;
            while (iterator.hasNext()) {
                trainingSetSize+=1;
                String line = iterator.next();
                if (headerFound) {
                    String[] values = line.split("[|]");
                    if (!intents.contains(values[0].toLowerCase())) {
                        intents.add(values[0].toLowerCase());
                    }
                } else {
                    headerFound = true;
                }
            }
        } finally {
            if (stream != null) {
                stream.close();
                stream = null;
            }
        }

        file = new File(validationPath);
        try {
            stream = Files.lines(file.toPath());
            Iterator<String> iterator = stream.iterator();
            boolean headerFound = false;
            while (iterator.hasNext()) {
                String line = iterator.next();
                if (headerFound) {
                    String[] values = line.split("[|]");
                    if (!intents.contains(values[0].toLowerCase())) {
                        intents.add(values[0].toLowerCase());
                    }
                } else {
                    headerFound = true;
                }
            }
        } finally {
            if (stream != null) {
                stream.close();
                stream = null;
            }
        }

        System.out.println("found "+intents.size()+" intents");

        Criteria<NDList, NDList> criteria = Criteria.builder()
                .optApplication(Application.NLP.WORD_EMBEDDING)
                .setTypes(NDList.class, NDList.class)
                //.optModelUrls(modelUrls)
                .optModelPath(Paths.get("build/pytorch/traced_distilbert_wikipedia_uncased"))
                .optProgress(new ProgressBar())
                .optTranslator(new MyNoopTranslator())
                .build();
        ZooModel<NDList, NDList> embedding = criteria.loadModel();

        System.out.println("Model Loaded");

        DefaultVocabulary vocabulary = DefaultVocabulary.builder()
                .addFromTextFile(embedding.getArtifact("vocab.txt"))
                .optUnknownToken("[UNK]")
                .build();

        int maxTokenLength = 128; 
        int batchSize = 32;
        int limit = Integer.MAX_VALUE;

        BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
        CsvDataset trainingDataset;
        CsvDataset validationDataset;
        try {
            trainingDataset = getIntentDataset(batchSize, tokenizer, maxTokenLength, limit, trainingPath, feature, label, intents, true);

            RandomAccessDataset trainingSet = trainingDataset;
            RandomAccessDataset validationSet = null;
            if (validationPath != null) {
                validationDataset = getIntentDataset(batchSize, tokenizer, maxTokenLength, limit, validationPath, feature, label, intents, false);
                validationSet = validationDataset;
            } else {
                RandomAccessDataset[] datasets = trainingDataset.randomSplit(7, 3);
                trainingSet = datasets[0];
                validationSet = datasets[1];
            }

            System.out.println("Vocabulary and dataset prepared");

            SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
            listener.setSaveModelCallback(
                trainer -> {
                    TrainingResult result = trainer.getTrainingResult();
                    Model model2 = trainer.getModel();
                    // track for accuracy and loss
                    float accuracy = result.getValidateEvaluation("Accuracy");
                    model2.setProperty("Accuracy", String.format("%.5f", accuracy));
                    model2.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
                });

            int epochStep = trainingSetSize/batchSize+1;
            LinearTracker lit = LinearTracker.builder()
                    .setBaseValue(0.00115f)
                    .optSlope(-0.00003f/epochStep) 
                    .optMinValue(0.00004f) // reached at epoch 32
                    .build();

            WarmUpTracker warmuptr = Tracker.warmUp()
                    .optWarmUpSteps(epoch/10*epochStep)
                    .optWarmUpBeginValue(0.001f)
                    .setMainTracker(lit)
                    .build();

            Optimizer adam = Adam.builder().optLearningRateTracker(warmuptr).build();

            Loss loss = Loss.softmaxCrossEntropyLoss();
            DefaultTrainingConfig config = new DefaultTrainingConfig(loss) // loss type
            .optOptimizer(adam) // Optimizer
            .addEvaluator(new Accuracy())
            .optDevices(Engine.getInstance().getDevices(1)) // train using single GPU
            .addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
            .addTrainingListeners(listener);
            System.out.println("Training config set up");

            Block classifier = new SequentialBlock()
                    .add(embedding.getBlock())      
                    // classification layer
                    .add(Dropout.builder().optRate(0.5f).build())
                    .add(Linear.builder().setUnits(728).build()) // pre classifier
                    .add(Activation::relu)
                    .add(Dropout.builder().optRate(0.5f).build())
                    .add(Linear.builder().setUnits(intents.size()).build()) // as many output as intents are
                    .addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
            Model model = Model.newInstance("IntentClassification");
            model.setBlock(classifier);
            System.out.println("Embedding generated");

            Trainer trainer = model.newTrainer(config);
            trainer.setMetrics(new Metrics());
            Shape encoderInputShape = new Shape(batchSize, maxTokenLength);

            trainer.initialize(encoderInputShape);
            EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
            System.out.println("Training ended: "+trainer.getTrainingResult());

            model.save(Paths.get("build/model"), "bert-intent.param");
            System.out.println("model saved");

        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    static CsvDataset getIntentDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit, String path, String feature, String label, ArrayList<String> intents, boolean randomSample) {
        float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
        return CsvDataset.builder()
                .optCsvFile(Paths.get(path)) // load from File
                .setCsvFormat(CSVFormat.Builder.create(CSVFormat.TDF)
                        .setQuote(null)
                        .setDelimiter("|")
                        .setRecordSeparator("\n")
                        .setHeader()
                        .build()) // Setting loading format
                .setSampling(batchSize, randomSample) // make sample size and random access
                .optLimit(limit)
                .addFeature(
                        new Feature(feature, new BertFeaturizer(tokenizer, maxLength))
                        )
                .addLabel(
                        new Feature(
                                label, (buf, data) -> {
                                    float intentNumber = intents.indexOf(data.toLowerCase());
                                    buf.put(intentNumber);
                                }))
                .optDataBatchifier(
                        PaddingStackBatchifier.builder()
                                .optIncludeValidLengths(false)
                                .addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
                                .build()) // define how to pad dataset to a fix length
                .build();
    }

}

I looked at your suggested example, but, sadly, I cannot spot logical differences, so I'm hoping in an help from you.

adepase commented 1 year ago

@KexinFeng Do we need your #2203 PR to succeed in fine tuning the model ?

adepase commented 1 year ago

I'm trying to repeat results from https://aclanthology.org/2021.naacl-industry.38.pdf for Distilbert and CLINC150: hyperparameters don't match (are really very different!) to those of the paper, but probably it's due to the fact I didn't understand that I was using a freezed block (as in the AmazonExample, also at the beginning of this post). I expect that the parameters of the paper will match the final ones to be used if we can fine tune the Distilbert model (probably results will not, but only because they're using AdamW and not Adam - I opened #2185 enhancement for that). Interesting, even with a frozen BERT you can reach about a 91% in accuracy, with just a simple classifier on top.

Well, for your knowledge, there could be probably another minor issue (I'll check it again before opening it and it is not so important for the current input, but it seems to me useful to highlight here for those reading) in the current DJL framework to correctly reach the results of the cited paper. When you use batches you end up padding tokens, at least to align sentences to the same length (the maximum length of sentences in the batch) in the specific batch (while you can have different maximum length in different batches).

This modifies the input sentences lengths, while BERT model are not trained for fixed length, so I expect a possible degradation.

Indeed, trying a post-training workaround, consisting in padding till an predefined length (only if the sentence length is minor than the predefined one) the same input sentences used for test and running them (execution, not training) on the final model manually, you obtain a better accuracy than without the padding.

KexinFeng commented 1 year ago

@KexinFeng Do we need your #2203 PR to succeed in fine tuning the model ?

@adepase Yes, #2203 PR is about enabling fine tuning the embeddor layer, in this case the distilbert model. So you can take a look at it, and experiment with trainParam="true" when setting the Criteria.builder, used to load distilbert, to see if the metric is improved.

Another headup is that you might need to assign a smaller learning rate to the pre-trained layer (ie embedding layer), which makes sure it is "fine" tuning. This implementation is demonstrated again in examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java.

By the way, the major difference in this #2203 PR is that the model now consists of a Lambda block, then the distilbert, then the following classifier layers. Previously Lambda block and the distilbert are together in one layer.

KexinFeng commented 1 year ago

In reply to https://github.com/deepjavalibrary/djl/issues/2144#issuecomment-1333468446 @adepase

I see that you are trying to reproduce the results in https://aclanthology.org/2021.naacl-industry.38.pdf for Distilbert and CLINC150. In this paper the best accuracy for this task is 85.7%. Is it already better than: "with a frozen BERT you can reach about a 91% in accuracy"?

Maybe you can first experiment with the unfrozen BERT, as you planed, and then see if it gets better.

By the way, you mentioned "hyperparameters don't match (are really very different!) to those of the paper,". Even though the unfreezing of the layer is enabled, it is still important to make sure hyperparamters, as well as the model and pre-processing agree to the paper, in order to reproduce or make sensible comparison of the results.

adepase commented 1 year ago

@KexinFeng In this paper the best accuracy for this task is 85.7%. Is it already better than: "with a frozen BERT you can reach about a 91% in accuracy"?

mmm, no, the 85.7% you mentioned appears in the difficult case (table 3), but Distilbert in the full case scores a 96.3% and that's what I'm trying to emulate now and where I reach a 91%, so I'm missing about a 5%.

Moreover, using few example tests (e.g. Curekart full, see table 7, where BERT reaches 83.6%) the gap is even greater (Distilbert frozen reaches about a 29%, absolutely not comparable and explainable probably only with the missing full training of all layers, i.e. with Distilbert not frozen).

So experimenting with the unfrozen model is a must.

By the way, you mentioned "hyperparameters don't match (are really very different!) to those of the paper,". Even though the unfreezing of the layer is enabled, it is still important to make sure hyperparamters, as well as the model and pre-processing agree to the paper, in order to reproduce or make sensible comparison of the results.

Obviously, it's important to use the same hyperparameters with the unfrozen case, but with the frozen one the paper hyperparameters perform absolutely worse than the values you see in the code. That's the reason why you see them in the code. My notice in the previous comment was only to explain that difference, it wasn't intended to highlight an issue. I expect to find similar results with similar hyperparameters in the unfrozen case.

Still, I'm not so sure to understand what should I do to make the unfrozen case work

adepase commented 1 year ago

@KexinFeng Sorry, I'm not so sure to understand. Let's try point by point:

Yes, https://github.com/deepjavalibrary/djl/pull/2203 PR is about enabling fine tuning the embeddor layer, in this case the distilbert model.

So, I should wait for the next release or is it enough that I just take the 2 classes you modified and put this new version of them in place (before in the classpath, i.e., in my project) to test if this corrects the NullPointerException occurred to me and to @SuperMaskv (as in his last comment)? Or do you think that there is some error in the example code I put in my previous comment as you requested?

So you can take a look at it, and experiment with trainParam="true" when setting the Criteria.builder, used to load distilbert, to see if the metric is improved.

At this time, more then the improvement of the metric (I understand you're referring to the accuracy, let me know if I misunderstood, please), I'm interested in understanding how to finetune all layers (i.e. how to have an unfrozen Distilbert, if I understand correctly) and not to the metric itself (well, yes, but it's the final result: I'm trying to understand the correct setup to avoid the NullPointerError shown in the last comment of @SuperMaskv).

Another headup is that you might need to assign a smaller learning rate to the pre-trained layer (ie embedding layer), which makes sure it is "fine" tuning. This implementation is demonstrated again in examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java.

Ok, I see that there is a different learning rate tracker in the example, I'll try it, thank you, but I expect just it improves the results, not it avoids the NullPointerException.

By the way, the major difference in this https://github.com/deepjavalibrary/djl/pull/2203 PR is that the model now consists of a Lambda block, then the distilbert, then the following classifier layers. Previously Lambda block and the distilbert are together in one layer.

Sorry, here I don't understand anything at all about this point. Can you rephrase, please? Should I add a Lambda block before Distilbert? I can't understand where should the Lambda block in the example (I think you refer to lines 83-89 of https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java, isn't it? I cannot see anything before the original model: what am I missing? Which line adds the Lambda block?

Thank you again for your support

KexinFeng commented 1 year ago

The PR mentioned has been merged, you can now fetch the most updated snapshot version.

To unfreeze the distilber layer you can use the option "trainParam" in the lines shown below. https://github.com/deepjavalibrary/djl/blob/97004d690d997275de18592f65f63f1db1001200/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java#L84-L92

Then, in your code, after

ZooModel<NDList, NDList> embedding = criteria.loadModel();
Block embeddingBlock = embedding.getBlock();

the embeddingBlock will be tuned during training.

But I notice another minor thing,

Block classifier = new SequentialBlock() .add(embedding.getBlock())

but in the example, examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java there is a LambdaBlock before embeddingBlock. Probably this is not an issue though.

KexinFeng commented 1 year ago

adamw https://github.com/deepjavalibrary/djl/pull/2206

adepase commented 1 year ago

@KexinFeng Thank you for your further effort in #2206 . About the previous issues I:

  1. substituted the latest version (0.20.0 in Maven since yesterday), hoping that it contained your PR
  2. added .optOption("trainParam", "true") to the Criteria

No more NullPointerException (fine, thank you again). But I got the following exception:

[main] INFO ai.djl.training.listener.LoggingTrainingListener - Load PyTorch Engine Version 1.13.0 in 0,045 ms. ai.djl.engine.EngineException: forward() is missing value for argument 'argument_2'. Declaration: forward(torch.transformers.models.distilbert.modeling_distilbert.DistilBertModel self, Tensor input_ids, Tensor argument_2, Tensor input) -> ((Tensor)) at ai.djl.pytorch.jni.PyTorchLibrary.moduleForward(Native Method) at ai.djl.pytorch.jni.IValueUtils.forward(IValueUtils.java:47) at ai.djl.pytorch.engine.PtSymbolBlock.forwardInternal(PtSymbolBlock.java:144) at ai.djl.pytorch.engine.PtSymbolBlock.getOutputShapes(PtSymbolBlock.java:224) at ai.djl.nn.SequentialBlock.initializeChildBlocks(SequentialBlock.java:225) at ai.djl.nn.AbstractBaseBlock.initialize(AbstractBaseBlock.java:184) at ai.djl.training.Trainer.initialize(Trainer.java:118) at it.algaware.mrjvs.djl.Test2.getIntentAll(Test2.java:230) at it.algaware.mrjvs.djl.Test2.main(Test2.java:77)

I debugged and found that 0.20.0 not yet contains your merged updates. :( Ok, I took the single files of your PR (except AmazonReview example) and put them in my project. Same error.

Then I figured out: I saw the AmazonReview updates and understood how to change the code. And also the reference in your previous comment to the Lambda block. Which is in examples/src/main/java/ai/djl/examples/training/transferlearning/TrainAmazonReviewRanking.java and not in examples/src/main/java/ai/djl/examples/training/transferlearning/TransferFreshFruit.java as you commented above.

The training started and it was very different than before (with old hyperparameters, but now I'm changing them to follow the paper, hoping that's enough).

Thank you so much again for your really impressive fast answers.

adepase commented 1 year ago

@KexinFeng Just trying the first hyperparameters setup, I already reached 94.17%, which is not the 96.3% of the paper, but I'm confident that I can find a better setup with more work and the code is using Adam and not AdamW (I expect a further minor enhancement from that). Great work, thank you again

KexinFeng commented 1 year ago

I'm glad to know that the answers helped!

By the way, the newest snapshot version is not in the 0.20.0 released version. It is accessed by

dependencies {
    implementation platform("ai.djl:bom:<UPCOMING VERSION>-SNAPSHOT")
}

Currently, the <UPCOMING VERSION> = 0.21.0. See https://github.com/deepjavalibrary/djl/blob/master/docs/get.md#using-built-from-source-version-in-another-project

adepase commented 1 year ago

In my previous case everything worked with the last snapshot (even AdamW) with the Distilbert Model. I almost reached the results I expected (well, with different parameters, but I think the original paper I already mentioned isn't so clear: it leaves many possible interpretations - at least for me, who are not so good at interpreting all the models). Anyway, I tried to change my model using a bert-base-uncased. Just saved it using this code:

import torch
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

text = "Replace me by any text you'd like."
inputs = tokenizer(text, return_tensors='pt')
model.config.return_dict=False;
model.config.torchscript=True;
model.eval();
traced_model = torch.jit.trace(model, [inputs.data.get('input_ids'),inputs.data.get('attention_mask'),inputs.data.get('token_type_ids')])
torch.jit.save(traced_model, "C:\\bertBase\\bertBase.pt")

Then using the following code:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.io.input.BOMInputStream;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.engine.Engine;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.Classifications.Classification;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.AdamW;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.FixedPerVarTracker;
import ai.djl.training.tracker.LinearTracker;
import ai.djl.training.tracker.Tracker;
import ai.djl.training.tracker.WarmUpTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;

public class Test2 {

    public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {

        System.out.println("You are using: " + Engine.getInstance().getEngineName() + " Engine");
        System.out.println("GPU count: " + Engine.getInstance().getGpuCount());
        Device[] pass1 = Engine.getInstance().getDevices();
        System.out.println("GPU: "+pass1[0]);

        getIntentAll(5, "C:\\clinc150_uci\\train",
                "C:\\clinc150_uci\\test", "feature", "label");

    }

    public static void getIntentAll(int epoch, String trainingPath, String validationPath, String feature, String label) throws ModelNotFoundException, MalformedModelException, IOException {
        // let's find intents, first
        ArrayList<String> intents = new ArrayList<>();

        File file = new File(trainingPath);
        Stream<String> stream = null;
        int trainingSetSize = 0;
        try {
            stream = Files.lines(file.toPath(), StandardCharsets.UTF_8);
            Iterator<String> iterator = stream.iterator();
            boolean headerFound = false;
            while (iterator.hasNext()) {
                trainingSetSize+=1;
                String line = iterator.next();
                if (headerFound) {
                    String[] values = line.split("[|]");
                    if (!intents.contains(values[0].toLowerCase())) {
                        intents.add(values[0].toLowerCase());
                    }
                } else {
                    headerFound = true;
                }
            }
        } finally {
            if (stream != null) {
                stream.close();
                stream = null;
            }
        }

        file = new File(validationPath);
        try {
            stream = Files.lines(file.toPath());
            Iterator<String> iterator = stream.iterator();
            boolean headerFound = false;
            while (iterator.hasNext()) {
                String line = iterator.next();
                if (headerFound) {
                    String[] values = line.split("[|]");
                    if (!intents.contains(values[0].toLowerCase())) {
                        intents.add(values[0].toLowerCase());
                    }
                } else {
                    headerFound = true;
                }
            }
        } finally {
            if (stream != null) {
                stream.close();
                stream = null;
            }
        }
        int avgExample = trainingSetSize/intents.size();
        System.out.println("found "+intents.size()+" intents - Avg example per intent: "+ avgExample);

        Criteria<NDList, NDList> criteria = Criteria.builder()
                .optApplication(Application.NLP.WORD_EMBEDDING)
                .setTypes(NDList.class, NDList.class)
                .optModelPath(Paths.get("build/pytorch/bertBase"))
                .optProgress(new ProgressBar())
                .optOption("trainParam", "true") // this unlocks the finetuning on the loaded model
                .build();

        ZooModel<NDList, NDList> embedding = criteria.loadModel();

        System.out.println("Model Loaded");

        DefaultVocabulary vocabulary = DefaultVocabulary.builder()
                .addFromTextFile(embedding.getArtifact("vocab.txt"))
                .optUnknownToken("[UNK]")
                .build();

        int maxTokenLength = 128; 
        int batchSize = 32;
        int limit = Integer.MAX_VALUE;

        BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
        CsvDataset trainingDataset;
        CsvDataset validationDataset;
        try {
            trainingDataset = getIntentDataset(batchSize, tokenizer, maxTokenLength, limit, trainingPath, feature, label, intents, true);

            RandomAccessDataset trainingSet = trainingDataset;
            RandomAccessDataset validationSet = null;
            if (validationPath != null) {
                validationDataset = getIntentDataset(batchSize, tokenizer, maxTokenLength, limit, validationPath, feature, label, intents, false);
                validationSet = validationDataset;
            } else {
                RandomAccessDataset[] datasets = trainingDataset.randomSplit(7, 3);
                trainingSet = datasets[0];
                validationSet = datasets[1];
            }

            System.out.println("Vocabulary and dataset prepared");

            SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
            listener.setSaveModelCallback(
                trainer -> {
                    TrainingResult result = trainer.getTrainingResult();
                    Model model2 = trainer.getModel();
                    // track for accuracy and loss
                    float accuracy = result.getValidateEvaluation("Accuracy");
                    model2.setProperty("Accuracy", String.format("%.5f", accuracy));
                    model2.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
                });

            int epochStep = trainingSetSize/batchSize+1;
            Tracker lrt = Tracker.fixed(0.00004f); // adam default: 0.001f, 4*10^-5 is HINT paper setup for BERT
            LinearTracker lit = LinearTracker.builder()
                    .setBaseValue(0.0001f)
                    .optSlope(-0.000015f/epochStep) 
                    .optMinValue(0.00002f) 
                    .build();

            Optimizer adam = AdamW.builder().optLearningRateTracker(lit).optWeightDecays(0.015f).build();

            Loss loss = Loss.softmaxCrossEntropyLoss();
            DefaultTrainingConfig config = new DefaultTrainingConfig(loss) // loss type
            .optOptimizer(adam) // Optimizer
            .addEvaluator(new Accuracy())
            .optDevices(Engine.getInstance().getDevices(1)) // train using single GPU
            .addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
            .addTrainingListeners(listener);
            System.out.println("Training config set up");

        Model model = Model.newInstance("IntentClassification");
        model.setBlock(getBlock(embedding.getBlock(), intents, avgExample));
        System.out.println("Embedding generated");

            Trainer trainer = model.newTrainer(config);
            trainer.setMetrics(new Metrics());
            Shape encoderInputShape = new Shape(batchSize, maxTokenLength);

            trainer.initialize(encoderInputShape);
            EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
            System.out.println("Training ended: "+trainer.getTrainingResult());

            model.save(Paths.get("build/model"), "bert-intent.param");

        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    private static Block getBlock(Block embedder, ArrayList<String> intents, int avgExample) {
        SequentialBlock classifier = new SequentialBlock();
        // text embedding layer
        if ("PyTorch".equals(Engine.getDefaultEngineName())) {
            LambdaBlock lambda =
                    new LambdaBlock(
                            ndList -> {
                                NDArray data = ndList.singletonOrThrow();
                                NDList inputs = new NDList();
                                inputs.add(data.toType(DataType.INT64, false));
                                inputs.add(
                                        data.getManager().full(data.getShape(), 1, DataType.INT64));
                                inputs.add(
                                        data.getManager()
                                                .arange(data.getShape().get(1)) // maxLen
                                                .toType(DataType.INT64, false)
                                                .broadcast(data.getShape()));
                                return inputs;
                            });
            classifier.add(lambda);
            classifier.add(embedder);
        } else {
            // MXNet
            LambdaBlock lambda =
                    new LambdaBlock(
                            ndList -> {
                                NDArray data = ndList.singletonOrThrow();
                                long batchSize = data.getShape().get(0);
                                float maxLength = data.getShape().get(1);
                                return embedder.forward(
                                        new ParameterStore(),
                                        new NDList(
                                                data,
                                                data.getManager()
                                                        .full(new Shape(batchSize), maxLength)),
                                        true);
                            });
            classifier.add(lambda);
        }
        // Classification layers
        classifier
            .add(Dropout.builder().optRate(0.05f).build())
                .add(Linear.builder().setUnits(768).build()) // pre classifier
                .add(Activation::relu)
                .add(Dropout.builder().optRate(0.18f).build())
                .add(Linear.builder().setUnits(intents.size()).build())
                .addSingleton(nd -> nd.get(":,0")); // follow HF classifier
        return classifier;
    }

    static CsvDataset getIntentDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit, String path, String feature, String label, ArrayList<String> intents, boolean randomSample) {
        float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
        return CsvDataset.builder()
                .optCsvFile(Paths.get(path)) // load from File
                .setCsvFormat(CSVFormat.Builder.create(CSVFormat.TDF)
                        .setQuote(null)
                        .setDelimiter("|")
                        .setRecordSeparator("\n")
                        .setHeader()
                        .build()) // Setting loading format
                .setSampling(batchSize, randomSample) // make sample size and random access
                .optLimit(limit)
                .addFeature(
                        new Feature(feature, new BertFeaturizer(tokenizer, maxLength))
                        )
                .addLabel(
                        new Feature(
                                label, (buf, data) -> {
                                    float intentNumber = intents.indexOf(data.toLowerCase());
                                    buf.put(intentNumber);
                                }))
                .optDataBatchifier(
                        PaddingStackBatchifier.builder()
                                .optIncludeValidLengths(false)
                                .addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
                                .build()) // define how to pad dataset to a fix length
                .build();
    }

}

I get the following error:

java.lang.IndexOutOfBoundsException: Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was 2 at ai.djl.ndarray.NDList.singletonOrThrow(NDList.java:218) at ai.djl.nn.norm.Dropout.forwardInternal(Dropout.java:74) at ai.djl.nn.AbstractBaseBlock.forwardInternal(AbstractBaseBlock.java:128) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:93) at ai.djl.nn.SequentialBlock.forwardInternal(SequentialBlock.java:209) at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:93) at ai.djl.training.Trainer.forward(Trainer.java:189) at ai.djl.training.EasyTrain.trainSplit(EasyTrain.java:122) at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:110) at ai.djl.training.EasyTrain.fit(EasyTrain.java:58) at it.algaware.mrjvs.djl.Test2.getIntentAll(Test2.java:[different line: I cleaned the above code before posting])

I understand (well, probably...) from https://huggingface.co/docs/transformers/model_doc/bert#transformers.models.bert.modeling_bert.BertForPreTrainingOutput that it depends upon the fact that bert-base returns 2 outputs (prediction_logits and seq_relationship_logits, I suppose).

But how can I select the correct output, matching my following 768 Linear block, just after the BERT model? I found it among the two NDList elements of the error, when debugging, but I don't know how to describe it in the getBlock method.

adepase commented 1 year ago

well, ok, in the previous code the modelDir seems not to match the phisical one where I saved the model, but it's just a not complete cleaning of the code: obviously they match, else I'd receive an error before.

KexinFeng commented 1 year ago

It looks like it requires some model level design and debugging. Also the output of the Bert is decided by the model, the subsequent of which needs to be designed accordingly. Is this issue solved?

KexinFeng commented 1 year ago

Regarding

But how can I select the correct output, matching my following 768 Linear block, just after the BERT model?

consider use the Lambda Block or other kinds of blocks in DJL to build it.

adepase commented 1 year ago

@KexinFeng

Is this issue solved?

Well... I didn't open this issue. It seems to me that the first reason for opening this issue has been solved, but I think @SuperMaskv should say that. About my last questions (raised only as continuations of previous comments), maybe it's better I'll open other issues, so to keep it cleaner.

Anyway, I tried the Lambda Block, but I got another issue. Opening apart. Thank you again for your support