mberr / torch-ppr

(Personalized) Page-Rank computation using PyTorch
MIT License
85 stars 5 forks source link

Formulate page-rank as a torch.nn Layer #21

Open LM-AuroTripathy opened 2 years ago

LM-AuroTripathy commented 2 years ago

Thank you for this repo!

The reason to request a 'layer' fomulation is to convert the function page_rank to an onnx graph with torch.onnx (only accepts models).

Once I have the onnx model, I can compile it different hardware (other than cuda).

Maybe need just the forward pass, no need for a backward pass although I think the compute will be differentiable.

Thanks.

cthoyt commented 2 years ago

Do you mean wrap the stateless page_rank() function in a stateful torch.nn.Module class?

LM-AuroTripathy commented 2 years ago

yes, why didn't I think of that? I can take a crack at it.

cthoyt commented 2 years ago

Simple way is to make all of the optional parameters as values to pass to the __init__ of the module and have the forward() function in the module match all of the required parameters, then inside pass everything together

LM-AuroTripathy commented 2 years ago

Below is what I came up with.

import torch
from torch_ppr import page_rank

class PageRank(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return page_rank(edge_index=x)

edge_index = torch.as_tensor(data=[(0, 1), (1, 2), (1, 3), (2, 4)]).t()
model = PageRank()
print(model(edge_index))
cthoyt commented 2 years ago

what about the rest of the arguments? all of the following can be passed to def __init__:

https://github.com/mberr/torch-ppr/blob/a5de68835cd4dd879b386f0f28bfe6297feb6262/src/torch_ppr/api.py#L34-L40

LM-AuroTripathy commented 2 years ago

Hoping below is correct usage:

import torch
from torch_ppr import page_rank
from typing import Optional, Union

DeviceHint = Union[None, str, torch.device]

class PageRank(torch.nn.Module):
    def __init__(self,
                 add_identity: bool = False,
                 max_iter: int = 1000,
                 alpha: float = 0.05,
                 epsilon: float = 1.0e-04,
                 x0: Optional[torch.Tensor] = None,
                 use_tqdm: bool = False,
                 device: DeviceHint = None):
        super().__init__()

    def forward(self, x):
        return page_rank(edge_index=x)

edge_index = torch.as_tensor(data=[(0, 1), (1, 2), (1, 3), (2, 4)]).t()
model = PageRank(device='cuda')
print(model(edge_index))

# Input somthing to the model
x = edge_index

torch.onnx.export(model,               # model being run
                  x,                   # model input (or a tuple for multiple inputs)
                  "page_rank.onnx",    # where to save the model (can be a file or file-like object)
                  export_params=False)  # store the trained parameter weights inside the model file
LM-AuroTripathy commented 2 years ago

The onnx conversion is not ready for operator, sparse_coo_tensor . Error message below.

    raise symbolic_registry.UnsupportedOperatorError(
torch.onnx.symbolic_registry.UnsupportedOperatorError: Exporting the operator ::sparse_coo_tensor 
to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request 
on PyTorch GitHub.

Please, is there a work-around, i.e., do the power_iteration with a regular matrix?

mberr commented 2 years ago

In theory, you can run the page-rank iterations with a full matrix; however, you'll lose much of its computational benefits, and are restricted to rather small graphs.

Essentially, you would need:

  1. to prepare the matrix A as dense matrix, cf. edge_index_to_sparse_matrix

    A = torch.zeros(n, n)
    A[edge_index[0], edge_index[1]] = 1.0
  2. Then prepare the A matrix to fulfil the page-rank properties, cf. prepare_page_rank_adjacency

    adj = adj + adj.t()
    if add_identity
    adj = adj + torch.eye(adj.shape[0])
    adj = adj / adj.sum(dim=1, keepdims=True).clamp_min(1.0e-08)
  3. In the power-iteration, you need to replace torch.sparse.addmm by a dense multiplication, i.e.,

    x = (1 - alpha) * (adj @ x) + alpha * x0