ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Set tensors to cpu in _append_to_hooked_list #278

Closed danbraunai-apollo closed 7 months ago

danbraunai-apollo commented 8 months ago

Currently, rib.hook_fns._append_to_hooked_list looks like:

def _append_to_hooked_list(
    hooked_data: dict[str, Any],
    hook_name: str,
    data_key: str,
    element_to_append: Any,
) -> None:
    """Append the given element to a hooked list. Creates the list if it doesn't exist.

    Args:
        hooked_data: Dictionary of hook data that will be updated.
        hook_name: Name of hook. Used as a 1st-level key in `hooked_data`.
        data_key: Name of data. Used as a 2nd-level key in `hooked_data`.
        element_to_append: Appended to hooked data.
    """
    hooked_data.setdefault(hook_name, {}).setdefault(data_key, [])
    hooked_data[hook_name][data_key].append(element_to_append)

We should convert to CPU when storing these, and handle converting back to GPU later wherever these things tensors are used.