pytorch / text

Models, data loaders and abstractions for language processing, powered by PyTorch
https://pytorch.org/text
BSD 3-Clause "New" or "Revised" License
3.51k stars 811 forks source link

How to use TorchText with Java #1369

Open anjali-chadha opened 3 years ago

anjali-chadha commented 3 years ago

❓ Questions and Help

Description I have a SentencePiece model which I serialized using sentencepiece_processor. My end goal is to use this torchscript serialized tokenizer in Java along with DJL Pytorch dependency. I am looking for guidance on how can I import torchtext dependency in Java environment.

Steps: 1. Serializing SPM Tokenizer using Torchtext Torchscript Serialized file is saved as 'spm-jit.pt'

import torch
from torchtext.experimental.transforms import sentencepiece_processor
spm_processor = sentencepiece_processor('foo.model')
jit_spm_processor = torch.jit.script(spm_processor)
torch.jit.save(jit_spm_processor,  'spm-jit.pt')

2. Deserializing SPM Tokenizer in Python Loadingspm-jit.pt without importing torchtext fails with the following error.

import torch
spm_tokenizer = torch.jit.load('spm-jit.pt')  # Fails when torchtext is not imported

Error

/usr/local/lib/python3.6/dist-packages/torch/jit/_serialization.py in load(f, map_location, _extra_files)
    159     cu = torch._C.CompilationUnit()
    160     if isinstance(f, str) or isinstance(f, pathlib.Path):
--> 161         cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
    162     else:
    163         cpp_module = torch._C.import_ir_module_from_buffer(

RuntimeError: 
Unknown type name '__torch__.torch.classes.torchtext.SentencePiece':
Serialized   File "code/__torch__/torchtext/experimental/transforms.py", line 6
  training : bool
  _is_full_backward_hook : Optional[bool]
  sp_model : __torch__.torch.classes.torchtext.SentencePiece
             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  def forward(self: __torch__.torchtext.experimental.transforms.SentencePieceProcessor,
    line: str) -> List[int]:

After importing torchtext, I am able to load the tokenizer from torchscript file.

import torch
import torchtext
spm_tokenizer = torch.jit.load('spm-jit.pt')  # Succeeds

This led me to the conclusion that serialized file has dependency on torchtext for it to load successfully in Java/Python/C++ environment.

Any guidance on how can I use torchtext in Java and/or C++

Thanks!

parmeet commented 3 years ago

Hi @anjali-chadha, Unfortunately I am not too familiar with Java.

Regarding Python: We need to load the library _torchtext.so in order for functionalities to be available in python run-time. Importing torchtext would implicitly load this library, but you could do so explicitly as well. Have a look here https://github.com/pytorch/text/blob/760a625f8796293145c7a9bf4d5c710cfb0aabc8/torchtext/__init__.py#L24

Regarding C++: You would need to link your application code with _torchtext.so and python lib. You can find more details here https://github.com/pytorch/text/issues/1255#issuecomment-867821021. Unfortunately, python lib is still required to be linked as we haven't yet got run-time library that doesn't depend on python lib. Although there is a plan to do this, but I cannot say much on the time lines. cc: @mthrok, @hudeven

frankfliu commented 3 years ago

@anjali-chadha You can try to load _torchtext.so manually:

System.load('/Library/Python/3.8/site-packages/torchtext/_torchtext.so');

With DJL 0.12.0, you simple define an environment variable, and DJL will load it to you:

export PYTORCH_EXTRA_LIBRARY_PATH=/Library/Python/3.8/site-packages/torchtext/_torchtext.so