Open hugo-ekinge opened 4 weeks ago
Can you post the code to replicate? Thanks!
I have made this function that is largely a trimmed down version of predict_single_split
with some modifications. It is placed inside benchmark.py
def create_embeddings(args, model_name, device, custom_encoder):
""" Create embeddings for a single model """
embedding_dir = os.path.join(get_path(args.embed_dataroot), "ALL", model_name)
os.makedirs(embedding_dir, exist_ok=True)
# Embed patches
logger.info(f"Embedding ALL tiles using {model_name} encoder and custom code")
weights_path = get_bench_weights(args.weights_root, model_name)
if model_name == 'custom_encoder':
encoder = custom_encoder
else:
encoder: InferenceEncoder = inf_encoder_factory(model_name)(weights_path)
precision = encoder.precision
patches_dir = os.path.join(get_path('hest_data'), 'patches')
for root, _, files in os.walk(patches_dir):
for file in tqdm(files):
if file.endswith('.h5'):
tile_h5_path = os.path.join(root, file)
sample_id = os.path.splitext(file)[0]
embed_path = os.path.join(embedding_dir, f'{sample_id}.h5')
if not os.path.isfile(embed_path) or args.overwrite:
_ = encoder.eval()
encoder.to(device)
tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms)
tile_dataloader = torch.utils.data.DataLoader(tile_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers)
_ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision)
else:
logger.info(f"Skipping {sample_id} as it already exists")
To connect it to the existing benchmark workflow I added an argument parser.add_argument('--create_embeddings', type=bool, default='false', help='only create embeddings')
.
I then added to benchmark_grid
at the top of the for model_name in model_names:
loop:
if args.create_embeddings:
create_embeddings(args, model_name, device, custom_encoder)
continue
and right after the loop
if args.create_embeddings:
break
Not the prettiest but I just wanted to hook into the existing code to run a few quick tests. Anyway, simply add create_embeddings: True
to bench_config.yaml
and run the benchmark like normal and it will try to create embeddings of all the data in hest_data/patches
and place it in an ALL
folder with the other embeddings, sorted for each model included.
With a batch size of 128, I noticed the issue with for example id TENX137
but there are also others.
To debug, I also added
if dset.dtype != val.dtype:
print(f'Path: {output_fpath}')
print(f'Key: {key}')
print(f'dset.dtype: {dset.dtype}')
print(f'val.dtype: {val.dtype}')
print(f'dset: {dset}\n{dset[:]}')
print(f'val: {val}')
right before the assert dset.dtype == val.dtype
in file_utils.py
.
Note that if overwrite
is not set to True
just rerunning the same thing will skip the samples where assertion failed previously, since the files have already been created (although incomplete).
Have you been able to replicate?
It's hard to reproduce without all seeing all your modifications. You might be able to solve the issue by modifying the save_hdh5, making sure the string types are not changing depending on the size of the string. Can you confirm if saving to h5 using this function solves the issue. thanks.
def save_hdf5_revised(output_fpath,
asset_dict,
attr_dict= None,
mode='a',
auto_chunk = True,
chunk_size = None):
with h5py.File(output_fpath, mode) as f:
for key, val in asset_dict.items():
data_shape = val.shape
if len(data_shape) == 1:
val = np.expand_dims(val, axis=1)
data_shape = val.shape
# Determine if the data is of string type
if np.issubdtype(val.dtype, np.string_) or np.issubdtype(val.dtype, np.unicode_):
data_type = h5py.string_dtype(encoding='utf-8')
else:
data_type = val.dtype
if key not in f: # if key does not exist, create dataset
if auto_chunk:
chunks = True # let h5py decide chunk size
else:
chunks = (chunk_size,) + data_shape[1:]
dset = f.create_dataset(
key,
shape=data_shape,
chunks=chunks,
maxshape=(None,) + data_shape[1:],
dtype=data_type
)
# Save attribute dictionary
if attr_dict is not None:
if key in attr_dict.keys():
for attr_key, attr_val in attr_dict[key].items():
dset.attrs[attr_key] = attr_val
dset[:] = val
else:
dset = f[key]
dset.resize(len(dset) + data_shape[0], axis=0)
if dset.dtype != data_type:
raise TypeError(f"Data type mismatch for key '{key}'. Dataset dtype: {dset.dtype}, value dtype: {data_type}")
dset[-data_shape[0]:] = val
Hi and thank you for all the work you have done so far.
I am trying to adapt your
embed_tiles
function to create embeddings for arbitrary samples, not just the ones included in HEST-Benchmark. One issue that I came across when looping over all samples is thatassert dset.dtype == val.dtype
insave_hdf5
triggered on some samples. Upon further inspection the issue seems to be thatbarcodes
is not always of fixed length. On some samples, the first batch of say 128 tiles might have only numerical barcodes like[b'0'] [b'1'] [b'2'] ... [b'994']
etc instead of for example[b'AAACACCAATAACTGC-1']
which is always of fixed length. This is a problem when the number of characters increases to say[b'1014']
because suddenly the dtype is|S4
instead of|S3
.One thing I tried was upping the batch size to 1024 instead which eliminated the problem of going from
|S3
to|S4
, but instead I get the same issue (although less often) when going from|S4
to|S5
in samples with large number of barcodes...This is what I have been able to figure out myself so far, but what would you recommend in order to fix it in a more robust way and properly handle all samples in the dataset?
(An interesting note I made is that in the
hest_data/patches
folder the dtype seems to beobject
rather than any S type at all forbarcodes
. So the conversion seems to happen somewhere in the dataset/dataloader pipeline and I am wondering if this is intentional or not and what the purpose behind it is in that case?)