Open alvations opened 4 years ago
From https://github.com/alvations/disambiguate#disambiguating-raw-text
Decoding is done through the decode.sh
script
./decode.sh --data_path $DATADIR --weights $DATADIR/model_weights_wsd*
The decode.sh
simply does this:
#!/bin/bash
"$(dirname "$0")"/java/launch.sh NeuralWSDDecode --python_path "$(dirname "$0")"/python "$@"
The java/launch.sh
calls the NeuralWSDDecode class through maven
#!/bin/bash
mvn exec:java -f "$(dirname "$0")"/pom.xml -e --quiet -Dexec.mainClass="$1" -Dexec.args="${*:2}"
from https://github.com/alvations/disambiguate/blob/master/java/src/main/java/NeuralWSDDecode.java
The python path is set in the NeuralDisambiguator
object at https://github.com/alvations/disambiguate/blob/master/java/src/main/java/NeuralWSDDecode.java#L80
neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize);
Then the NeuralDisambiguator calls initPythonProcess
to load the model https://github.com/alvations/disambiguate/blob/daaee2938501a793ba6a44269082393d79fe3421/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java#L152
And booms a load of spaghetti madness. Lets try another approach, we go directly to the model_weights_wsd*
and load it in torch.
>>> import torch
>>> model = torch.load('model_weights_wsd0')
>>> type(model)
<class 'dict'>
>>> model.keys()
dict_keys(['config', 'backend_encoder', 'optimizer', 'backend_decoder_classification'])
Okay, so the problem is that the model object is not just native torch tensor.
Lets dig through the components
>>> type(model['config'])
<class 'dict'>
>>> model['config']
{'input_embeddings_size': [1024], 'input_embeddings_tokenize_model': [None], 'input_elmo_path': [None], 'input_bert_path': ['bert-large-cased'], 'input_auto_model': [None], 'input_auto_path': [None], 'input_word_dropout_rate': None, 'input_resize': [None], 'input_linear_size': None, 'input_dropout_rate': None, 'encoder_type': 'lstm', 'encoder_lstm_hidden_size': 1000, 'encoder_lstm_layers': 1, 'encoder_lstm_dropout': 0.5, 'encoder_transformer_hidden_size': 512, 'encoder_transformer_layers': 6, 'encoder_transformer_heads': 8, 'encoder_transformer_dropout': 0.1, 'encoder_transformer_positional_encoding': True, 'encoder_transformer_scale_embeddings': True, 'decoder_translation_transformer_hidden_size': 512, 'decoder_translation_transformer_layers': 6, 'decoder_translation_transformer_heads': 8, 'decoder_translation_transformer_dropout': 0.1, 'decoder_translation_scale_embeddings': True, 'decoder_translation_share_embeddings': False, 'decoder_translation_share_encoder_embeddings': False, 'decoder_translation_tokenizer_bert': None}
So config as the key suggests is just some key-value maps.
Now, optimizer
:
>>> type(model['optimizer'])
<class 'dict'>
>>> model['optimizer'].keys()
dict_keys(['state', 'param_groups'])
>>> model['optimizer']['param_groups']
[{'lr': 0.0001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [140512760111744, 140513220910144, 140513220907776, 140513220909440, 140513211867584, 140512760915392, 140512760915328, 140512760915200, 140513220907584, 140512760914368]}]
Okay, so this is madness trying to change very tensor inside the torch model object,
what if we just try https://discuss.pytorch.org/t/how-to-convert-gpu-trained-model-on-cpu-model/63514/4
import torch
for i in range(8):
model = torch.load(f'model_weights_wsd{i}',map_location=torch.device("cpu"))
torch.save(model, f'cpu_model_weights_wsd{i}')
That works!
Even when changing the model to CPU model, the decode.sh
is complaining.
Lets see... From https://github.com/alvations/disambiguate/blob/master/java/src/main/java/NeuralWSDDecode.java#L80
Basically the decoder does:
private void decode(String[] args) throws Exception
{
ArgumentParser parser = new ArgumentParser();
parser.addArgument("python_path");
parser.addArgument("data_path");
parser.addArgumentList("weights");
parser.addArgument("lowercase", "false");
parser.addArgument("sense_compression_hypernyms", "true");
parser.addArgument("sense_compression_instance_hypernyms", "false");
parser.addArgument("sense_compression_antonyms", "false");
parser.addArgument("sense_compression_file", "");
parser.addArgument("clear_text", "true");
parser.addArgument("batch_size", "1");
parser.addArgument("truncate_max_length", "150");
parser.addArgument("filter_lemma", "true");
parser.addArgument("mfs_backoff", "true");
if (!parser.parse(args)) return;
String pythonPath = parser.getArgValue("python_path");
String dataPath = parser.getArgValue("data_path");
List<String> weights = parser.getArgValueList("weights");
boolean lowercase = parser.getArgValueBoolean("lowercase");
boolean senseCompressionHypernyms = parser.getArgValueBoolean("sense_compression_hypernyms");
boolean senseCompressionInstanceHypernyms = parser.getArgValueBoolean("sense_compression_instance_hypernyms");
boolean senseCompressionAntonyms = parser.getArgValueBoolean("sense_compression_antonyms");
String senseCompressionFile = parser.getArgValue("sense_compression_file");
boolean clearText = parser.getArgValueBoolean("clear_text");
int batchSize = parser.getArgValueInteger("batch_size");
int truncateMaxLength = parser.getArgValueInteger("truncate_max_length");
filterLemma = parser.getArgValueBoolean("filter_lemma");
mfsBackoff = parser.getArgValueBoolean("mfs_backoff");
Map<String, String> senseCompressionClusters = null;
if (senseCompressionHypernyms || senseCompressionAntonyms)
{
senseCompressionClusters = WordnetUtils.getSenseCompressionClusters(WordnetHelper.wn30(), senseCompressionHypernyms, senseCompressionInstanceHypernyms, senseCompressionAntonyms);
}
if (!senseCompressionFile.isEmpty())
{
senseCompressionClusters = WordnetUtils.getSenseCompressionClustersFromFile(senseCompressionFile);
}
CorpusPOSTaggerAndLemmatizer tagger = new CorpusPOSTaggerAndLemmatizer();
firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30());
neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize);
neuralDisambiguator.lowercaseWords = lowercase;
neuralDisambiguator.filterLemma = filterLemma;
neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters;
reader = new BufferedReader(new InputStreamReader(System.in));
writer = new BufferedWriter(new OutputStreamWriter(System.out));
List<Sentence> sentences = new ArrayList<>();
for (String line = reader.readLine(); line != null ; line = reader.readLine())
{
Sentence sentence = new Sentence(line);
if (sentence.getWords().size() > truncateMaxLength)
{
sentence.getWords().stream().skip(truncateMaxLength).collect(Collectors.toList()).forEach(sentence::removeWord);
}
if (filterLemma)
{
tagger.tag(sentence.getWords());
}
sentences.add(sentence);
if (sentences.size() >= batchSize)
{
decodeSentenceBatch(sentences);
sentences.clear();
}
}
decodeSentenceBatch(sentences);
writer.close();
reader.close();
neuralDisambiguator.close();
}
The steps are:
senseCompressionClusters
- Either read it from file `WordnetUtils.getSenseCompressionClustersFromFile(senseCompressionFile)` or
- Load the hypernym or antonym compressed clusters: `WordnetUtils.getSenseCompressionClusters(WordnetHelper.wn30(), senseCompressionHypernyms, senseCompressionInstanceHypernyms, senseCompressionAntonyms)`
// POS Tagger + Lemmatizer
CorpusPOSTaggerAndLemmatizer tagger = new CorpusPOSTaggerAndLemmatizer();
# Backoff disambiguator.
firstSenseDisambiguator = new FirstSenseDisambiguator(WordnetHelper.wn30());
# The actual model object.
neuralDisambiguator = new NeuralDisambiguator(pythonPath, dataPath, weights, clearText, batchSize);
// Somehow the model has several configurations, seemingly
// a) whether the words should be lowercased.
neuralDisambiguator.lowercaseWords = lowercase;
// b) whether to filter and only use the lemmas
neuralDisambiguator.filterLemma = filterLemma;
// c) reduce the OOV by loading the sense compression clusters.
neuralDisambiguator.reducedOutputVocabulary = senseCompressionClusters;
// Read batchSize sentences and put them into `sentences` array.
List<Sentence> sentences = new ArrayList<>();
for (String line = reader.readLine(); line != null ; line = reader.readLine())
{
Sentence sentence = new Sentence(line);
// Truncate long sentences to some max length.
if (sentence.getWords().size() > truncateMaxLength)
{
sentence.getWords().stream().skip(truncateMaxLength).collect(Collectors.toList()).forEach(sentence::removeWord);
}
// If filtering lemmas, feed it through the tagger+lemmatizer
// I'm assuming that this only returns a list of lemmas
if (filterLemma)
{
tagger.tag(sentence.getWords());
}
// Collate the sentences per batch
sentences.add(sentence);
if (sentences.size() >= batchSize)
{
// This is the actual WSD function.
decodeSentenceBatch(sentences);
sentences.clear();
}
}
decodeSentenceBatch
coming from?decodeSentenceBatch
-> neuralDisambiguator.disambiguateDynamicSentenceBatch
-> disambiguateFixedSentenceBatch
-> disambiguateNoCatch
-> readPredictOutput
... which leads to
The actual tagging seems to be from https://github.com/alvations/disambiguate/blob/daaee2938501a793ba6a44269082393d79fe3421/java/src/main/java/getalp/wsd/method/neural/NeuralDisambiguator.java#L152
private void initPythonProcess(String pythonPath, String neuralPath, List<String> weightsPaths) throws IOException
{
List<String> args = new ArrayList<>(Arrays.asList(pythonPath + "/launch.sh", "getalp.wsd.predict", "--data_path", neuralPath, "--weights"));
args.addAll(weightsPaths);
if (clearText) args.add("--clear_text");
args.addAll(Arrays.asList("--batch_size", "" + batchSize));
if (disambiguate) args.add("--disambiguate");
if (translate) args.add("--translate");
args.addAll(Arrays.asList("--beam_size", "" + beamSize));
if (extraLemma) args.add("--output_all_features");
ProcessBuilder pb = new ProcessBuilder(args);
pb.redirectError(ProcessBuilder.Redirect.INHERIT);
pythonProcess = pb.start();
pythonProcessReader = new BufferedReader(new InputStreamReader(pythonProcess.getInputStream()));
pythonProcessWriter = new BufferedWriter(new OutputStreamWriter(pythonProcess.getOutputStream()));
}
Seems to be the Python code from getalp.wsd.predict
Oh it looks like all the heavy lifting after all the tagging/filter tricks are at https://github.com/alvations/disambiguate/blob/master/python/getalp/wsd/predicter.py#L20 =)
Somehow this the loading of the ensemble becomes a bottle neck.
While this is fast:
from lazyme import find_files
import torch
model_filenames = list(find_files("models/", "cpu_model_weights_wsd*"))
torch.load(model_filenames[0])
Loading like how https://github.com/alvations/disambiguate/blob/master/python/getalp/wsd/predicter.py is doing is uber slow...
from getalp.wsd.predicter import Predicter
from getalp.wsd.model import Model, ModelConfig, DataConfig
from lazyme import find_files
data_config = DataConfig()
data_config.load_from_file('config.json')
model_config = ModelConfig(data_config)
model_config.load_from_file('config.json')
clear_text = True
if clear_text:
model_config.data_config.input_clear_text = [True for _ in range(model_config.data_config.input_features)]
to_disambiguate = False if data_config.output_features <= 0 else True
to_translate = False if data_config.output_translations <= 0 else True
model_filenames = list(find_files("models/", "cpu_model_weights_wsd*"))[:1]
p = Predicter()
p.create_ensemble(model_config, model_filenames)
Yatta!! Got it to work!!
git clone https://github.com/alvations/disambiguate.git
cd capellini
wget https://sgp1.digitaloceanspaces.com/nltk/disambiguate/cpu_models.zip
unzip cpu_models.zip
jupyter notebook Bolognese.ipynb
Needs a bit more:
$ git clone https://github.com/alvations/disambiguate.git
$ cd disambiguate/capellini
$ # pip install -r ../requirements.txt # if not done already
$ pip install lazyme nltk jupyter # need a few more things
$ python -m nltk.downloader wordnet
$ wget https://sgp1.digitaloceanspaces.com/nltk/disambiguate/cpu_models.zip
$ unzip cpu_models.zip -d models # notebook expects directory "models/"
$ PYTHONPATH=../python/ jupyter notebook Bolognese.ipynb
Then I get results that look like the "compressed" senses and not the resolved lemmatic senses. Was it the Java code that did this sense resolution? Or maybe the results are just totally wrong.
Maybe there's more after the Python pass it back to the Java? Lets see if I can distract myself some time in the afternoon tomorrow =)
Somehow when doing the decoding on CPU makes PyTorch unhappy.
So lets document how to fix this.