google / aqt

Apache License 2.0
247 stars 25 forks source link

flax_e2e_model.py example fails #667

Closed Jconn closed 1 month ago

Jconn commented 2 months ago

I'm getting this error when running python3 flax_e2e_model.py which I think is from the lhs quantmode being QuantMode.CONVERT, which pushes the lhs freezer to store the lhs scale during serving.

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 490, in <module>
    app.run(main)
  File "/var/tmp/aqt/aqt_env/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/var/tmp/aqt/aqt_env/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 485, in main
    loss = serve(state, weight_only=False)
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 440, in serve
    logits = serve_fn(
  File "/var/tmp/aqt/aqt/jax/v2/examples/flax_e2e_model.py", line 84, in __call__
    x = nn.Dense(features=256, dot_general_cls=aqt_dg)(x)
  File "/var/tmp/aqt/aqt_env/lib/python3.10/site-packages/flax/linen/linear.py", line 276, in __call__
    y = dot_general(
  File "/var/tmp/aqt/aqt/jax/v2/flax/aqt_flax.py", line 515, in __call__
    return ret_dg(
  File "/var/tmp/aqt/aqt/jax/v2/tiled_dot_general.py", line 527, in tiled_dot_general
    return tiled_dot_general_with_tiling_states(
  File "/var/tmp/aqt/aqt/jax/v2/tiled_dot_general.py", line 419, in tiled_dot_general_with_tiling_states
    out = dot_general(
  File "/var/tmp/aqt/aqt/jax/v2/flax/aqt_flax.py", line 459, in ret_dg
    lhs_freezer.set(out_lhs_qt)
  File "/var/tmp/aqt/aqt/jax/v2/flax/freezer.py", line 100, in set
    return self._get_or_set(inputs, is_set=True)
  File "/var/tmp/aqt/aqt/jax/v2/flax/freezer.py", line 63, in _get_or_set
    s.value = inputs
flax.errors.ModifyScopeVariableError: Cannot update variable "frozen" in "/Dense_0/AqtDotGeneral_0/qlhs" because collection "aqt" is immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ModifyScopeVariableError)