alvations / disambiguate

Disambiguate is a tool for training and using state of the art neural WSD models
https://arxiv.org/abs/1905.05677
0 stars 0 forks source link

Decoding #2

Open alvations opened 4 years ago

alvations commented 4 years ago

Somehow when doing the decoding on CPU makes PyTorch unhappy.

So lets document how to fix this.

alvations commented 4 years ago

From https://github.com/alvations/disambiguate#disambiguating-raw-text

Decoding is done through the decode.sh script

./decode.sh --data_path $DATADIR --weights $DATADIR/model_weights_wsd*

The decode.sh simply does this:

#!/bin/bash
"$(dirname "$0")"/java/launch.sh NeuralWSDDecode --python_path "$(dirname "$0")"/python "$@"

The java/launch.sh calls the NeuralWSDDecode class through maven

#!/bin/bash
mvn exec:java -f "$(dirname "$0")"/pom.xml -e --quiet -Dexec.mainClass="$1" -Dexec.args="${*:2}"

from https://github.com/alvations/disambiguate/blob/master/java/src/main/java/NeuralWSDDecode.java

The python path is set in the NeuralDisambiguator object at https://github.com/alvations/disambiguate/blob/master/java/src/main/java/NeuralWSDDecode.java#L80

neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize);

Then the NeuralDisambiguator calls initPythonProcess to load the model https://github.com/alvations/disambiguate/blob/daaee2938501a793ba6a44269082393d79fe3421/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java#L152

alvations commented 4 years ago

And booms a load of spaghetti madness. Lets try another approach, we go directly to the model_weights_wsd* and load it in torch.

>>> import torch
>>> model = torch.load('model_weights_wsd0')

>>> type(model)
<class 'dict'>

>>> model.keys()
dict_keys(['config', 'backend_encoder', 'optimizer', 'backend_decoder_classification'])

Okay, so the problem is that the model object is not just native torch tensor.

Lets dig through the components

>>> type(model['config'])
<class 'dict'>
>>> model['config']
{'input_embeddings_size': [1024], 'input_embeddings_tokenize_model': [None], 'input_elmo_path': [None], 'input_bert_path': ['bert-large-cased'], 'input_auto_model': [None], 'input_auto_path': [None], 'input_word_dropout_rate': None, 'input_resize': [None], 'input_linear_size': None, 'input_dropout_rate': None, 'encoder_type': 'lstm', 'encoder_lstm_hidden_size': 1000, 'encoder_lstm_layers': 1, 'encoder_lstm_dropout': 0.5, 'encoder_transformer_hidden_size': 512, 'encoder_transformer_layers': 6, 'encoder_transformer_heads': 8, 'encoder_transformer_dropout': 0.1, 'encoder_transformer_positional_encoding': True, 'encoder_transformer_scale_embeddings': True, 'decoder_translation_transformer_hidden_size': 512, 'decoder_translation_transformer_layers': 6, 'decoder_translation_transformer_heads': 8, 'decoder_translation_transformer_dropout': 0.1, 'decoder_translation_scale_embeddings': True, 'decoder_translation_share_embeddings': False, 'decoder_translation_share_encoder_embeddings': False, 'decoder_translation_tokenizer_bert': None}

So config as the key suggests is just some key-value maps.

Now, optimizer:

>>> type(model['optimizer'])
<class 'dict'>
>>> model['optimizer'].keys()
dict_keys(['state', 'param_groups'])

>>> model['optimizer']['param_groups']
[{'lr': 0.0001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [140512760111744, 140513220910144, 140513220907776, 140513220909440, 140513211867584, 140512760915392, 140512760915328, 140512760915200, 140513220907584, 140512760914368]}]
alvations commented 4 years ago

Okay, so this is madness trying to change very tensor inside the torch model object,

what if we just try https://discuss.pytorch.org/t/how-to-convert-gpu-trained-model-on-cpu-model/63514/4

TL;DR

import torch

for i in range(8):
    model = torch.load(f'model_weights_wsd{i}',map_location=torch.device("cpu"))
    torch.save(model, f'cpu_model_weights_wsd{i}')

That works!

alvations commented 4 years ago

Even when changing the model to CPU model, the decode.sh is complaining.

Lets see... From https://github.com/alvations/disambiguate/blob/master/java/src/main/java/NeuralWSDDecode.java#L80

Basically the decoder does:

    private void decode(String[] args) throws Exception
    {
        ArgumentParser parser = new ArgumentParser();
        parser.addArgument("python_path");
        parser.addArgument("data_path");
        parser.addArgumentList("weights");
        parser.addArgument("lowercase", "false");
        parser.addArgument("sense_compression_hypernyms", "true");
        parser.addArgument("sense_compression_instance_hypernyms", "false");
        parser.addArgument("sense_compression_antonyms", "false");
        parser.addArgument("sense_compression_file", "");
        parser.addArgument("clear_text", "true");
        parser.addArgument("batch_size", "1");
        parser.addArgument("truncate_max_length", "150");
        parser.addArgument("filter_lemma", "true");
        parser.addArgument("mfs_backoff", "true");
        if (!parser.parse(args)) return;

        String pythonPath = parser.getArgValue("python_path");
        String dataPath = parser.getArgValue("data_path");
        List<String> weights = parser.getArgValueList("weights");
        boolean lowercase = parser.getArgValueBoolean("lowercase");
        boolean senseCompressionHypernyms = parser.getArgValueBoolean("sense_compression_hypernyms");
        boolean senseCompressionInstanceHypernyms = parser.getArgValueBoolean("sense_compression_instance_hypernyms");
        boolean senseCompressionAntonyms = parser.getArgValueBoolean("sense_compression_antonyms");
        String senseCompressionFile = parser.getArgValue("sense_compression_file");
        boolean clearText = parser.getArgValueBoolean("clear_text");
        int batchSize = parser.getArgValueInteger("batch_size");
        int truncateMaxLength = parser.getArgValueInteger("truncate_max_length");
        filterLemma = parser.getArgValueBoolean("filter_lemma");
        mfsBackoff = parser.getArgValueBoolean("mfs_backoff");

        Map<String, String> senseCompressionClusters = null;
        if (senseCompressionHypernyms || senseCompressionAntonyms)
        {
            senseCompressionClusters = WordnetUtils.getSenseCompressionClusters(WordnetHelper.wn30(), senseCompressionHypernyms, senseCompressionInstanceHypernyms, senseCompressionAntonyms);
        }
        if (!senseCompressionFile.isEmpty())
        {
            senseCompressionClusters = WordnetUtils.getSenseCompressionClustersFromFile(senseCompressionFile);
        }

        CorpusPOSTaggerAndLemmatizer tagger = new CorpusPOSTaggerAndLemmatizer();
        firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30());
        neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize);
        neuralDisambiguator.lowercaseWords = lowercase;
        neuralDisambiguator.filterLemma = filterLemma;
        neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters;

        reader = new BufferedReader(new InputStreamReader(System.in));
        writer = new BufferedWriter(new OutputStreamWriter(System.out));
        List<Sentence> sentences = new ArrayList<>();
        for (String line = reader.readLine(); line != null ; line = reader.readLine())
        {
            Sentence sentence = new Sentence(line);
            if (sentence.getWords().size() > truncateMaxLength)
            {
                sentence.getWords().stream().skip(truncateMaxLength).collect(Collectors.toList()).forEach(sentence::removeWord);
            }
            if (filterLemma)
            {
                tagger.tag(sentence.getWords());
            }
            sentences.add(sentence);
            if (sentences.size() >= batchSize)
            {
                decodeSentenceBatch(sentences);
                sentences.clear();
            }
        }
        decodeSentenceBatch(sentences);
        writer.close();
        reader.close();
        neuralDisambiguator.close();
    }
alvations commented 4 years ago

The steps are:

1. Populate the senseCompressionClusters

- Either read it from file `WordnetUtils.getSenseCompressionClustersFromFile(senseCompressionFile)`  or
- Load the hypernym or antonym compressed clusters: `WordnetUtils.getSenseCompressionClusters(WordnetHelper.wn30(), senseCompressionHypernyms, senseCompressionInstanceHypernyms, senseCompressionAntonyms)`

2. Load some objects.

// POS Tagger + Lemmatizer
CorpusPOSTaggerAndLemmatizer tagger = new CorpusPOSTaggerAndLemmatizer();

# Backoff disambiguator.
firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30());

# The actual model object.
neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize);

// Somehow the model has several configurations, seemingly

// a) whether the words should be lowercased.
neuralDisambiguator.lowercaseWords = lowercase;

// b) whether to filter and only use the lemmas
neuralDisambiguator.filterLemma = filterLemma;

// c) reduce the OOV by loading the sense compression clusters.
neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters;

3. The actually WSD process.


// Read batchSize sentences and put them into `sentences` array.
List<Sentence> sentences = new ArrayList<>();
for (String line = reader.readLine(); line != null ; line = reader.readLine())
{
Sentence sentence = new Sentence(line);

//  Truncate long sentences to some max length.
if (sentence.getWords().size() > truncateMaxLength)
{
sentence.getWords().stream().skip(truncateMaxLength).collect(Collectors.toList()).forEach(sentence::removeWord);
}

// If filtering lemmas, feed it through the tagger+lemmatizer
// I'm assuming that this only returns a list of lemmas 
if (filterLemma)
{
    tagger.tag(sentence.getWords());
}
// Collate the sentences per batch
sentences.add(sentence);
if (sentences.size() >= batchSize)
{
// This is the actual WSD function.
    decodeSentenceBatch(sentences);
    sentences.clear();
}
}

Where's this decodeSentenceBatch coming from?

decodeSentenceBatch -> neuralDisambiguator.disambiguateDynamicSentenceBatch -> disambiguateFixedSentenceBatch -> disambiguateNoCatch -> readPredictOutput ... which leads to

alvations commented 4 years ago

The actual tagging seems to be from https://github.com/alvations/disambiguate/blob/daaee2938501a793ba6a44269082393d79fe3421/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java#L152

        private void initPythonProcess(String pythonPath, String neuralPath, List<String> weightsPaths) throws IOException
    {
        List<String> args = new ArrayList<>(Arrays.asList(pythonPath + "/launch.sh", "getalp.wsd.predict", "--data_path", neuralPath, "--weights"));
        args.addAll(weightsPaths);
        if (clearText) args.add("--clear_text");
        args.addAll(Arrays.asList("--batch_size", "" + batchSize));
        if (disambiguate) args.add("--disambiguate");
        if (translate) args.add("--translate");
        args.addAll(Arrays.asList("--beam_size", "" + beamSize));
        if (extraLemma) args.add("--output_all_features");
        ProcessBuilder pb = new ProcessBuilder(args);
        pb.redirectError(ProcessBuilder.Redirect.INHERIT);
        pythonProcess = pb.start();
        pythonProcessReader = new BufferedReader(new InputStreamReader(pythonProcess.getInputStream()));
        pythonProcessWriter = new BufferedWriter(new OutputStreamWriter(pythonProcess.getOutputStream()));
    }

Seems to be the Python code from getalp.wsd.predict

alvations commented 4 years ago

Oh it looks like all the heavy lifting after all the tagging/filter tricks are at https://github.com/alvations/disambiguate/blob/master/python/getalp/wsd/predicter.py#L20 =)

alvations commented 4 years ago

Somehow this the loading of the ensemble becomes a bottle neck.

While this is fast:

from lazyme import find_files
import torch

model_filenames = list(find_files("models/", "cpu_model_weights_wsd*"))
torch.load(model_filenames[0])

Loading like how https://github.com/alvations/disambiguate/blob/master/python/getalp/wsd/predicter.py is doing is uber slow...

from getalp.wsd.predicter import Predicter
from getalp.wsd.model import Model, ModelConfig, DataConfig

from lazyme import find_files

data_config = DataConfig()
data_config.load_from_file('config.json')
model_config = ModelConfig(data_config)
model_config.load_from_file('config.json')

clear_text = True

if clear_text:
    model_config.data_config.input_clear_text = [True for _ in range(model_config.data_config.input_features)]

to_disambiguate = False if data_config.output_features <= 0 else True
to_translate = False if data_config.output_translations <= 0 else True

model_filenames = list(find_files("models/", "cpu_model_weights_wsd*"))[:1]

p = Predicter()
p.create_ensemble(model_config, model_filenames)
alvations commented 4 years ago

Yatta!! Got it to work!!

git clone https://github.com/alvations/disambiguate.git
cd capellini
wget https://sgp1.digitaloceanspaces.com/nltk/disambiguate/cpu_models.zip
unzip cpu_models.zip

jupyter notebook Bolognese.ipynb
goodmami commented 4 years ago

Needs a bit more:

$ git clone https://github.com/alvations/disambiguate.git
$ cd disambiguate/capellini
$ # pip install -r ../requirements.txt  # if not done already
$ pip install lazyme nltk jupyter  # need a few more things
$ python -m nltk.downloader wordnet
$ wget https://sgp1.digitaloceanspaces.com/nltk/disambiguate/cpu_models.zip
$ unzip cpu_models.zip -d models  # notebook expects directory "models/"
$ PYTHONPATH=../python/ jupyter notebook Bolognese.ipynb

Then I get results that look like the "compressed" senses and not the resolved lemmatic senses. Was it the Java code that did this sense resolution? Or maybe the results are just totally wrong.

alvations commented 4 years ago

Maybe there's more after the Python pass it back to the Java? Lets see if I can distract myself some time in the afternoon tomorrow =)