When working with multiple databases you are often faced with the task of getting labels for a particular scheme, e.g. gender. But as those can be stored very differently it is very hard to write code that works for all databases.
I would propose to add audformat.Database.get(schemes), that you can provide a single scheme or list of schemes and it will return all relevant data points in a series.
An example implementation is given by:
import typing
import pandas as pd
import audb
import audeer
import audformat
def get(
db: audformat.Database,
schemes: typing.Union[str, typing.Sequence],
) -> pd.Series:
r"""Get all data points for the requested scheme(s)."""
requested_schemes = audeer.to_list(schemes)
# Check if requested schemes
# are stored as labels in other schemes
scheme_mappings = []
for scheme_id, scheme in db.schemes.items():
if scheme.uses_table and scheme_id in db.misc_tables:
# Labels stored as misc table
for column_id, column in db[scheme_id].columns.items():
for requested_scheme in requested_schemes:
if (
column.scheme_id is not None
and requested_scheme in column.scheme_id
):
scheme_mappings.append((scheme_id, requested_scheme))
break
else:
# Labels stored in scheme
for requested_scheme in requested_schemes:
if scheme_id in scheme.labels_as_list:
scheme_mappings.append((scheme_id, requested_scheme))
break
# Get data points for requested schemes
ys = []
for table_id, table in db.tables.items():
for column_id, column in table.columns.items():
if any_scheme_in_column(requested_schemes, column, column_id):
y = db[table_id][column_id].get()
ys.append(clean_y(y))
else:
for (scheme_id, mapping) in scheme_mappings:
if scheme_id in column_id:
y = db[table_id][column_id].get(map=mapping)
ys.append(clean_y(y))
index = audformat.utils.union([y.index for y in ys])
y = audformat.utils.concat(ys).loc[index]
y.name = ', '.join(requested_schemes)
return y
def any_scheme_in_column(schemes, column, column_id):
r"""Check if any of the schemes is attached to the column.
At the moment we also check for the column name,
as some databases might forget to attach the scheme.
"""
return any(
[
scheme in column_id
or (column.scheme_id is not None and scheme in column.scheme_id)
for scheme in schemes
]
)
def clean_y(y: pd.Series) -> pd.Series:
r"""Remove NaN and normalize dtype of series."""
# TODO: at the moment we simply convert to string
# to avoid errors for different categorical data types.
# In real implementation we need to adjust those dtypes.
return y.dropna().astype('string')
We can test with:
db = audb.load(
'emodb',
version='1.4.1',
only_metadata=True,
full_path=False,
verbose=False,
)
y = get(db, ['gender', 'sex'])
print(y.head())
which returns
file
wav/03a01Fa.wav male
wav/03a01Nc.wav male
wav/03a01Wa.wav male
wav/03a02Fc.wav male
wav/03a02Nc.wav male
Name: gender, sex, dtype: string
What is not yet solved is the handling of multi-channel databases. E.g. if you have a two speaker conversation and gender labels for both speakers, audformat.utils.concat(ys).loc[index] will fail as you have the same index entry with different labels (or you have the same labels, then it would not fail, but you will have missing data).
We might also think about providing an argument to get() to restrict it to a given table type, or limiting it to certain samples by providing an index argument.
When working with multiple databases you are often faced with the task of getting labels for a particular scheme, e.g.
gender
. But as those can be stored very differently it is very hard to write code that works for all databases.I would propose to add
audformat.Database.get(schemes)
, that you can provide a single scheme or list of schemes and it will return all relevant data points in a series.An example implementation is given by:
We can test with:
which returns
and
which returns
What is not yet solved is the handling of multi-channel databases. E.g. if you have a two speaker conversation and gender labels for both speakers,
audformat.utils.concat(ys).loc[index]
will fail as you have the same index entry with different labels (or you have the same labels, then it would not fail, but you will have missing data).We might also think about providing an argument to
get()
to restrict it to a given table type, or limiting it to certain samples by providing anindex
argument./cc @ChristianGeng @frankenjoe @felixbur @audeerington