Closed ryxli closed 4 weeks ago
Hi @IsaevIlya
provide a brief reasoning for adding the tell method to the S3Writer class.
More generally, adding tell
method to S3Writer class makes sense because it extends the io.BufferedIOBase
and enhances the interoperability and usability of the S3Writer class as a in place drop in for a Pathlike
object. Without it, parent class functionality like seek(), etc.. won't work, even if S3Writer doesn't support the Buffer protocol for now (ref).
The S3Reader
class implements the tell() method in a similar fashion: https://github.com/awslabs/s3-connector-for-pytorch/blob/main/s3torchconnector/src/s3torchconnector/s3reader.py#L200
More specifically, I can give an example of one use case where having the stream position is useful:
When integrating the MountPointS3Client with pytorch distributed checkpointing, one of the API exposed is torch.distributed.checkpoint.FileSystemReader For each TensorObject / ByteObject in the sharded state dict, the torch implementation will create a plan to write each object into a stream:
with create_stream(file_name, "wb") as stream:
for write_item in bytes_w:
data = planner.resolve_data(write_item)
write_results.append(
_write_item(stream, data, write_item, storage_key)
)
for tensor, write_item in loader.values():
assert tensor.is_cpu
write_results.append(
_write_item(stream, tensor, write_item, storage_key)
)
The _write_item
helper is written as follows:
def _write_item(
stream: io.IOBase,
data: Union[io.BytesIO, torch.Tensor],
write_item: WriteItem,
storage_key: str,
) -> WriteResult:
offset = stream.tell()
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
stream.write(data.getbuffer())
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
torch.save(data, cast(IO[bytes], stream))
length = stream.tell() - offset
return WriteResult(
index=write_item.index,
size_in_bytes=length,
storage_data=_StorageInfo(storage_key, offset, length),
)
There is an intermediate API which stores write results, and is used later when the distributed checkpoint plan is executed in a queue like manner.
With the current state of S3Writer, it is possible to get the position of the stream via ret_val
of S3Writer.write
, but it would require additional modification to the Pytorch source code, rather than a simple extension of FileSystemWriter
via drop in replacement. Something similar to:
def _modified_write_item(
stream: S3Writer,
data: Union[io.BytesIO, torch.Tensor],
write_item: WriteItem,
storage_key: str,
) -> WriteResult:
offset = 0 # some start of position, with S3Writer, on _enter_, this should always be zero since we write to s3 object from scratch rather than appending to existing s3 object.
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
length = stream.write(data.getbuffer()) # <--- get length from write return here
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
torch.save(data, cast(IO[bytes], stream)) # <--- would need additional modification to calculate size of tensor since torch.save does not return length, and there is no way to retrieve the length of the data written using S3Writer when passed into torch.save
# length = something
return WriteResult(
index=write_item.index,
size_in_bytes=length,
storage_data=_StorageInfo(storage_key, offset, length),
)
Another workaround would be to collect all of the results into a in memory Bytes buffer first, and then write to the S3Writer, but that's a less preferred solution and could result in oom errors.
include unit/integration tests to ensure the new functionality works as expected.
Sure, I can add. I didn't include because I couldn't find any existing unit tests for S3Writer. S3Reader I found here: https://github.com/awslabs/s3-connector-for-pytorch/blob/main/s3torchconnector/tst/unit/test_s3reader.py
added tests
removed unused code.
tests passed:
% pytest s3torchconnector/tst/unit/test_s3writer.py -s
============================================================= test session starts =============================================================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /fsx-Training/shopqa-training-fsx-prod-us-east-1/home/rynli/workspace/s3-connector-for-pytorch/s3torchconnector
configfile: pyproject.toml
plugins: hypothesis-6.111.2, timeout-2.3.1
collected 5 items
s3torchconnector/tst/unit/test_s3writer.py .....
============================================================== warnings summary ===============================================================
../../../../../../opt/conda/envs/s3connector/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:258
/opt/conda/envs/s3connector/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
cpu = _conversion_method_template(device=torch.device("cpu"))
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================== 5 passed, 1 warning in 3.69s =========================================================
Description
Add tell to writer to get current position of putobjectstream
Additional context
Related items
Testing
By submitting this pull request, I confirm that my contribution is made under the terms of BSD 3-Clause License and I agree to the terms of the LICENSE.