erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
192 stars 23 forks source link

Question: Does low-bit config reduce TPU HBM memory usage when training? #50

Closed Beomi closed 9 months ago

Beomi commented 10 months ago

Hello,

Firstly, I'd like to express my appreciation for your work on this repository. I noticed that it supports low-bit (4 or 8 bits) formats during training, which is quite intriguing.

I have a query regarding TPU compatibility, particularly before TPUv4. As TPUs typically don't support low-bit formats like 4 or 8 bits until TPUv4 (which supports int8), I'm curious about how this implementation works. My current understanding is that the code might be converting 4 or 8-bit formats into bfloat16 or float16 formats. If this is the case, would it imply that the memory usage reduction typically expected from lower bit formats might not be realized?

Could you please clarify if my understanding is correct? Thanks for your time and effort in developing and maintaining this repository.

erfanzar commented 10 months ago

hello and thanks ill store parameters as float16,bfloat16 or float32 and do operations in bits like 8 6 4, and you can still train your model in this way and make it much more accurate than bitsandbytes or peft fine-tuning if I have to give you a simple blueprint of how does those works let me describe that in this way

 # hidden_state an array with the shape of (N-Dims)
# q_proj query proj a Neuron in Flax
# q_flax.QDotGeneral built-in FJFormer (Computation backend for EasyDel)

if config.bits is not None:
    _dot_general_cls = q_config.fully_quantized(
        fwd_bits=self.config.bits,
        bwd_bits=self.config.bits
    )
else:
    _dot_general_cls = None

dot_general_cls = q_flax.QDotGeneral(_dot_general_cls)

q_proj = nn.Dense(..., dot_general=dot_general_cls) 
# now this layer will take any array and do the matmul or dot operation in given bits for example in 6,8 or 4 
# instead of using jax.lax.dot_general we use fjformer QDotGeneral like this
#
# from fjformer.bits import q_flax, config as q_config
#
# dot_general = q_flax.QDotGeneral(q_config.fully_quantized(
#     bwd_bits=4,
#     fwd_bits=4
# ))

hidden_state = q_proj(hidden_state)
# hidden_state is still in bfloat16 or float16 or float32 but operations are computed in given bits

Run and find Out

so let see how will the code work or compute with running the code

import chex

from fjformer.bits import config, q_flax as q
from flax import linen as nn
from typing import Optional
import jax

class MLP(nn.Module):
    qnt_config: Optional[config.DotGeneral] = None

    @nn.compact
    def __call__(self, inputs: chex.Array):
        dot_general = q.QDotGeneral(self.qnt_config) # in case of passing None the jax.lax.dot_general will be used
        x = nn.Dense(dot_general=dot_general, features=inputs.shape[-1] * 4)(inputs)
        x = nn.silu(x)
        x = nn.Dense(dot_general=dot_general, features=inputs.shape[-1])(x)
        return x + inputs

def init_and_eval(name, inputs, mlp_block, init_seed=0, eval_seed=0):
    params = mlp_block.init(jax.random.PRNGKey(init_seed), inputs)
    out = mlp_block.apply(params, inputs, rngs={'params': jax.random.key(eval_seed)})
    print(f"{name}:\n", out)

def main():
    int8_config = config.fully_quantized(fwd_bits=8, bwd_bits=8)

    mlp_fp16 = MLP(qnt_config=None)
    mlp_int8 = MLP(qnt_config=int8_config)

    input_ = jax.random.normal(jax.random.key(0), (1, 600, 4))

    init_and_eval('mlp_fp16', input_, mlp_fp16)
    init_and_eval('mlp_int8', input_, mlp_int8)

if __name__ == "__main__":
    main()

and the output must be like this

Float16 Computation

mlp_fp16:
 [[[ 0.6609616  -0.11108614 -0.0223496  -0.81255555]
  [-0.25502944  0.51054865 -1.0421164   0.79341686]
  [ 0.7009573  -0.5298119  -0.57316536  1.602029  ]
  ...
  [-0.08386192 -0.34028694  0.05781332 -0.06801909]
  [ 0.6388167   2.3675287  -1.3287199  -2.387364  ]
  [-0.42537162  0.5458796   0.40854657  1.6986433 ]]]

INT 8 Computation

mlp_int8:
 [[[ 0.6650832  -0.11005261 -0.01998997 -0.8111123 ]
  [-0.26523077  0.5025017  -1.0404696   0.79337656]
  [ 0.6945003  -0.5313612  -0.5694739   1.6055017 ]
  ...
  [-0.08491541 -0.34102592  0.05706861 -0.06926808]
  [ 0.64493084  2.3578165  -1.3139739  -2.3902466 ]
  [-0.4261561   0.5440683   0.401039    1.68537   ]]]
erfanzar commented 10 months ago

Hello,

Firstly, I'd like to express my appreciation for your work on this repository. I noticed that it supports low-bit (4 or 8 bits) formats during training, which is quite intriguing.

I have a query regarding TPU compatibility, particularly before TPUv4. As TPUs typically don't support low-bit formats like 4 or 8 bits until TPUv4 (which supports int8), I'm curious about how this implementation works. My current understanding is that the code might be converting 4 or 8-bit formats into bfloat16 or float16 formats. If this is the case, would it imply that the memory usage reduction typically expected from lower bit formats might not be realized?

Could you please clarify if my understanding is correct? Thanks for your time and effort in developing and maintaining this repository.

if you still have any other questions I'll be happy to answer them and help you and yes in some cases training process is ~1.9 times faster and maybe 3 times faster but I don't know where I'm doing this wrong right now in the case of using platform-specific attentions built-in FJFormer such as tpu_flash_attention, ring_attention_standard, ring_flash_attention_gpu, tpu_flash_attention, or gpu_flash_attention it's not working well and will store much more memory in the backward process and other memory footprint issues like this happen in TPU-V3 and TPU-V5e but not of GPUs or TPU-V4 it kinda depends on the case you are using the model too.