gilestrolab / ethoscope

a platform from monitoring animal behaviour in real time from a raspberry pi
http://lab.gilest.ro/ethoscope/
GNU General Public License v3.0
17 stars 25 forks source link

Problem with how xy_dist_log_10x1000 is calculated #194

Closed ggilestro closed 5 months ago

ggilestro commented 5 months ago

Unfortunately this commit created a major issue in tracking output.

This is due to how the value is calculated. The old way was:

xy_dist = round(log10(1./float(w_im) + abs(pos - self._old_pos))*1000)

vs post-commit:

xy_dist = round(log10(abs(pos - self._old_pos) + 1) * 1000)

This creates two different distributions

image

The good news is that data collected with this new algorithm can be fixed. More solutions below.

ggilestro commented 5 months ago

This is now fixed. Any db file that was acquired using the faulty tracking algorithm can be recovered using the script below.

#!/usr/bin/env python3
#
#  update_xy_dist.py
#  
#  Copyright 2024 Giorgio <giorgio@gilest.ro>
#  
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#  
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#  
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
#  MA 02110-1301, USA.
#  
#  

import os
import fnmatch
import optparse
import sqlite3
import pandas as pd
import numpy as np

def calculate_xy_dist_old(df, w_im):
    pos = (df['x'] + 1.0j * df['y']) / w_im
    old_pos = (df['x'].shift(1) + 1.0j * df['y'].shift(1)) / w_im

    df['xy_dist_log10x1000'] = np.round(np.log10(1.0 / float(w_im) + abs(pos - old_pos)) * 1000)

    return df

def process_table(conn, table_name, w_im):
    print(f"Processing table {table_name}...")
    query = f'SELECT * FROM {table_name}'
    df = pd.read_sql_query(query, conn)

    if not df.empty:
        df = calculate_xy_dist_old(df, w_im)
        df.to_sql(table_name, conn, if_exists='replace', index=False)
        print(f"Table {table_name} processed and updated.")
    else:
        print(f"Table {table_name} is empty. Skipping...")

def find_db_file(base_path, ethoscope_number, experiment_date):
    matches = []
    # The pattern now matches the date and ethoscope ID in the filename
    pattern = f"{experiment_date}_*_*{ethoscope_number:03d}*.db"

    # Walk through the directory and find matching files
    for root, dirnames, filenames in os.walk(base_path):
        for filename in fnmatch.filter(filenames, pattern):
            matches.append(os.path.join(root, filename))

    return matches

def main():
    parser = optparse.OptionParser()
    parser.add_option('-e', '--ethoscope', dest='ethoscope_number', type='int', help='Ethoscope number')
    parser.add_option('-d', '--date', dest='experiment_date', help='Experiment date (YYYY-MM-DD)')
    parser.add_option('-w', '--wim', dest='w_im', type='float', default=560, help='w_im value (default: 560)')
    parser.add_option('-b', '--base', dest='base_path', default='/mnt/data/results', help='Base path (default: /mnt/data/results)')
    parser.add_option('-f', '--file', dest='file_path', help='Full path to the db file')

    (options, args) = parser.parse_args()

    if not options.file_path:
        if not options.ethoscope_number:
            parser.error('Ethoscope number not given')
        if not options.experiment_date:
            parser.error('Experiment date not given')

        matching_files = find_db_file(options.base_path, options.ethoscope_number, options.experiment_date)

        if matching_files:
            for file_path in matching_files:
                print(f"Matching database file: {file_path}")

                print(f"Connecting to database {file_path}...")
                conn = sqlite3.connect(file_path)
                for i in range(1, 21):
                    table_name = f'ROI_{i}'
                    process_table(conn, table_name, options.w_im)
                conn.close()
                print("All tables processed. Connection closed.")
        else:
            print("No matching database file found.")
    else:
        file_path = options.file_path
        print(f"Connecting to database {file_path}...")
        conn = sqlite3.connect(file_path)
        for i in range(1, 21):
            table_name = f'ROI_{i}'
            process_table(conn, table_name, options.w_im)
        conn.close()
        print("All tables processed. Connection closed.")

if __name__ == "__main__":
    main()