haoliuhl / ringattention

Transformers with Arbitrarily Large Context
Apache License 2.0
630 stars 50 forks source link

Test Script Issues #15

Open djbyrne opened 7 months ago

djbyrne commented 7 months ago

Hi Hao,

First off, big thank you for the huge amount of work that has gone into open sourcing the implementation of your research, it is highly appreciated!

While going through the repo and trying to deeply understand the method I discovered that there are some issues with the test script.

1) the test script does not appear to be running different attention methods and is only ever comparing against the default setting. My initial impression from the code was that by setting the 'attention_label' it would update the config and run the attention mechanism associate with that label (i.e standard, ring blockwise etc.) however after further inspection it seems like this no longer does anything and the method will always run based on what has been defined in the base config using the scan_attention, scan_mlp, scan_layers and mesh_dim arguments. In order to actually compare methods you have to update the config at each iteration.

for attention_type in attention_types:
        llama_config_copy = copy.deepcopy(llama_config)
        llama_config_copy.update(dict(attention_type=attention_type))
        if attention_type == ['standard']:
            llama_config_copy.update(dict(scan_attention=False, scan_mlp=False, scan_layers=False, remat_attention='', remat_mlp='',  mesh_dim='1,-1,2,1'))
            llama_config_copy.update(dict(attention_type=attention_type))
        elif attention_type == 'ring_blockwise':
            llama_config_copy.update(dict(scan_attention=True, scan_mlp=True, scan_layers=True, mesh_dim='1,1,2,-1'))
            llama_config_copy.update(dict(attention_type=attention_type))
            llama_config_copy.update(dict(scan_query_chunk_size=1024, scan_key_chunk_size=1024, scan_mlp_chunk_size=1024))
        model = FlaxLLaMAForCausalLMModule(
            llama_config_copy, dtype=get_float_dtype_by_name(FLAGS.dtype)
        )
        models.append(model)
    model = models[0]

2) it appears that it isn't possible to change the mesh_dims as this is defined once at the start of the testing and is used as a context manager for the whole test. So I think we can't change between ring and blockwise during the test.

3) It doesn't look like the grads being returned are a 'FrozenDict' , so the unfreeze at line 163 is not needed (I think its fine that its not frozen in this case).

4) After applying my naive updates to compare Standard with Ring I am now seeing a larger diff in the logits and grads then expected.

standard
logits: 0.0 1.6717689 1.6717689
grads: 0.0 0.11031877 0.11031877

ring_blockwise
logits: 0.0044222176 1.6717689 1.6717689
grads: 6.278977e-05 0.11030923 0.11031877

Is this similar to your own results or should the results be more aligned to Standard Attention as my understanding is that the Blockwise Ring Attention is numerically equivalent. Please could you confirm if my configs are correct for comparing these methods, there is a good chance I have made a mistake somewhere. For reference, I am running on a TPU v4-8, so I only have 4 devices.

Would like to confirm if you agree with these observations, or have I just done something silly when applying my changes? If these are in-fact issues that have crept in I am happy to submit a fix 😃

Cheers,

Donal