awslabs / s3-connector-for-pytorch

The Amazon S3 Connector for PyTorch delivers high throughput for PyTorch training jobs that access and store data in Amazon S3.
BSD 3-Clause "New" or "Revised" License
112 stars 17 forks source link

support tell for s3writer #228

Closed ryxli closed 4 weeks ago

ryxli commented 1 month ago

Description

Add tell to writer to get current position of putobjectstream

Additional context

Related items

Testing


By submitting this pull request, I confirm that my contribution is made under the terms of BSD 3-Clause License and I agree to the terms of the LICENSE.

ryxli commented 1 month ago

Hi @IsaevIlya

provide a brief reasoning for adding the tell method to the S3Writer class.

More generally, adding tell method to S3Writer class makes sense because it extends the io.BufferedIOBase and enhances the interoperability and usability of the S3Writer class as a in place drop in for a Pathlike object. Without it, parent class functionality like seek(), etc.. won't work, even if S3Writer doesn't support the Buffer protocol for now (ref). The S3Reader class implements the tell() method in a similar fashion: https://github.com/awslabs/s3-connector-for-pytorch/blob/main/s3torchconnector/src/s3torchconnector/s3reader.py#L200

More specifically, I can give an example of one use case where having the stream position is useful:

When integrating the MountPointS3Client with pytorch distributed checkpointing, one of the API exposed is torch.distributed.checkpoint.FileSystemReader For each TensorObject / ByteObject in the sharded state dict, the torch implementation will create a plan to write each object into a stream:

            with create_stream(file_name, "wb") as stream:
                for write_item in bytes_w:
                    data = planner.resolve_data(write_item)
                    write_results.append(
                        _write_item(stream, data, write_item, storage_key)
                    )

                for tensor, write_item in loader.values():
                    assert tensor.is_cpu
                    write_results.append(
                        _write_item(stream, tensor, write_item, storage_key)
                    )

The _write_item helper is written as follows:

def _write_item(
    stream: io.IOBase,
    data: Union[io.BytesIO, torch.Tensor],
    write_item: WriteItem,
    storage_key: str,
) -> WriteResult:
    offset = stream.tell()

    if write_item.type == WriteItemType.BYTE_IO:
        assert isinstance(data, io.BytesIO)
        stream.write(data.getbuffer())
    else:
        assert isinstance(data, torch.Tensor)
        assert data.device == torch.device("cpu")
        torch.save(data, cast(IO[bytes], stream))
    length = stream.tell() - offset

    return WriteResult(
        index=write_item.index,
        size_in_bytes=length,
        storage_data=_StorageInfo(storage_key, offset, length),
    )

There is an intermediate API which stores write results, and is used later when the distributed checkpoint plan is executed in a queue like manner.

With the current state of S3Writer, it is possible to get the position of the stream via ret_val of S3Writer.write, but it would require additional modification to the Pytorch source code, rather than a simple extension of FileSystemWriter via drop in replacement. Something similar to:

def _modified_write_item(
    stream: S3Writer,
    data: Union[io.BytesIO, torch.Tensor],
    write_item: WriteItem,
    storage_key: str,
) -> WriteResult:
    offset = 0 # some start of position, with S3Writer, on _enter_, this should always be zero since we write to s3 object from scratch rather than appending to existing s3 object.

    if write_item.type == WriteItemType.BYTE_IO:
        assert isinstance(data, io.BytesIO)
        length = stream.write(data.getbuffer()) # <--- get length from write return here
    else:
        assert isinstance(data, torch.Tensor)
        assert data.device == torch.device("cpu")
        torch.save(data, cast(IO[bytes], stream))  # <--- would need additional modification to calculate size of tensor since torch.save does not return length, and there is no way to retrieve the length of the data written using S3Writer when passed into torch.save
        # length = something

    return WriteResult(
        index=write_item.index,
        size_in_bytes=length,
        storage_data=_StorageInfo(storage_key, offset, length),
    )

Another workaround would be to collect all of the results into a in memory Bytes buffer first, and then write to the S3Writer, but that's a less preferred solution and could result in oom errors.

include unit/integration tests to ensure the new functionality works as expected.

Sure, I can add. I didn't include because I couldn't find any existing unit tests for S3Writer. S3Reader I found here: https://github.com/awslabs/s3-connector-for-pytorch/blob/main/s3torchconnector/tst/unit/test_s3reader.py

ryxli commented 1 month ago

added tests

ryxli commented 4 weeks ago

removed unused code.

tests passed:

% pytest s3torchconnector/tst/unit/test_s3writer.py -s
============================================================= test session starts =============================================================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /fsx-Training/shopqa-training-fsx-prod-us-east-1/home/rynli/workspace/s3-connector-for-pytorch/s3torchconnector
configfile: pyproject.toml
plugins: hypothesis-6.111.2, timeout-2.3.1
collected 5 items                                                                                                                             

s3torchconnector/tst/unit/test_s3writer.py .....

============================================================== warnings summary ===============================================================
../../../../../../opt/conda/envs/s3connector/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:258
  /opt/conda/envs/s3connector/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
    cpu = _conversion_method_template(device=torch.device("cpu"))

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================== 5 passed, 1 warning in 3.69s =========================================================