google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.26k stars 503 forks source link

rm fairescale #46

Closed Mon-ius closed 6 months ago

Mon-ius commented 6 months ago

Since only implementation of fairescale is from fairscale.nn.model_parallel.utils import divide_and_check_no_remainder, split_tensor_along_last_dim where can just be migrate as :

from typing import Tuple

def ensure_divisibility(numerator: int, denominator: int) -> None:
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)

def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

def split_tensor_along_last_dim(
    tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
) -> Tuple[torch.Tensor, ...]:
    """Split a tensor along its last dimension.
    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                in memory.
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list

It makes upon build more simpler and cleaner 🤗

google-cla[bot] commented 6 months ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

pengchongjin commented 6 months ago

Sounds good. I also prefer simple and less dependencies.

Have you test the code using run_xla.py? You can follow the instructions in the "Try It out with PyTorch/XLA" section in README.md.

Btw, since you are here, you can also remove all the redundant dependencies in dockerfiles. https://github.com/google/gemma_pytorch/tree/main/docker

Thanks

Mon-ius commented 6 months ago

Absolutely! I tested both on TPUv4-8 x 8 and A100 GPU x8

Mon-ius commented 6 months ago

Done 🤗

pengchongjin commented 6 months ago

Maybe also remove dependencies in Dockerfile? https://github.com/google/gemma_pytorch/blob/main/docker/Dockerfile

Mon-ius commented 6 months ago

my mistake, should be all done 🤗

pengchongjin commented 6 months ago

Awesome, thanks for the contribution. Merging it now.