kevinsbello / dagma

A Python 3 package for learning Bayesian Networks (DAGs) from data. Official implementation of the paper "DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization"
https://dagma.readthedocs.io/en/latest/
Apache License 2.0
97 stars 20 forks source link

Queries on pre-processing, variable types & overall performance #3

Closed soya-beancurd closed 1 year ago

soya-beancurd commented 1 year ago

Hello! I've just had a look at the paper on DAGMA and it was really interesting in how you rescoped the typical continuous optimization-based approach (from NOTEARS) into something that leverages on the properties of DAGs!

However, I'm no expert in Causal Discovery (nor am I proficient in optimization methods as I just barely managed to understand the high-level intuition of DAGMA), and thus am writing this post to clarify certain doubts that I have with regards to the practical applications/implementations of DAGMA:

(1) Similar to NOTEARS (based on this paper - Unsuitability of NOTEARS for Causal Graph Discovery), is DAGMA susceptible to rescaling of data and is therefore not scale invariant? Such that standardizing all variables to unit variance as a preprocessing step is necessary?

(2) What are the various variable types (i.e., continuous, discrete & categorical) that DAGMA can take in at once? Or can DAGMA (at least for the linear models) only take in variables with only the same type at any one point?

(3) I understand that you have compared the performance & scalability of DAGMA to other continuous optimization-based approaches such as NOTEARS & GOLEM. If possible, where do you think it might fit in the framework of such approaches that are known to be more 'scalable' to larger amounts of nodes such as NOBEARS & LEAST? I am just curious about both the scalability of DAGMA to such algorithms, where there's a nice summary on continuous optimization-based approaches from a recent paper early last year (image attached below, from D’ya Like DAGs? A Survey on Structure Learning and Causal Discovery)

image

I understand that answering all of these questions might be more than a mouthful, but I am trying to clarify all of these doubts as I'm really interested in seeing if there could be a practical application of DAGMA for my real world dataset with ~300-700 variables (columns) & ~1m rows (as my current implementation using DirectLiNGAM is not really scalable at all and unfortunately is constrained to only continuous variables).

Thanks so much in advance!

kevinsbello commented 1 year ago

Hi, thanks for your questions.

  1. There is a large misunderstanding regarding NOTEARS-like algorithms. NOTEARS, GOLEM, DAGMA, and similar methods are approaches that aim to optimize a given loss/score function. These methods offer a way to optimize the given loss function via gradient descent or its variants. Thus, you should think about these methods as optimization methods.

The "issues" that have appeared in a few papers are not about NOTEARS or DAGMA themselves but an issue about the loss function. For example, assuming linear models, if one uses vanilla least squares as the loss function, and the data is standardized, then the model becomes unidentifiable. The latter means that the global minimum of the loss function does not correspond to the true underlying DAG. Thus, it does not matter if you use a greedy method, or notears, or dagma, or any other optimizer, because even if you find the global minimum, it will not be the ground-truth DAG.

For reasons of identifiability via a loss function, it makes sense to talk about identifiable models via a loss function, such as linear models with equal noise variances (in this case, the global minimum of vanilla least squares corresponds to the ground-truth DAG), or the CAM model.

  1. The current implementation of DAGMA can only take either: Continuous variables---if using linear models, then make sure to use least squares, if using nonlinear models, then use the default loss function; Binary (0/1) variables---here one could use linear models with the logistic loss, I have not actually tried what happens when using nonlinear models with the default loss.

  2. The main difference to NOBEARS at least is that DAGMA is an exact characterization of acyclicity. Even though NOBEARS motivates their approach by constraining the spectral radius to be 0, their algorithm actually does a rough approximation of it. Looking at their experiments for 100/300 nodes, I'm under the impression that DAGMA would perform better.

I understand that answering all of these questions might be more than a mouthful, but I am trying to clarify all of these doubts as I'm really interested in seeing if there could be a practical application of DAGMA for my real world dataset with ~300-700 variables (columns) & ~1m rows (as my current implementation using DirectLiNGAM is not really scalable at all and unfortunately is constrained to only continuous variables).

DAGMA should easily handle that number of variables. If using linear models, 300 nodes should take a few minutes and 700 nodes about 30mins to 1 hour in CPU! The current implementation, though, first computes the covariance matrix, and it is just the sample covariance via X^T * X / n, so that might run out of memory for 1M rows. One option would be to use SGD or use a server with higher memory :)

Hope it helps.

soya-beancurd commented 1 year ago

Hi Kevin,

Thanks so much for the clarification as it has really widened up my perspectives on DAGMA (and NOTEARS-esque optimization methods in general)!

Lastly, I'd like to clarify a doubt that I had over the past week when experimenting with DAGMA. As you've mentioned, DAGMA is exceptionally quicker than both vanilla NOTEARS & LiNGAM in general (even the fastest variant, ICA-LiNGAM with FastICA). This was observed across the board with 20 - 200 nodes. However, when testing out these algorithms with 200 nodes, I noticed that the adjacency matrix (or matrix of weights, W_est) identified via DAGMA was not a DAG, where it had failed the dagma.utils.is_dag function.

Setup

Although it took 5-10mins for DAGMA to run (as compared to the variants of LiNGAM which took at least half a day), the weight matrix it produced was not a DAG as compared to LiNGAM which guarantees a DAG based on how it works.

I understand that one could potentially "omit" edges out via either pruning based on some sort of unconditional independence tests between 2 variables of every single non-zero weight or edge (i.e., Fast Conditional Independence Test for Vector Variables with Large Sample Sizes), or perhaps just forcefully remove edges from certain nodes that nullifies a DAG based on setting their weights to 0 (similar to your solution for another DAGMA Github issue). However, the former does not guarantee that a post-pruned weight matrix is a DAG, while the latter might not be scalable.

Therefore, would there be any explanation as to why I might be experiencing such an anomaly, and if there are any possible scalable workarounds to ensure that the weight matrices produced by DAGMA are DAGs? (Could it be that perhaps the continuous objective/framework that is optimized via DAGMA only captures conditional dependencies between variables and not necessarily the true causal relationships, and by extension cannot guarantee a DAG? Or could it be due to the way I generated my synthetic data?)

Thank you!!

kevinsbello commented 1 year ago

Hey there,

That's surprising, I suspect it is just a matter of optimization. I have a feeling that during optimization, the solution path went outside the feasible domain of M-matrices and it never went back, thus outputting something that is not a DAG.

I would like to reproduce the problem, so if you could share the exact code that throw such error, would be great.

In the meantime, just in case try T > 5 (e.g., 6 or 7). If that doesn't help I'd need to look deeper.

Finally, do you get the error even for Gaussian noises or just Uniform noises?

Best, Kevin

soya-beancurd commented 1 year ago

Hi Kevin,

The same error appears for me when noise_type = "gauss" / "uniform" and also lambda1 = 0.2 / 0.1 / 0.01. However, it seems that your suggestion (increasing the number of iterations T to any of the following values: [6, 8, 10]) does the job of producing a DAG for any of the argument values stated in the previous line.

Here's the code that I used:

from dagma import utils as dagma_utils
from dagma.linear import DagmaLinear

############## HYPERPARAMETERS ##############

# Seed
dagma_utils.set_random_seed(1)

# Number of nodes
num_nodes = 200

# Sparsity rate to determine number of edges
# 10% of maximum permissible edges for a DAG
s0_rate = 0.1

# Type of noise for synthetic data
    # "uniform" noise type was also tested
noise_type = "gauss"

# l1 regularization parameter
    # lambda1 values of 0.2 / 0.1 / 0.01 were all tested
lambda1 = 0.2

# loss_type parameter
loss_type = "l2"

# number of iterations
    # values of T = 6 / 8 / 10 were also tested
T = 5

#############################################

# Create an Erdos-Renyi DAG in expectation with Gaussian noise
# number of samples n
# number of nodes d
# number of edges s0 - computed based on s0_rate * max_num_of_edges (= d(d-1)/2)
d = num_nodes
s0 = int(s0_rate * ((d * (d-1))/2))
n = 750000
graph_type = "ER"
sem_type = noise_type

# Data Generation
B_true = dagma_utils.simulate_dag(d, s0, graph_type)
W_true = dagma_utils.simulate_parameter(B_true)
X = dagma_utils.simulate_linear_sem(W_true, n, sem_type)

# Running DAGMA
dagma_model = DagmaLinear(loss_type=loss_type)
W_est = dagma_model.fit(X, lambda1=lambda1, T=T)

# Check if DAG
assert dagma_utils.is_dag(W_est), "Result is not a DAG"

I'd like to check how would you link increased number of iterations to the fact that the solution path sort of converges towards a DAG?

kevinsbello commented 1 year ago

Hi there, sorry for the late response. T controls the number of outer iterations and is related to the decay of mu, as mu gets closer to 0 then "DAGness" increases. I initially set it to 5 since that seemed to be enough for most cases, but in your setting it seems that T=5 was not enough.

Scriddie commented 1 month ago

Hi @kevinsbello @soya-beancurd thanks for starting this interesting discussion on the issue of standardization. From what I understand, any (non-uniform) multiplicative rescaling would render the setting unidentifiable, so one would have to know the "true" data scale to be within the model class. Hence my question:

How would this work on real-world data where data scales are arbitrary?

For example, I may measure GDP in dollar or yuan, height in inches or meters, time in seconds or hours, etc. Given the arbitrariness of real-world datascales, how would the algorithm be applicable to such real-world data? I know real-world applications of causal discovery are difficult in general, so I'm looking forward to hearing your thoughts :)