NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.98k stars 328 forks source link

jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to xxx for platform CUDA #1024

Open MoFHeka opened 4 months ago

MoFHeka commented 4 months ago

jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to te_scaled_upper_triang_masked_softmax_forward for platform CUDA

from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer

When I use TE flax layer, all of them report no implementation bug.

Image: ghcr.io/nvidia/jax:maxtext-2024-07-17

ptrendx commented 4 months ago

@denera Could you take a look?

denera commented 3 months ago

Hi @MoFHeka -- the JAX/XLA custom op for te_scaled_upper_triang_masked_softmax_forward is implemented here, exposed via PyBind11 here and registered with XLA for the CUDA platform here.

TE/Flax modules invoke this custom op via the scaled_upper_triang_softmax_fwd() API. Is that what you're trying to use in your application?

If this is not working for you, could you provide us a minimal reproducer along with some information about your platform like GPU type, CUDA driver version and CUDA Toolkit version?

zlsh80826 commented 3 months ago

I tried to run python -c 'from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer' in ghcr.io/nvidia/jax:maxtext-2024-07-17 on H100 with drive 550.54.14 & cuda 12.4 and I can't reproduce the error.

MoFHeka commented 3 months ago

Here is the simple demo code, nothing special. Could be the problem that Kube host machine cuda driver version(535) is too old?

from dataclasses import dataclass
from functools import partial
import functools
import os
import sys
import time

import flax.linen as fnn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
import optax

import flax.linen as nn
import flax.struct
import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
import numpy as np
from flax.linen import partitioning as nn_partitioning
from flax.linen.linear import DotGeneralT, PrecisionLike

from transformer_engine.jax.flax.transformer import DotProductAttention, MultiHeadAttention, TransformerLayer

os.environ["NVIDIA_TF32_OVERRIDE"] = "1"

os.environ["XLA_FLAGS"] = """
    --xla_gpu_enable_triton_gemm=false
    --xla_gpu_graph_level=2
    --xla_gpu_enable_custom_fusions=true
    --xla_gpu_enable_address_computation_fusion=true
"""

@dataclass
class ModelConfig:
    """Configuration for the language models."""

    seq_len: int
    n_layers: int
    d_model: int
    num_heads: int
    ff_dim: int
    dropout: float

    batch_size: int
    learning_rate: float
    max_num_batch: int

class RandomDS:
    def __init__(self, batch_size: int, seq_len: int, use_jax=False):
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.use_jax = use_jax
        if self.use_jax:
            self.rng = jax.random.PRNGKey(1)

    def __iter__(self):
        if self.use_jax:
            batches = [
                jax.random.bits(self.rng, shape=(self.batch_size, 100), dtype=jnp.uint8)
                for start in range(0, self.batch_size)
            ]
        else:
            batches = [
                np.random.randint(low=0, high=16, size=(self.batch_size, 100), dtype=np.uint8)
                for start in range(0, self.batch_size)
            ]
        return iter(batches)

class TransformerLayer(fnn.Module):
    d_model: int
    num_heads: int
    ff_dim: int
    dropout: float

    def setup(self):
        self.mha = MultiHeadAttention(
            head_dim=self.d_model // self.num_heads,
            num_attention_heads=self.num_heads,
            input_layernorm=False,
            dtype=jnp.bfloat16,
        )
        self.layer_norm_1 = fnn.LayerNorm(epsilon=1e-5, dtype=jnp.bfloat16,)
        self.linear_1 = fnn.Dense(
            features=self.ff_dim,
            kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
            dtype=jnp.bfloat16,
            param_dtype=jnp.bfloat16
        )
        self.linear_2 = fnn.Dense(
            features=self.d_model,
            kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
            dtype=jnp.bfloat16,
            param_dtype=jnp.bfloat16
        )
        self.layer_norm_2 = fnn.LayerNorm(epsilon=1e-5, dtype=jnp.bfloat16)
        self.dropout_layer = fnn.Dropout(self.dropout, deterministic=False)

    def __call__(
        self, x: jnp.array, mask: jnp.array
    ) -> jnp.array:
        # "correct" type annotations for jax DeviceArrays are numpy ndarrays
        x = self.layer_norm_1(x)
        x = self.mha(inputs_q=x, inputs_kv=x, mask=mask)[0]
        x = x + self.dropout_layer(x)
        x = x + self.dropout_layer(self._ff_block(self.layer_norm_2(x)))
        return x

    def _ff_block(self, x):
        x = jnn.relu(self.linear_1(x))
        x = self.dropout_layer(x)
        x = self.linear_2(x)
        return x

class LM(fnn.Module):
    cfg: ModelConfig

    def setup(self):
        self.byte_embedding = fnn.Embed(
            num_embeddings=256,
            features=self.cfg.d_model,
            embedding_init=jnn.initializers.normal(),
            param_dtype=jnp.bfloat16
        )
        self.positional_encoding = self.param(
            "positional_encoding",
            jnn.initializers.normal(),
            (self.cfg.seq_len, self.cfg.d_model),
            dtype=jnp.bfloat16,
        )
        self.dropout_layer = fnn.Dropout(self.cfg.dropout, deterministic=False)

        self.transformer_layers = [
            TransformerLayer(
                self.cfg.d_model, self.cfg.num_heads, self.cfg.ff_dim, self.cfg.dropout
            )
            for _ in range(self.cfg.n_layers)
        ]
        self.prob_decoder = fnn.Dense(
            features=256,
            kernel_init=fnn.initializers.variance_scaling(1/3, "fan_in", "uniform"),
            dtype=jnp.bfloat16,
            param_dtype=jnp.bfloat16
        )

    def __call__(self, text):
        x = self.byte_embedding(text)
        # Shift x right so causality isn't violated
        x = jnp.concatenate(
                [jnp.zeros([text.shape[0], 1, self.cfg.d_model], dtype=x.dtype), x[:, :-1, :]], axis=1
            )
        x = x + self.positional_encoding
        x = self.dropout_layer(x)

        mask = fnn.attention.make_causal_mask(text)
        for layer in self.transformer_layers:
            x = layer(x, mask=mask)

        return self.prob_decoder(x)

rng = jax.random.PRNGKey(1)
def compute_loss(params, model: LM, text):
    model_out = model.apply(params, text=text, rngs={"dropout": rng})
    one_hots = jnn.one_hot(text, 256)
    loss = optax.softmax_cross_entropy(model_out, one_hots)
    return loss

def setup_model(rng, cfg: ModelConfig):
    model = LM(cfg)

    rng_p, rng_d = jax.random.split(rng)
    params = model.init(
        {"params": rng_p, "dropout": rng_d}, jnp.zeros([cfg.batch_size, cfg.seq_len], dtype=jnp.uint8)
    )
    return params, model

def setup_optimizer(params, cfg: ModelConfig):
    optimizer = optax.adam(cfg.learning_rate)
    opt_state = optimizer.init(params)
    return optimizer, opt_state

def train_loop(
    model: LM, optimizer, opt_state, params, cfg: ModelConfig, datapath: str
):

    def run_train_step(params, opt_state, text_batch):
        loss, grad = jax.value_and_grad(lambda p: compute_loss(p, model, text=text_batch).mean())(params)
        updates, opt_state = optimizer.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss

    fast_train_step = jax.jit(run_train_step, donate_argnums=[0, 1])

    losses = []
    t = time.time()
    log_per = 20

    def multi_train_steps(state, data):
        for single_step_data in data:
            params, opt_state, loss = fast_train_step(params, opt_state, batch)
        return params, opt_state, loss

    dataset = list(RandomDS(cfg.batch_size, cfg.seq_len, use_jax=True))

    for idx, batch in enumerate(dataset):
        params, opt_state, loss = fast_train_step(params, opt_state, batch)
        if (idx + 1) % log_per == 0:
            break

    iter_num = 0
    t = time.time()
    for batch in dataset:
        params, opt_state, loss = fast_train_step(params, opt_state, batch)
        losses.append(loss)
        iter_num += 1
    time_elps = time.time() - t
    speed = iter_num * cfg.batch_size / time_elps
    print(
        f"At iter {iter_num}, loss: {np.mean(losses):.4f}, Speed: {int(speed):d}"
    )
    t = time.time()
    losses = []

    return params, opt_state

def setup_all(cfg: ModelConfig, rng=None):
    rng = jax.random.PRNGKey(1)
    params, model = setup_model(rng, cfg)
    optimizer, opt_state = setup_optimizer(params, cfg)

    return params, model, optimizer, opt_state

if __name__ == "__main__":
    cfg = ModelConfig(
        seq_len=100,
        n_layers=1,
        d_model=512,
        num_heads=2,
        ff_dim=1024,
        dropout=0.1,
        batch_size=128,
        learning_rate=1e-3,
        max_num_batch=5000,
    )

    params, model, optimizer, opt_state = setup_all(cfg)
    params, model, optimizer, opt_state = amp_policy.cast_to_compute((params, model, optimizer, opt_state))
    params, opt_state = train_loop(model, optimizer, opt_state, params, cfg)