chanzuckerberg / cellxgene-census

CZ CELLxGENE Discover Census
https://chanzuckerberg.github.io/cellxgene-census/
MIT License
72 stars 18 forks source link

Unexpected behavior when getting length of ExperimentDatapipe #1183

Open jdenhof opened 4 weeks ago

jdenhof commented 4 weeks ago

I am using the experiment data pipe in conjunction with the experiment_dataloader and noticed that len() on an ExperimentDatapipe returns the total number of samples and does not take in account batch size. I feed the results of the len() to the random_split as the documentation says this is best practice. I believe the random_split expects the total iterations not total number of samples. I could divide this number by the batch size but I feel this may cause issues if there are left over samples.

To Reproduce

import tiledbsoma as soma
import cellxgene_census as cell_census
import cellxgene_census.experimental.ml as census_ml

census = cell_census.open_soma(census_version="2023-12-15")

obs_query_value_filter='is_primary_data == True and assay in ["microwell-seq", "10x 3\' v1", "10x 3\' v2", "10x 3\' v3", "10x 3\' transcription profiling", "10x 5\' transcription profiling", "10x 5\' v1", "10x 5\' v2"]'

obs_col_names = (
    "dataset_id",
    "assay", 
    "donor_id",
)

def create_experiment_dp(batch_size):
   return census_ml.ExperimentDataPipe(
              experiment=census["census_data"]["homo_sapiens"],
              measurement_name="RNA",
              X_name="normalized",
              obs_query=soma.AxisQuery(value_filter=obs_query_value_filter),
              obs_column_names= obs_col_names,
              shuffle=True,
              batch_size=batch_size,
              seed=42,
              soma_chunk_size=None)

print(len(create_experiment_dp(64))) # 31293506
print(len(create_experiment_dp(128))) # 31293506

Is this expected behavior? If so how would you go about dropping the last batch if it doesn't align with batch_size. You cannot pass this to the Dataloader ask kwarg because batch_size=None is mutually exclusive with drop_last. Would it also cause problems passing this as total_samples to the random_split?

ebezzi commented 4 weeks ago

Hey @jdenhof ,

I did some research and it definitely looks like len() should return the number of batches produced by the iteration. Does that align with your expectation?

jdenhof commented 4 weeks ago

Yes that is what I was expecting when calling len() on the experiment_datapipe.

ebezzi commented 4 weeks ago

Sounds good, I'll push a fix for this. I'll let you know when it's released.