Closed mmcdermott closed 4 months ago
@mmcdermott what polars and numpy version were used for this branch?
Building the dataset works fine with this branch.
When running pretrain, I initially got a runtime error stemming from here: https://github.com/mmcdermott/EventStreamGPT/blob/29c3b9f53f8b1d1e16188f81a18face2d1c4adce/EventStream/data/pytorch_dataset.py#L253-L255
The error being:
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (55877,) + inhomogeneous part.
When I change the numpy array dtype to object, it just hangs (it runs for an incredibly long time, and I haven't been patient enough to wait it out). I don't know whether safetensors or the nested_ragged_tensors are to blame (I assume the latter). This occurs when trying to run pretrain, sampling about 90,000 subjects from a full dataset of about 600,000.
@juancq , definitely don't feel obliged to wait it out when setting it to an object; that will be very slow and defeats the purpose of this change. One question, though; in your data do you have reason to believe that your subjects will have varying numbers of observations of static measurements per row? I believe that what is going on here is that something that is treated as a dense tensor is in reality a ragged one.
@juancq , definitely don't feel obliged to wait it out when setting it to an object; that will be very slow and defeats the purpose of this change. One question, though; in your data do you have reason to believe that your subjects will have varying numbers of observations of static measurements per row? I believe that what is going on here is that something that is treated as a dense tensor is in reality a ragged one.
I don't follow, what would be an example?
This branch also has parts that break with certain polars versions. This makes it hard to test because I don't know which exact polar version to use.
@mmcdermott I have sorted out polars issues and have gotten further testing this.
My pretraining now hangs here: https://github.com/mmcdermott/EventStreamGPT/blob/fa33387247f43e760d11c117c3eec5a983778f2d/EventStream/data/pytorch_dataset.py#L161
When I kill the script, the stack trace is something along the following lines (in ragged tensors):
If this is too cryptic, let me know and I'll rephrase or can post an issue on nested_ragged_tensors repo.
Thank you @juancq -- I've been travelling quite a bit and been otherwise occupied for the last month and a half, but I'm trying to push a new major version of ESGPT that addresses these issues and the other memory issues. I assume your last comment still reflects the state of things with this change for you?
@coderabbitai review
The recent changes in the EventStream project involve enhancing data handling, caching mechanisms, and error management across multiple files. These updates include refining file extension checks, improving exception handling for data conversion, restructuring caching mechanisms for efficiency, and aligning test cases with the updated data structures and logic.
Files | Change Summary |
---|---|
EventStream/data/config.py |
Modified tensorized_cached_files to use a dictionary comprehension with a different file extension check. |
EventStream/data/dataset_polars.py |
Added defaultdict import, improved exception handling in _filter_col_inclusion , and updated build_DL_cached_representation for new aggregation of time_delta . |
EventStream/data/pytorch_dataset.py |
Restructured caching mechanisms, introduced new tensor structures for caching, revamped logic for handling dense and ragged tensors, and updated caching process for improved efficiency. |
EventStream/baseline/FT_task_baseline.py |
Renamed ConstructorPytorchDataset to PytorchDataset , affecting task normalization. |
EventStream/data/dataset_base.py |
Added imports for polars and JointNestedRaggedTensorDict , modified caching logic for DL representations, handled sharding, and cached NRT representations using Polars. |
tests/data/test_pytorch_dataset.py |
Removed unnecessary imports, updated references and data structures, refactored methods for handling temporary directories, and revised test cases to align with the new data structures and logic. |
EventStream/data/types.py |
Changed representation of null values from "null" to "nul" in convert_to_DL_DF . |
Objective (Issue #73) | Addressed | Explanation |
---|---|---|
Memory consumption increase with DataLoader and num_workers > 0 |
❓ | The changes include significant restructuring of caching mechanisms and data handling, which might indirectly address memory consumption issues. However, no direct fix for the issue is evident. |
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?
Thank you @juancq -- I've been travelling quite a bit and been otherwise occupied for the last month and a half, but I'm trying to push a new major version of ESGPT that addresses these issues and the other memory issues. I assume your last comment still reflects the state of things with this change for you?
Yes, the state of things on my end are the same as of the time I wrote the last comment.
This is blocked by #104
@juancq I know this update is long overdue, but the recent pushes use the nested ragged tensor code in a new way that should dramatically reduce the CPU memory burden during model training. You'll need to re-build the base dataset object first to produce the right cached files (though I may write a conversion script to ease that cost, if that would be helpful) but once that is done this should, with minimal to no impact on throughput, set things up to only pull the patient data from disk as needed rather than loading it all in memory at all. Note this requires updating to the latest version of nested_ragged_tensors as well. If you try it and find it useful or find issues with it, I'd be very appreciative and curious of your findings!
@mmcdermott thanks for all the hard work. I tested this branch on my dataset. The previous bugs are gone. I am now seeing about a 7% runtime improvement per epoch and about 30% lower memory usage.
Fantastic! Thanks so much @juancq . I'll do some final testing just to make sure there are no issues and plan to merge this branch in soon. Glad this has resolved your issues and induced other improvements besides.
This should fix #73 as well