kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

how to restart training #197

Closed whoislimshady closed 2 years ago

whoislimshady commented 2 years ago

i trained a model till 1000 step which had to be saved at 1000th step but i think cancelled little early how to start training form last checkpoint ?

do i just point model path to the last checkpoint or there is any other better way to do that btw when i do point it to the last checkpoint i am getting like this


  File "device_train.py", line 259, in <module>
    network = CausalTransformer(params)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 277, in __init__
    self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 516, in fun_mapped
    out_flat = xmap_p.bind(
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 652, in bind
    return core.call_bind(self, fun, *args, **params)  # type: ignore
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 655, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 539, in xmap_impl
    return make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 555, in make_xmap_callable
    jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 182, in init
    params = param_init_fn(key, x, x)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/transform.py", line 113, in init_fn
    params, state = f.init(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/transform.py", line 364, in init_fn
    f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/random.py", line 96, in wrapper
    jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/random.py", line 93, in pure_fun
    return fun(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 178, in train_loss
    return transformer.loss(x, y)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 65, in loss
    loss, correct = self.eval(ctx, tgt, float(z_loss), mask=mask)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 60, in eval
    x = x + hk.remat(l)(x, attn_bias)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/stateful.py", line 361, in wrapper
    out, state = dec_stateful_fun(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/stateful.py", line 353, in stateful_fun
    out = fun(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 310, in __call__
    attn_out = self.self_attn(q, v, k, bias)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 269, in self_attn
    q_rot = apply_rotary_pos_emb(q_rot, sincos)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 147, in apply_rotary_pos_emb
    sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2)[-x.shape[0]:, None, :], sincos)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 147, in <lambda>
    sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2)[-x.shape[0]:, None, :], sincos)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/einops.py", line 502, in repeat
    return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/einops.py", line 382, in reduce
    return recipe.apply(tensor)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/einops.py", line 203, in apply
    backend = get_backend(tensor)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/_backends.py", line 49, in get_backend
    if backend.is_appropriate_type(tensor):
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/_backends.py", line 450, in is_appropriate_type
    return isinstance(tensor, (self.tf.Tensor, self.tf.Variable))
jax._src.traceback_util.FilteredStackTrace: AttributeError: module 'tensorflow' has no attribute 'Tensor'

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/transform.py", line 364, in init_fn
    f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/random.py", line 96, in wrapper
    jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/api.py", line 2315, in eval_shape
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 404, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1178, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/random.py", line 93, in pure_fun
    return fun(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 178, in train_loss
    return transformer.loss(x, y)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 65, in loss
    loss, correct = self.eval(ctx, tgt, float(z_loss), mask=mask)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 60, in eval
    x = x + hk.remat(l)(x, attn_bias)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/stateful.py", line 361, in wrapper
    out, state = dec_stateful_fun(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/api.py", line 2408, in fun_remat
    out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/core.py", line 1402, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/core.py", line 1405, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1061, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/stateful.py", line 353, in stateful_fun
    out = fun(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 310, in __call__
    attn_out = self.self_attn(q, v, k, bias)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 269, in self_attn
    q_rot = apply_rotary_pos_emb(q_rot, sincos)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 147, in apply_rotary_pos_emb
    sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2)[-x.shape[0]:, None, :], sincos)
  File "/home/harsh/gptj/mesh_transformer/layers.py", line 147, in <lambda>
    sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2)[-x.shape[0]:, None, :], sincos)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/einops.py", line 502, in repeat
    return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/einops.py", line 382, in reduce
    return recipe.apply(tensor)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/einops.py", line 203, in apply
    backend = get_backend(tensor)
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/_backends.py", line 49, in get_backend
    if backend.is_appropriate_type(tensor):
  File "/home/harsh/.local/lib/python3.8/site-packages/einops/_backends.py", line 450, in is_appropriate_type
    return isinstance(tensor, (self.tf.Tensor, self.tf.Variable))
AttributeError: module 'tensorflow' has no attribute 'Tensor'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "device_train.py", line 259, in <module>
    network = CausalTransformer(params)
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 277, in __init__
    self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 516, in fun_mapped
    out_flat = xmap_p.bind(
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 652, in bind
    return core.call_bind(self, fun, *args, **params)  # type: ignore
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/core.py", line 1393, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 655, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/core.py", line 600, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 539, in xmap_impl
    return make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/linear_util.py", line 260, in memoized_fun
    ans = call(fun, *args)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 555, in make_xmap_callable
    jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/harsh/.local/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/harsh/gptj/mesh_transformer/transformer_shard.py", line 182, in init
    params = param_init_fn(key, x, x)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/transform.py", line 113, in init_fn
    params, state = f.init(*args, **kwargs)
  File "/home/harsh/.local/lib/python3.8/site-packages/haiku/_src/transform.py", line 365, in init_fn
    except jax.errors.UnexpectedTracerError as e:
AttributeError: module 'jax.errors' has no attribute 'UnexpectedTracerError'```