google / praxis

Apache License 2.0
176 stars 43 forks source link

Cross-layer attention weight sharing fails in different scopes #52

Open mqyqlx opened 7 months ago

mqyqlx commented 7 months ago

Hi. I try to share attention weight across layers following the testcase in shared_layers_test.py.

  def testSharedTemplateLayer(self):
    sub_params = pax_fiddle.Config(
        linears.FeedForward, input_dims=8, output_dims=8
    )
    # Only share the linear projection, not the entire FeedForward layer.
    sub_params.linear_tpl.shared_weight_layer_id = 'shared_weight'
    test_layer_p = pax_fiddle.Config(
        SimpleShared01,
        name='test',
        sub1_tpl=sub_params.clone(),
        sub2_tpl=sub_params.clone(),
    )
    x_in = jnp.ones([2, 8])
    with base_layer.JaxContext.new_context():
      prng_key = jax.random.PRNGKey(1234)
      layer = base_layer.instantiate(test_layer_p)
      init_vars = layer.init(prng_key, x_in)

But it failed to share weight because of using different scopes when set or lookup cache.

  def lookup_shared_layer(
      self, root_scope: flax_core.Scope, shared_layer_id: str
  ) -> _SharedLayerCacheEntry | None:
    logging.info('lookup_shared_layer called with id: %s in the scope of %s',
                 shared_layer_id, root_scope)
    return self._root_scope_to_shared_layers_map[root_scope][shared_layer_id]

  def set_shared_layer(self, root_scope: flax_core.Scope, shared_layer_id: str,
                       wrapper: _WrapperLayer, layer_hparams):
    logging.info('set_shared_layer called with id: %s in the scope of %s',
                 shared_layer_id, root_scope)
    existing = self.lookup_shared_layer(root_scope, shared_layer_id)
    assert existing is None
    self._root_scope_to_shared_layers_map[root_scope][
        shared_layer_id] = _SharedLayerCacheEntry(
            layer=wrapper.cld, hparams=layer_hparams.clone(), wrapper=wrapper)

Specifically, I implement a 24-layer Llama with StackedTransformer(not using StackedTransformerRepeated) and set shared_weight_layer_id interleaved with the interval of 6, below the line in setup function of StackedTransformer. The main code differences are bolded in the following block. Meanwhile I set remat=True, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING in StackedTransformer.


class StackedTransformer(base_layer.BaseLayer):
  use_cross_attention: bool = False
  mask_self_attention: bool = False
  num_layers: int = 0
  model_dims: int = 0
  hidden_dims: int = 0
  num_heads: int = 0
  dim_per_head: int | None = None
  dropout_prob: float = 0.0
  atten_dropout_prob: float | None = None
  residual_dropout_prob: float | None = None
  relu_dropout_prob: float | None = None
  residual_droppath_prob: float = 0.0
  input_dropout_prob: float = 0.0
  gating_func: str = 'top2'
  unadjusted_expert_capacity_factor: float = 2.0
  transformer_layer_params_tpl: LayerTpl | Sequence[LayerTpl] = template_field(
      Transformer
  )
  packed_input: bool = False
  fold_padding_with_segment_mask: bool = False
  moe_layer_tpl: LayerTpl | None = template_field(TransformerFeedForwardMoe)
  num_experts: int = 0
  num_groups: int = 1
  min_group_size: int | None = None
  moe_layers: Sequence[int] | None = ()
  ngrammer_tpls: Sequence[LayerTpl] | None = template_field(None)
  remat: bool = False
  share_interval: int = 6
  checkpoint_policy: AutodiffCheckpointType = (
      AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS_FFN1
  )

  def _clone_layer_params(self, layer_tpl: LayerTpl) -> LayerTpl:
    """Useful to let subclasses switch the class (e.g. Streaming version)."""
    return layer_tpl.clone()

  def setup(self) -> None:
    assert self.num_layers > 0
    assert self.model_dims > 0
    assert self.hidden_dims > 0
    assert self.num_heads > 0
    assert 0.0 <= self.dropout_prob < 1.0
    assert 0.0 <= self.input_dropout_prob < 1.0
    def _layer_params(i):
      """Construct i-th layer params."""
      if isinstance(self.transformer_layer_params_tpl, Sequence):
        factor = self.num_layers // len(self.transformer_layer_params_tpl)
        ii = i // factor
        p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii])
      else:
        p_i = self._clone_layer_params(self.transformer_layer_params_tpl)
      p_i.name = f'layer_{i}'

      ii = i % self.share_interval  # ii is in the range [0,5] when share_interval = 6
      p_i.tr_atten_tpl.shared_weight_layer_id = f'shared_attn_{ii}'
      
      p_i.use_cross_attention = self.use_cross_attention
      p_i.num_heads = self.num_heads
      p_i.dim_per_head = self.dim_per_head
      p_i.input_dims = self.model_dims
      p_i.packed_input = self.packed_input
      p_i.atten_dropout_prob = self.atten_dropout_prob or self.dropout_prob
      p_i.residual_dropout_prob = (
          self.residual_dropout_prob or self.dropout_prob
      )
      p_i.relu_dropout_prob = self.relu_dropout_prob or self.dropout_prob
      p_i.hidden_dims = self.hidden_dims
      if self.residual_droppath_prob > 0.0:
        p_i.residual_droppath_prob = (
            self.residual_droppath_prob * i / max(1, self.num_layers)
        )
      if self.moe_layers and i in self.moe_layers:
        assert self.num_experts > 0
        assert self.moe_layer_tpl is not None
        moe_p = self.moe_layer_tpl.clone()
        moe_p.num_experts = self.num_experts
        moe_p.num_groups = self.num_groups
        moe_p.min_group_size = self.min_group_size
        moe_p.gating_func = self.gating_func
        if moe_p.hidden_dims:
          # MoE hidden_dims could be different from FFN hidden_dims
          p_i.hidden_dims = moe_p.hidden_dims
        p_i.tr_fflayer_tpl = moe_p
      if self.ngrammer_tpls is not None:
        if self.ngrammer_tpls[i] is not None:
          p_i.ngrammer_tpl = self.ngrammer_tpls[i]
      return p_i

    if isinstance(self.transformer_layer_params_tpl, (list, tuple)):
      if self.num_layers % len(self.transformer_layer_params_tpl):
        raise ValueError(
            'num_layers should be divisible by transformer_layer_params_tpl'
        )

    layer_params = [_layer_params(i) for i in range(self.num_layers)]
    self.create_children('x_layers', layer_params)

    if self.input_dropout_prob > 0.0:
      self.create_child(
          'input_dropout',
          pax_fiddle.Config(
              stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob
          ),
      )

Could you explain why the scopes are different when sharing attention weight across layers? Is it related to layer-wise checkpointing? I would be grateful for a demonstration of how to share attention weights, or any other advice you might offer.

justzh commented 6 months ago

Use #pragma instead of #code or whatever you put at the top of the document.

justzh commented 6 months ago

Get rid of this code. It's ugly. Read the book Clean Code.

justzh commented 6 months ago

That's much better. Awesome! Are you a speed-reader?!