dynverse / dyngen

Simulating single-cell data using gene regulatory networks 📠
https://dyngen.dynverse.org
Other
73 stars 6 forks source link

Estimate fate probabilities for sampled cells #31

Closed zsteve closed 3 years ago

zsteve commented 3 years ago

Hi Robrecht,

I am interested in estimating fate probabilities for sampled cells. Specifically, for a backbone with e.g. 2 terminal states, each sampled cell should have a probability vector of length 2 containing Prob(cell fate \in {A, B} | observed cell state).

For an observed cell state x, one way to estimate this would be to simulate N new cells that start from x and see where they end up. Wondering if this can be done within the dyngen framework? The idea would be to use these fate probabilities as a ground truth for benchmarking trajectory inference algorithms.

Thanks!

Stephen

rcannood commented 3 years ago

Hey Stephen,

How about the following code. I first run 1 simulation with a bifurcating backbone. The simulation will go down one of two end states, EndC or EndD. I pick many different starting states along this one simulation. For each starting state, I run 100 dyngen simulations with the same underlying model but just a different starting state and no burn-in. By aggregating the metadata of the last state in each simulation, I can make a plot to show the distribution of states for each of the different starting positions (See below).

model_init was created with just 1 simulation in order to be able to generate a nice plot at the end that shows the further along the simulation that I sample, the more likely the cell will end up in end-state EndD. If you want to do something with this type of analysis, I assume you want to run the initial simulation with num_simulations > 30 and have the starting states be sampled randomly along the original simulation.

I might consider turning the code below into a vignette. Is this okay with you?

Robrecht

fate_probs.R ```r library(tidyverse) library(dyngen) set.seed(1) backbone <- backbone_bifurcating() sim_time <- simtime_from_backbone(backbone) # run it once to get the starting states # dyngen will throw a warning because we're only using 1 simulation # and thus one of the end points will not have been reached. model_init <- initialise_model( backbone = backbone, simulation = simulation_default( total_time = sim_time, ssa_algorithm = ssa_etl(tau = .1), experiment_params = simulation_type_wild_type(num_simulations = 1) ) ) %>% generate_tf_network() %>% generate_feature_network() %>% generate_kinetics() %>% generate_gold_standard() %>% generate_cells() plot_simulations(model_init) plot_simulation_expression(model_init) # choose different starting states spread along the simulation above starting_state_times <- c(seq(0, sim_time, by = 20), sim_time) # get the index of the states closest to the specified starting times starting_state_ix <- map_int(starting_state_times, function(dt) { which.min(abs(model_init$simulations$meta$sim_time - dt)) }) # run `num_simulations` simulations from each starting point num_simulations <- 100 # gather initial states outputs <- map(seq_along(starting_state_ix), function(i) { cat("Run ", i, "/", length(starting_state_ix), "\n", sep = "") # get starting state and time init_sim_ix <- starting_state_ix[[i]] init_state <- model_init$simulations$counts[init_sim_ix, , drop = TRUE] init_sim_time <- model_init$simulations$meta$sim_time[[init_sim_ix]] # prepare model model_many_base <- model_init model_many_base$verbose <- FALSE # remove previous simulations model_many_base$simulations <- NULL # set initial state model_many_base$simulation_system$initial_state <- init_state # run `num_simulations` simulations model_many_base$simulation_params$experiment_params <- simulation_type_wild_type(num_simulations = num_simulations) # simulate remaining time plus a bit extra model_many_base$simulation_params$total_time <- sim_time - init_sim_time + 0.5 * sim_time # no burn-in required model_many_base$simulation_params$burn_time <- 0 # run simulations model_many_base <- model_many_base %>% generate_cells() # gather output meta <- model_many_base$simulations$meta %>% group_by(simulation_i) %>% slice(n()) %>% ungroup() %>% mutate(init_i = i, init_sim_ix, init_sim_time) # you can choose to also return the model, though this is not necessary for this script list( # model = model_many_base, meta = meta ) }) # collect all the metadata meta_all <- map_df(outputs, "meta") %>% mutate( init_sim_time = starting_state_times[init_i], edge = paste0(from, "->", to) ) # plot meta_summ <- meta_all %>% group_by(init_i, init_sim_ix, init_sim_time, edge) %>% summarise( n = n(), .groups = "drop" ) %>% mutate(pct = n / sum(n)) %>% ungroup() %>% group_by(init_i) %>% mutate(ymax = cumsum(n), ymin = ymax - n) %>% ungroup() ggplot(meta_summ) + geom_rect(aes(xmin = init_sim_time - 9, xmax = init_sim_time + 9, ymin = ymin, ymax = ymax, fill = edge)) + theme_bw() + labs(x = "Starting time", y = "Count") + scale_fill_brewer(palette = "Set2") ```

plot_zoom_png

rcannood commented 3 years ago

Hey Stephen @zsteve,

Does the code proposed above come in the vicinity of what you were looking for?

Robrecht

rcannood commented 3 years ago

I'm assuming this issue is solved. If it isn't, feel free to reply to this issue.