audeering / audformat

Format to store media files and annotations
https://audeering.github.io/audformat/
Other
11 stars 1 forks source link

Introduce Database.get() to access all datapoints for a scheme #398

Closed hagenw closed 10 months ago

hagenw commented 11 months ago

When working with multiple databases you are often faced with the task of getting labels for a particular scheme, e.g. gender. But as those can be stored very differently it is very hard to write code that works for all databases.

I would propose to add audformat.Database.get(schemes), that you can provide a single scheme or list of schemes and it will return all relevant data points in a series.

An example implementation is given by:

import typing

import pandas as pd

import audb
import audeer
import audformat

def get(
        db: audformat.Database,
        schemes: typing.Union[str, typing.Sequence],
) -> pd.Series:
    r"""Get all data points for the requested scheme(s)."""

    requested_schemes = audeer.to_list(schemes)

    # Check if requested schemes
    # are stored as labels in other schemes
    scheme_mappings = []
    for scheme_id, scheme in db.schemes.items():

        if scheme.uses_table and scheme_id in db.misc_tables:
            # Labels stored as misc table
            for column_id, column in db[scheme_id].columns.items():
                for requested_scheme in requested_schemes:
                    if (
                            column.scheme_id is not None
                            and requested_scheme in column.scheme_id
                    ):
                        scheme_mappings.append((scheme_id, requested_scheme))
                        break
        else:
            # Labels stored in scheme
            for requested_scheme in requested_schemes:
                if scheme_id in scheme.labels_as_list:
                    scheme_mappings.append((scheme_id, requested_scheme))
                    break

        # Get data points for requested schemes
        ys = []
        for table_id, table in db.tables.items():
            for column_id, column in table.columns.items():
                if any_scheme_in_column(requested_schemes, column, column_id):
                    y = db[table_id][column_id].get()
                    ys.append(clean_y(y))
                else:
                    for (scheme_id, mapping) in scheme_mappings:
                        if scheme_id in column_id:
                            y = db[table_id][column_id].get(map=mapping)
                            ys.append(clean_y(y))

        index = audformat.utils.union([y.index for y in ys])
        y = audformat.utils.concat(ys).loc[index]
        y.name = ', '.join(requested_schemes)

    return y

def any_scheme_in_column(schemes, column, column_id):
    r"""Check if any of the schemes is attached to the column.

    At the moment we also check for the column name,
    as some databases might forget to attach the scheme.

    """
    return any(
        [
            scheme in column_id
            or (column.scheme_id is not None and scheme in column.scheme_id)
            for scheme in schemes
        ]
    )

def clean_y(y: pd.Series) -> pd.Series:
    r"""Remove NaN and normalize dtype of series."""
    # TODO: at the moment we simply convert to string
    # to avoid errors for different categorical data types.
    # In real implementation we need to adjust those dtypes.
    return y.dropna().astype('string')

We can test with:

db = audb.load(
    'emodb',
    version='1.4.1',
    only_metadata=True,
    full_path=False,
    verbose=False,
)
y = get(db, ['gender', 'sex'])
print(y.head())

which returns

file
wav/03a01Fa.wav    male
wav/03a01Nc.wav    male
wav/03a01Wa.wav    male
wav/03a02Fc.wav    male
wav/03a02Nc.wav    male
Name: gender, sex, dtype: string

and

y = get(db, 'age')
print(y.head())

which returns

file
wav/03a01Fa.wav    31
wav/03a01Nc.wav    31
wav/03a01Wa.wav    31
wav/03a02Fc.wav    31
wav/03a02Nc.wav    31
Name: age, dtype: string

What is not yet solved is the handling of multi-channel databases. E.g. if you have a two speaker conversation and gender labels for both speakers, audformat.utils.concat(ys).loc[index] will fail as you have the same index entry with different labels (or you have the same labels, then it would not fail, but you will have missing data).

We might also think about providing an argument to get() to restrict it to a given table type, or limiting it to certain samples by providing an index argument.

/cc @ChristianGeng @frankenjoe @felixbur @audeerington

felixbur commented 11 months ago

I'd love that

frankenjoe commented 11 months ago

Yes, sounds like a useful feature and I can see you already started with it in https://github.com/audeering/audformat/pull/399