mahmoodlab / HEST

HEST: Bringing Spatial Transcriptomics and Histopathology together - NeurIPS 2024
Other
164 stars 12 forks source link

Barcodes loaded in inconsistent format (different data types based on batch size) #62

Open hugo-ekinge opened 4 weeks ago

hugo-ekinge commented 4 weeks ago

Hi and thank you for all the work you have done so far.

I am trying to adapt your embed_tiles function to create embeddings for arbitrary samples, not just the ones included in HEST-Benchmark. One issue that I came across when looping over all samples is that assert dset.dtype == val.dtype in save_hdf5 triggered on some samples. Upon further inspection the issue seems to be that barcodes is not always of fixed length. On some samples, the first batch of say 128 tiles might have only numerical barcodes like [b'0'] [b'1'] [b'2'] ... [b'994'] etc instead of for example [b'AAACACCAATAACTGC-1'] which is always of fixed length. This is a problem when the number of characters increases to say [b'1014'] because suddenly the dtype is |S4 instead of |S3.

One thing I tried was upping the batch size to 1024 instead which eliminated the problem of going from |S3 to |S4, but instead I get the same issue (although less often) when going from |S4 to |S5 in samples with large number of barcodes...

This is what I have been able to figure out myself so far, but what would you recommend in order to fix it in a more robust way and properly handle all samples in the dataset?

(An interesting note I made is that in the hest_data/patches folder the dtype seems to be object rather than any S type at all for barcodes. So the conversion seems to happen somewhere in the dataset/dataloader pipeline and I am wondering if this is intentional or not and what the purpose behind it is in that case?)

guillaumejaume commented 4 weeks ago

Can you post the code to replicate? Thanks!

hugo-ekinge commented 4 weeks ago

I have made this function that is largely a trimmed down version of predict_single_split with some modifications. It is placed inside benchmark.py

def create_embeddings(args, model_name, device, custom_encoder):
    """ Create embeddings for a single model """

    embedding_dir = os.path.join(get_path(args.embed_dataroot), "ALL", model_name)
    os.makedirs(embedding_dir, exist_ok=True)

    # Embed patches
    logger.info(f"Embedding ALL tiles using {model_name} encoder and custom code")
    weights_path = get_bench_weights(args.weights_root, model_name)
    if model_name == 'custom_encoder':
        encoder = custom_encoder
    else:
        encoder: InferenceEncoder = inf_encoder_factory(model_name)(weights_path)
    precision = encoder.precision

    patches_dir = os.path.join(get_path('hest_data'), 'patches')
    for root, _, files in os.walk(patches_dir):
        for file in tqdm(files):
            if file.endswith('.h5'):
                tile_h5_path = os.path.join(root, file)
                sample_id = os.path.splitext(file)[0]
                embed_path = os.path.join(embedding_dir, f'{sample_id}.h5')
                if not os.path.isfile(embed_path) or args.overwrite:
                    _ = encoder.eval()
                    encoder.to(device)

                    tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms)
                    tile_dataloader = torch.utils.data.DataLoader(tile_dataset, 
                                                                    batch_size=1, 
                                                                    shuffle=False,
                                                                    num_workers=args.num_workers)

                    _ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision)
                else:
                    logger.info(f"Skipping {sample_id} as it already exists")

To connect it to the existing benchmark workflow I added an argument parser.add_argument('--create_embeddings', type=bool, default='false', help='only create embeddings').

I then added to benchmark_grid at the top of the for model_name in model_names: loop:

if args.create_embeddings:
    create_embeddings(args, model_name, device, custom_encoder)
    continue

and right after the loop

if args.create_embeddings:
    break

Not the prettiest but I just wanted to hook into the existing code to run a few quick tests. Anyway, simply add create_embeddings: True to bench_config.yaml and run the benchmark like normal and it will try to create embeddings of all the data in hest_data/patches and place it in an ALL folder with the other embeddings, sorted for each model included.

With a batch size of 128, I noticed the issue with for example id TENX137 but there are also others.

To debug, I also added

if dset.dtype != val.dtype:
    print(f'Path: {output_fpath}')
    print(f'Key: {key}')
    print(f'dset.dtype: {dset.dtype}')
    print(f'val.dtype: {val.dtype}')
    print(f'dset: {dset}\n{dset[:]}')
    print(f'val: {val}')

right before the assert dset.dtype == val.dtype in file_utils.py.

Note that if overwrite is not set to True just rerunning the same thing will skip the samples where assertion failed previously, since the files have already been created (although incomplete).

hugo-ekinge commented 3 weeks ago

Have you been able to replicate?

guillaumejaume commented 3 weeks ago

It's hard to reproduce without all seeing all your modifications. You might be able to solve the issue by modifying the save_hdh5, making sure the string types are not changing depending on the size of the string. Can you confirm if saving to h5 using this function solves the issue. thanks.

def save_hdf5_revised(output_fpath, 
                      asset_dict, 
                      attr_dict= None, 
                      mode='a', 
                      auto_chunk = True,
                      chunk_size = None):
    with h5py.File(output_fpath, mode) as f:
        for key, val in asset_dict.items():
            data_shape = val.shape
            if len(data_shape) == 1:
                val = np.expand_dims(val, axis=1)
                data_shape = val.shape

            # Determine if the data is of string type
            if np.issubdtype(val.dtype, np.string_) or np.issubdtype(val.dtype, np.unicode_):
                data_type = h5py.string_dtype(encoding='utf-8')
            else:
                data_type = val.dtype

            if key not in f:  # if key does not exist, create dataset
                if auto_chunk:
                    chunks = True  # let h5py decide chunk size
                else:
                    chunks = (chunk_size,) + data_shape[1:]
                dset = f.create_dataset(
                    key,
                    shape=data_shape,
                    chunks=chunks,
                    maxshape=(None,) + data_shape[1:],
                    dtype=data_type
                )
                # Save attribute dictionary
                if attr_dict is not None:
                    if key in attr_dict.keys():
                        for attr_key, attr_val in attr_dict[key].items():
                            dset.attrs[attr_key] = attr_val
                dset[:] = val
            else:
                dset = f[key]
                dset.resize(len(dset) + data_shape[0], axis=0)
                if dset.dtype != data_type:
                    raise TypeError(f"Data type mismatch for key '{key}'. Dataset dtype: {dset.dtype}, value dtype: {data_type}")
                dset[-data_shape[0]:] = val