Open MoFHeka opened 4 months ago
@denera Could you take a look?
Hi @MoFHeka -- the JAX/XLA custom op for te_scaled_upper_triang_masked_softmax_forward
is implemented here, exposed via PyBind11 here and registered with XLA for the CUDA platform here.
TE/Flax modules invoke this custom op via the scaled_upper_triang_softmax_fwd()
API. Is that what you're trying to use in your application?
If this is not working for you, could you provide us a minimal reproducer along with some information about your platform like GPU type, CUDA driver version and CUDA Toolkit version?
I tried to run python -c 'from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer'
in ghcr.io/nvidia/jax:maxtext-2024-07-17 on H100 with drive 550.54.14 & cuda 12.4 and I can't reproduce the error.
Here is the simple demo code, nothing special. Could be the problem that Kube host machine cuda driver version(535) is too old?
from dataclasses import dataclass
from functools import partial
import functools
import os
import sys
import time
import flax.linen as fnn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
import optax
import flax.linen as nn
import flax.struct
import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
import numpy as np
from flax.linen import partitioning as nn_partitioning
from flax.linen.linear import DotGeneralT, PrecisionLike
from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer
os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
os.environ["XLA_FLAGS"] = """
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=2
--xla_gpu_enable_custom_fusions=true
--xla_gpu_enable_address_computation_fusion=true
"""
@dataclass
class ModelConfig:
"""Configuration for the language models."""
seq_len: int
n_layers: int
d_model: int
num_heads: int
ff_dim: int
dropout: float
batch_size: int
learning_rate: float
max_num_batch: int
class RandomDS:
def __init__(self, batch_size: int, seq_len: int, use_jax=False):
self.batch_size = batch_size
self.seq_len = seq_len
self.use_jax = use_jax
if self.use_jax:
self.rng = jax.random.PRNGKey(1)
def __iter__(self):
if self.use_jax:
batches = [
jax.random.bits(self.rng, shape=(self.batch_size, 100), dtype=jnp.uint8)
for start in range(0, self.batch_size)
]
else:
batches = [
np.random.randint(low=0, high=16, size=(self.batch_size, 100), dtype=np.uint8)
for start in range(0, self.batch_size)
]
return iter(batches)
class TransformerLayer(fnn.Module):
d_model: int
num_heads: int
ff_dim: int
dropout: float
def setup(self):
self.mha = MultiHeadAttention(
head_dim=self.d_model // self.num_heads,
num_attention_heads=self.num_heads,
input_layernorm=False,
dtype=jnp.bfloat16,
)
self.layer_norm_1 = fnn.LayerNorm(epsilon=1e-5, dtype=jnp.bfloat16,)
self.linear_1 = fnn.Dense(
features=self.ff_dim,
kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
dtype=jnp.bfloat16,
param_dtype=jnp.bfloat16
)
self.linear_2 = fnn.Dense(
features=self.d_model,
kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
dtype=jnp.bfloat16,
param_dtype=jnp.bfloat16
)
self.layer_norm_2 = fnn.LayerNorm(epsilon=1e-5, dtype=jnp.bfloat16)
self.dropout_layer = fnn.Dropout(self.dropout, deterministic=False)
def __call__(
self, x: jnp.array, mask: jnp.array
) -> jnp.array:
# "correct" type annotations for jax DeviceArrays are numpy ndarrays
x = self.layer_norm_1(x)
x = self.mha(inputs_q=x, inputs_kv=x, mask=mask)[0]
x = x + self.dropout_layer(x)
x = x + self.dropout_layer(self._ff_block(self.layer_norm_2(x)))
return x
def _ff_block(self, x):
x = jnn.relu(self.linear_1(x))
x = self.dropout_layer(x)
x = self.linear_2(x)
return x
class LM(fnn.Module):
cfg: ModelConfig
def setup(self):
self.byte_embedding = fnn.Embed(
num_embeddings=256,
features=self.cfg.d_model,
embedding_init=jnn.initializers.normal(),
param_dtype=jnp.bfloat16
)
self.positional_encoding = self.param(
"positional_encoding",
jnn.initializers.normal(),
(self.cfg.seq_len, self.cfg.d_model),
dtype=jnp.bfloat16,
)
self.dropout_layer = fnn.Dropout(self.cfg.dropout, deterministic=False)
self.transformer_layers = [
TransformerLayer(
self.cfg.d_model, self.cfg.num_heads, self.cfg.ff_dim, self.cfg.dropout
)
for _ in range(self.cfg.n_layers)
]
self.prob_decoder = fnn.Dense(
features=256,
kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
dtype=jnp.bfloat16,
param_dtype=jnp.bfloat16
)
def __call__(self, text):
x = self.byte_embedding(text)
# Shift x right so causality isn't violated
x = jnp.concatenate(
[jnp.zeros([text.shape[0], 1, self.cfg.d_model], dtype=x.dtype), x[:, :-1, :]], axis=1
)
x = x + self.positional_encoding
x = self.dropout_layer(x)
mask = fnn.attention.make_causal_mask(text)
for layer in self.transformer_layers:
x = layer(x, mask=mask)
return self.prob_decoder(x)
rng = jax.random.PRNGKey(1)
def compute_loss(params, model: LM, text):
model_out = model.apply(params, text=text, rngs={"dropout": rng})
one_hots = jnn.one_hot(text, 256)
loss = optax.softmax_cross_entropy(model_out, one_hots)
return loss
def setup_model(rng, cfg: ModelConfig):
model = LM(cfg)
rng_p, rng_d = jax.random.split(rng)
params = model.init(
{"params": rng_p, "dropout": rng_d}, jnp.zeros([cfg.batch_size, cfg.seq_len], dtype=jnp.uint8)
)
return params, model
def setup_optimizer(params, cfg: ModelConfig):
optimizer = optax.adam(cfg.learning_rate)
opt_state = optimizer.init(params)
return optimizer, opt_state
def train_loop(
model: LM, optimizer, opt_state, params, cfg: ModelConfig, datapath: str
):
def run_train_step(params, opt_state, text_batch):
loss, grad = jax.value_and_grad(lambda p: compute_loss(p, model, text=text_batch).mean())(params)
updates, opt_state = optimizer.update(grad, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
fast_train_step = jax.jit(run_train_step, donate_argnums=[0, 1])
losses = []
t = time.time()
log_per = 20
def multi_train_steps(state, data):
for single_step_data in data:
params, opt_state, loss = fast_train_step(params, opt_state, batch)
return params, opt_state, loss
dataset = list(RandomDS(cfg.batch_size, cfg.seq_len, use_jax=True))
for idx, batch in enumerate(dataset):
params, opt_state, loss = fast_train_step(params, opt_state, batch)
if (idx + 1) % log_per == 0:
break
iter_num = 0
t = time.time()
for batch in dataset:
params, opt_state, loss = fast_train_step(params, opt_state, batch)
losses.append(loss)
iter_num += 1
time_elps = time.time() - t
speed = iter_num * cfg.batch_size / time_elps
print(
f"At iter {iter_num}, loss: {np.mean(losses):.4f}, Speed: {int(speed):d}"
)
t = time.time()
losses = []
return params, opt_state
def setup_all(cfg: ModelConfig, rng=None):
rng = jax.random.PRNGKey(1)
params, model = setup_model(rng, cfg)
optimizer, opt_state = setup_optimizer(params, cfg)
return params, model, optimizer, opt_state
if __name__ == "__main__":
cfg = ModelConfig(
seq_len=100,
n_layers=1,
d_model=512,
num_heads=2,
ff_dim=1024,
dropout=0.1,
batch_size=128,
learning_rate=1e-3,
max_num_batch=5000,
)
params, model, optimizer, opt_state = setup_all(cfg)
params, model, optimizer, opt_state = amp_policy.cast_to_compute((params, model, optimizer, opt_state))
params, opt_state = train_loop(model, optimizer, opt_state, params, cfg)
When I use TE flax layer, all of them report no implementation bug.
Image: ghcr.io/nvidia/jax:maxtext-2024-07-17