SainsburyWellcomeCentre / aeon_mecha

Project Aeon's main library for interfacing with acquired data. Contains modules for raw data file io, data querying, data processing, data qc, database ingestion, and building computational data pipelines.
BSD 3-Clause "New" or "Revised" License
3 stars 5 forks source link

Create a general function to assign events to subjects #335

Open jkbhagatio opened 6 months ago

jkbhagatio commented 6 months ago

Currently, I have some logic to assign pellets and wheel values to subjects. This could be made into a general function such that any event of interest can be assigned to a subject when not a priori known.

Example of current logic:

# <s Get pose-tracking info in order to do subject-specific assignments
pose_df = aeon.load(block.root, social02.CameraTop.Pose, block.start, block.end)
pose_df = reader.Pose.class_int2str(pose_df, block.sleap_model_dir)
if len(subjects) == 1:  # fix mistaken sleap assignments for single-subject blocks
    pose_df["class"] = subjects[0]
# /s>
# <s Get per patch data (fill in `patch_info`, `cum_wheel_dist`, `pellet_info` cols of `blocks_df`)
patch_stats_df = pd.DataFrame(index=patches, columns=["mean", "offset"])  # -> patch_info
cum_wheel_dist_dm = DotMap()  # -> cum_wheel_dist
pellets_stats_df = pd.DataFrame(columns=["time", "patch", "threshold", "id"])  # -> pellet_info
for i, patch in enumerate(patches):
    # <ss Get wheel data
    r = eval(f"social02.{patch}.Encoder")
    wheel_df = aeon.load(block.root, r, block.start, block.end)[::50].round(1).astype(np.float32)
    cum_wheel_dist = -distancetravelled(wheel_df.angle)
    # /ss>
    # <ss Get pellets data
    r = eval(f"social02.{patch}.DepletionState")
    patch_df = aeon.load(block.root, r, block.start, block.end)
    rate, offset = patch_df[["rate", "offset"]].iloc[0]
    patch_stats_df.loc[patch, ["mean", "offset"]] = (1 / rate // 100 * 100, offset)
    patch_df_good_indxs = np.concatenate((np.diff(patch_df.index) > pd.Timedelta("1s"), (True,)))
    patch_df_for_pellets_df = patch_df[patch_df_good_indxs].reset_index()[["time", "threshold"]]
    patch_df_for_pellets_df["patch"] = patch
    patch_df_for_pellets_df["id"] = None
    patch_df_for_pellets_df.dropna(subset=["threshold"], inplace=True)
    # drop 1st val as is from block start
    patch_df_for_pellets_df = patch_df_for_pellets_df.iloc[1:].reset_index(drop=True)
    # /ss>
    # <ss Assign data to subjects
    if len(subjects) == 1:
        cum_wheel_dist_dm[patch] = cum_wheel_dist.to_frame(name=subjects[0])
        patch_df_for_pellets_df["id"] = subjects[0]
    else:
        # <sss Assign id based on which subject was closest to patch at time of event
        # <ssss Get distance-to-patch at each pose data timestep
        patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
        subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
        dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
        dist_to_patch_df = pose_df[["class"]].copy()
        dist_to_patch_df["dist_to_patch"] = dist_to_patch
        # /ssss>
        # <ssss Get distance-to-patch at each wheel ts and pel del ts, organized by subject
        dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects)
        dist_to_patch_pel_ts_id_df = pd.DataFrame(
            index=patch_df_for_pellets_df["time"], columns=subjects
        )
        for subject in subjects:
            # Find closest match between pose_df indices and wheel indices
            dist_to_patch_wheel_ts_subj = pd.merge_asof(
                left=dist_to_patch_wheel_ts_id_df[subject],
                right=dist_to_patch_df[dist_to_patch_df["class"] == subject],
                left_index=True,
                right_index=True,
                direction="forward",
                tolerance=pd.Timedelta("100ms"),
            )
            dist_to_patch_wheel_ts_id_df[subject] = dist_to_patch_wheel_ts_subj["dist_to_patch"]
            # Find closest match between pose_df indices and pel indices
            dist_to_patch_pel_ts_subj = pd.merge_asof(
                left=dist_to_patch_pel_ts_id_df[subject],
                right=dist_to_patch_df[dist_to_patch_df["class"] == subject],
                left_index=True,
                right_index=True,
                direction="forward",
                tolerance=pd.Timedelta("200ms"),
            )
            dist_to_patch_pel_ts_id_df[subject] = dist_to_patch_pel_ts_subj["dist_to_patch"]
        # /ssss>
        # <ssss Get closest subject to patch at each pel del timestep
        patch_df_for_pellets_df["id"] = dist_to_patch_pel_ts_id_df.idxmin(axis=1).values
        # /ssss>
        # <ssss Get closest subject to patch at each wheel timestep
        cum_wheel_dist_subj_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects, data=0.)
        closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)
        wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0])
        # Assign wheel dist to closest subject for each wheel timestep
        for subject in subjects:
            subj_idxs = cum_wheel_dist_subj_df[closest_subjects == subject].index
            cum_wheel_dist_subj_df.loc[subj_idxs, subject] = wheel_dist[subj_idxs]
        cum_wheel_dist_dm[patch] = cum_wheel_dist_subj_df.cumsum(axis=0)
        # /ssss> #/sss> #/ss>
    pellets_stats_df = pd.concat([pellets_stats_df, patch_df_for_pellets_df], ignore_index=True)