google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

Machine Translation Refromer model.pkl for trax 1.4.1? #1765

Open ymcki opened 1 year ago

ymcki commented 1 year ago

Description

I am trying to translate the Reformer machine_translation code https://github.com/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb that was written for trax 1.2.9 to make it work with trax 1.4.1 However, my program crashes when I try to load model.pkl with this error: "ModuleNotFoundError: No module named 'trax.history'" My understanding is that trax.history in 1.2.9 has been moved to trax.supervised.history in 1.4.1. I believe I need a new model.pkl for 1.4.1 to make it work. Where can I download the new model.pkl? It would also be great if I can download the new config.gin as well. Thanks a lot in advance. ...

Environment information

OS: <your answer here>
Ubuntu 20.04

$ pip freeze | grep trax
# your output here
trax==1.4.1

$ pip freeze | grep tensor
# your output here
tensorboard @ file:///home/conda/feedstock_root/build_artifacts/tensorboard_1664238338171/work/tensorboard-2.10.1-py3-none-any.whl
tensorboard-data-server @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-data-server_1649932776625/work/tensorboard_data_server-0.6.0-py3-none-manylinux2010_x86_64.whl
tensorboard-plugin-wit @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-plugin-wit_1641458951060/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
tensorflow==2.10.0
tensorflow-datasets==4.7.0
tensorflow-estimator @ file:///home/conda/feedstock_root/build_artifacts/tensorflow-split_1663957899180/work/tensorflow-estimator/wheel_dir/tensorflow_estimator-2.10.0-py2.py3-none-any.whl
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.27.0
tensorflow-metadata==1.10.0
tensorflow-text==2.10.0

$ pip freeze | grep jax
# your output here
jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1665610009116/work
jaxlib==0.3.22

$ python -V
# your output here
Python 3.8.10

For bugs: reproduction and error logs

# Steps to reproduce:
...

import sys import gin import os import pickle import jax import trax import numpy as np import jax.numpy as jnp import sacrebleu from trax.data.text_encoder import SubwordTextEncoder from tensorflow.io.gfile import GFile

Load the source text and reference translations into Python

refs = [] for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1): if line.endswith('\n'): line = line[:-1] refs.append(line) srcs = [] for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1): if line.endswith('\n'): line = line[:-1] srcs.append(line)

Set up our sub-word tokenizer

tokenizer = SubwordTextEncoder( 'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')

Encode source sentences using the tokenizer

input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64) for i, x in enumerate(srcs): x = tokenizer.encode(x) assert len(x) <= 127 input_ids[i, :len(x)] = x input_ids[i, len(x)] = 1

We'll be using a pre-trained reversible transformer-base model.

First, load the config (which sets all needed hyperparameters).

!gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin

gin.parse_config_file('./config.gin')

Now we load the pre-trained model weights.

with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f: model_weights = pickle.load(f)['weights']

# Error logs:
...

Traceback (most recent call last): File "reformer.py", line 65, in model_weights = pickle.load(f)['weights'] ModuleNotFoundError: No module named 'trax.history'