google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.1k stars 816 forks source link

BERT model fails on its initialization #1655

Open manifest opened 3 years ago

manifest commented 3 years ago

Description

PretrainedBERT model fails on its initialization.

Environment information

OS: macOS Big Sur 11.4

$ pip freeze | grep trax
trax==1.3.9

$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.5.0
tensorflow-datasets==4.3.0
tensorflow-estimator==2.5.0
tensorflow-hub==0.12.0
tensorflow-metadata==1.0.0
tensorflow-text==2.5.0

$ pip freeze | grep jax
jax==0.2.13
jaxlib==0.1.67

$ python -V
Python 3.8.10

For bugs: reproduction and error logs

# Steps to reproduce:
import trax
trax.models.bert.BERT(init_checkpoint="bert-base-uncased")
# Error logs:
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/manifest/essential/utils/python/env/lib/python3.8/site-packages/trax/models/research/bert.py", line 160, in BERT
    bert = PretrainedBERT(
  File "/Users/manifest/essential/utils/python/env/lib/python3.8/site-packages/trax/models/research/bert.py", line 178, in __init__
    self.init_checkpoint = None
  File "/Users/manifest/essential/utils/python/env/lib/python3.8/site-packages/trax/layers/base.py", line 703, in __setattr__
    raise ValueError(
ValueError: Trax layers only allow to set ('weights', 'state', 'rng') as public attribues, not init_checkpoint.
manifest commented 3 years ago

Related issue

manifest commented 3 years ago

In the PR above, I've overridden the _settable_attrs function of the PretrainedBERT to allow setting init_checkpoint attribute required for loading the model from its checkpoints.