pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.61k stars 1.99k forks source link

ENH: Sparse matrix handling with ICAR prior #7406

Closed jfhawkin closed 1 month ago

jfhawkin commented 1 month ago

Before

@classmethod
    def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
        # Note: These checks are forcing W to be non-symbolic
        if not W.ndim == 2:
            raise ValueError("W must be matrix with ndim=2")

        if not W.shape[0] == W.shape[1]:
            raise ValueError("W must be a square matrix")

        if not np.allclose(W.T, W):
            raise ValueError("W must be a symmetric matrix")

        if np.any((W != 0) & (W != 1)):
            raise ValueError("W must be composed of only 1s and 0s")

        W = pt.as_tensor_variable(W, dtype=int)
        sigma = pt.as_tensor_variable(sigma)
        zero_sum_stdev = pt.as_tensor_variable(zero_sum_stdev)
        return super().dist([W, sigma, zero_sum_stdev], **kwargs)

    def support_point(rv, size, W, sigma, zero_sum_stdev):
        N = pt.shape(W)[-2]
        return pt.zeros(N)

    def logp(value, W, sigma, zero_sum_stdev):
        # convert adjacency matrix to edgelist representation
        # An edgelist is a pair of lists.
        # If node i and node j are connected then one list
        # will contain i and the other will contain j at the same
        # index value.
        # We only use the lower triangle here because adjacency
        # is a undirected connection.
        N = pt.shape(W)[-2]
        node1, node2 = pt.eq(pt.tril(W), 1).nonzero()

        pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(pt.square(value[node1] - value[node2]))
        zero_sum = (
            -0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
            - pt.log(pt.sqrt(2.0 * np.pi))
            - pt.log(zero_sum_stdev * N)
        )

        return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")

After

@classmethod
    def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
        # Note: These checks are forcing W to be non-symbolic
        if not W.ndim == 2:
            raise ValueError("W must be matrix with ndim=2")

        if not W.shape[0] == W.shape[1]:
            raise ValueError("W must be a square matrix")

        if not np.allclose(W.data.T, W.data):
            raise ValueError("W must be a symmetric matrix")

        if np.any((W.data != 0) & (W.data != 1)):
            raise ValueError("W must be composed of only 1s and 0s")

        sigma = pt.as_tensor_variable(sigma)
        zero_sum_stdev = pt.as_tensor_variable(zero_sum_stdev)
        return super().dist([W, sigma, zero_sum_stdev], **kwargs)

    def support_point(rv, size, W, sigma, zero_sum_stdev):
        N = pt.shape(W)[-2]
        return pt.zeros(N)

    def logp(value, W, sigma, zero_sum_stdev):
        # convert adjacency matrix to edgelist representation
        # An edgelist is a pair of lists.
        # If node i and node j are connected then one list
        # will contain i and the other will contain j at the same
        # index value.
        # We only use the lower triangle here because adjacency
        # is a undirected connection.
        N = pt.shape(W)[-2]
        node1, node2 = W.nonzero()
        node1 = pt.as_tensor_variable(node1, dtype=int)
        node2 = pt.as_tensor_variable(node2, dtype=int)
        W = pytensor.sparse.as_sparse_or_tensor_variable(W)

        pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(pt.square(value[node1] - value[node2]))
        zero_sum = (
            -0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
            - pt.log(pt.sqrt(2.0 * np.pi))
            - pt.log(zero_sum_stdev * N)
        )

        return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")

Context for the issue:

The CAR prior allows the user to input a sparse matrix, but the ICAR prior does not have this functionality. For large spatial adjacency matrices, sparsity is critical to ensure efficient memory allocation.

The above solution makes minor revisions to run checks and generate the node tuple from a sparse matrix. It may be necessary to place it in an if statement, similar to the approach used in the CAR class. This version runs for me on a test dataset.

welcome[bot] commented 1 month ago

Welcome Banner] :tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

ricardoV94 commented 1 month ago

You want W wrapped in a pt.as_tensor_or_sparse (or something like that). The shape errors can go away or be made in the logp, unless shape is static in which case it can be done immediately.

jfhawkin commented 1 month ago

@ricardoV94 Running the latest version of multivariate.py, I get the following error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 12
     10 # spatially dependent random effect with non-centered parameterization
     11 icar_sigma = pm.Exponential('icar_sigma', 1)
---> 12 phi = pm.ICAR("phi", W=adj_matrix)
     13 mu_icar = icar_sigma * phi
     15 # CBSA model

File ~/.conda/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:536, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, default_transform, *args, **kwargs)
    533     elif observed is not None:
    534         kwargs["shape"] = tuple(observed.shape)
--> 536 rv_out = cls.dist(*args, **kwargs)
    538 rv_out = model.register_rv(
    539     rv_out,
    540     name,
   (...)
    546     initval=initval,
    547 )
    549 # add in pretty-printing support

File ~/.conda/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/multivariate.py:2423, in ICAR.dist(cls, W, sigma, zero_sum_stdev, **kwargs)
   2420 sigma = pt.as_tensor_variable(sigma)
   2421 zero_sum_stdev = pt.as_tensor_variable(zero_sum_stdev)
-> 2423 return super().dist([W, node1, node2, N, sigma, zero_sum_stdev], **kwargs)

File ~/.conda/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:618, in Distribution.dist(cls, dist_params, shape, **kwargs)
    615     ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
    617 create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
--> 618 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
    620 rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
    621 rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")

TypeError: ICARRV.__call__() got multiple values for argument 'size'
jfhawkin commented 1 month ago

I found the error. The pymc 5.16.2 on conda-forge differs from the Github main repo in important ways. It's missing size = kwargs.pop("size", None) and several similar lines that address the positional/keyword argument issue.