DeepTrackAI / DeepTrack2

DeepTrack2
MIT License
162 stars 50 forks source link

Issue Translating Trained MAGIK Model to Real Data #207

Open akilgall opened 7 months ago

akilgall commented 7 months ago

Hello,

I have been trying to link detections using a MAGIK model trained on simulations, and I'm having issues getting any reasonable trajectories from real data. I have attached my train.py and test.py scripts for generating simulated trajectories with varying density and applying to a dataset from real data. I am having a lot of difficulty understanding if there is an issue with the framework or if I am perhaps overtraining on my simulations, as I am getting no trajectories from the edge dataframe.

  trajs = []
  for i, (t, c) in enumerate(traj):
      trajs.append(nodes[t])
  print("trajs: ", trajs)
  print("num traj: ", len(trajs))
  print("mean len: ", np.mean([len(a) for a in trajs]))
  traj:  []
  trajs:  []
  num traj:  0

Here is an example of my training simulations from traindf.csv: image

Here is some of the data from dfIndexToEvaluate.csv image

I am training and evaluating with (I have to upload as a txt file): train.py.txt test.py.txt

dfIndexedToEvaluate.csv traindf.csv

If the results are what I think they are, there should be many tracks, and not none. Am I interpretting the results of dt.models.gnns.get_traj correctly?

Thank you!

JesusPinedaC commented 7 months ago

I have checked your example carefully and have some recommendations.

Why do you get no trajectories?

This issue is related to the thresholding parameter (th) utilized in the get_traj function. By default, the value of th is set to 8, which implies that any trajectories that have a length shorter than 8 frames will be filtered out. This suggests that your model might not have been trained correctly.

Why the model has not been trained properly?

The reason behind this is related to both your dataset and MAGIK's data generation pipeline.

Your dataset consists mostly of confined trajectories, which is completely fine! However, GraphExtractor uses an augmentation function that creates subgraphs for batch training by extracting N trajectories from the training graph. By default, MAGIK extracts between 5 and 12 trajectories to create each subgraph. This method may work well when working with cells, but in your case, where the movement of your particles is confined, it will result in disconnected subgraphs where all the edges are "true connections" or "1s". This will cause the neural network to enter a local minimum. In this state, MAGIK is unable to distinguish between true and false edges because it has not encountered any during training. A balance between true and false connections is desirable.

How do we solve the problem?

At the moment, I see two possible solutions to the problem.

First, try to simulate multiple videos instead of single one. You can set the ID of each video in traindf["set"]. This way, all trajectories in the first video will be set to 0, in the second video set to 1, and so on. Also, make the spatial distribution of your particles similar to the experimental data in each video. This will ensure that the graph has a similar connectivity between the train and test sets, which is currently not the case.

Additionally, In this case, I recommend modifying the default augmentation of the graph extractor as follows:

import random

def GetFeature(full_graph, **kwargs):
    return (
        dt.Value(full_graph)
        >> dt.Lambda(
            GetSubSet,
            randset=lambda: np.random.randint(
                np.max(full_graph[-1][0][:, 0]) + 1),
        )
        >> dt.Lambda(
            AugmentCentroids,
            rotate=lambda: np.random.rand() * 2 * np.pi,
            translate=lambda: np.random.randn(2) * 0.05,
            flip_x=lambda: np.random.randint(2),
            flip_y=lambda: np.random.randint(2),
        )
        >> dt.Lambda(NoisyNode)
        >> dt.Lambda(NodeDropout, dropout_rate=0.03)
    )

generator = GraphGenerator(
        nodesdf=traindf,
        properties=["centroid"],
        min_data_size=51,
        max_data_size=52,
        batch_size=4,
        feature_function=GetFeature,
        **variables.properties()
    )

Now the GetFeature function randomly extracts a training set video and applies rotation, translation, flip, and dropout augmentations to the nodes and edges of the graph.

The second option is to increase the number of trajectories used in creating the training subgraphs without modifying the existing dataset.

import random
def GetFeature(full_graph, **kwargs):
    return (
        dt.Value(full_graph)
        >> dt.Lambda(
            GetSubGraphFromLabel,
            samples=lambda: np.array(
                sorted(
                    random.sample(
                        list(full_graph[-1][0][:, -1]),
                        np.random.randint(200, 300), # Modify this: (min, max) 
                    )
                )
            ),
        )
        >> dt.Lambda(
            AugmentCentroids,
            rotate=lambda: np.random.rand() * 2 * np.pi,
            translate=lambda: np.random.randn(2) * 0.05,
            flip_x=lambda: np.random.randint(2),
            flip_y=lambda: np.random.randint(2),
        )
        >> dt.Lambda(NoisyNode)
        >> dt.Lambda(NodeDropout, dropout_rate=0.03)
    )

generator = GraphGenerator(
        nodesdf=traindf,
        properties=["centroid"],
        min_data_size=51,
        max_data_size=52,
        batch_size=4,
        feature_function=GetFeature,
        **variables.properties()
    )

These two solutions will represent an increase in computation resources. Therefore I recomend using the MPN block instead of the default FGNN block in MAGIK, i.e.,

model = dt.models.gnns.MAGIK(
    dense_layer_dimensions=(64, 96),      # number of features in each dense encoder layer
    base_layer_dimensions=(96,)*4,    # Latent dimension throughout the message passing layers
    number_of_node_features=2,             # Number of node features in the graphs
    number_of_edge_features=1,             # Number of edge features in the graphs
    number_of_edge_outputs=1,              # Number of predicted features
    edge_output_activation="sigmoid",      # Activation function for the output layer
    output_type=_OUTPUT_TYPE,              # Output type. Either "edges", "nodes", or "graph"
    graph_block="MPN",
)
akilgall commented 6 months ago

Thank you for your suggestions!
I was able to implement your first suggestion to get a model properly trained. When applied to my simulations, we get a notable boost in metrics such as F1 score compared to TrackMate's LAP tracker.

When applying this model to my data, however, the trajectories are reconstructed quite slowly due to the number of localizations (100k -1000k). I was able to parallelize some aspects of the to_trajectories function and run slices (256x256 pixels x 1000 frames) of the data through in batch, but some samples still take hours to process. Are there plans to speed this functionality up? Do you have any recommendations for how I could speed up this function?

imzhangyd commented 3 days ago

@akilgall @JesusPinedaC Hello, I would like to save the trained model after training. Here is my code:

with generator:
    model.fit(generator, epochs=1)
    model.save('./model.keras')

However, when I use model.save('./model.keras'), I encounter an error.

Traceback (most recent call last):
  File "/mnt/data1/ZYDdata/proj_MAGIK/code/DeepTrack2/train_maigc.py", line 145, in <module>
    model.save('./model.keras')
  File "/ldap_shared/home/s_zyd/anaconda3/envs/magic_withpkg/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/ldap_shared/home/s_zyd/anaconda3/envs/magic_withpkg/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py", line 395, in _get_class_or_fn_config
    raise TypeError(
TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method.

I am not very familiar with TensorFlow framework. Can you help me with this issue?