metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
146 stars 7 forks source link

[Feature Request] Implement `optree.ravel_pytree` similar to `from jax.flatten_util import ravel_pytree` #96

Closed patel-zeel closed 1 year ago

patel-zeel commented 1 year ago

Required prerequisites

Motivation

Hi,

What does this function do?

This function is motivated from JAX's ravel_pytree which converts a PyTree into a 1D tensor and provides a function to convert that 1D tensor back to the original PyTree format.

from optree import tree_structure, ravel_pytree

flat_tensor, unravel_fn = ravel_pytree(pytree)
unraveled_pytree = unravel_fn(flat_tensor)
assert tree_structure(pytree) == tree_structure(unraveled_pytree)

This function is useful for these applications: i) Ease the process of defining "Hypernetworks"; ii) Hessian computation of neural networks (e.g., Laplace approximation)

Challanges in Hypernetworks

In "Hypernetworks" line of work, we have a hypernetwork hypernet, which generates the parameters of another network net. Typically, hypernet's last layer will be a Linear/Dense layer, and thus it will be in a "flattened" format (1D tensor). It is challenging to convert this flat tensor into parameters of net.

Challenges in computing Hessians in Laplace approximation.

Laplace approximation of neural network requires one to compute the hessian of the negative log joint w.r.t. parameters and then invert it. The inverse of Hessian is the covariance matrix of the posterior normal distribution. One has to flatten the parameters to get the hessian. Also, a sample parameter from this posterior is a flat array that needs to be restructured before inserting into the network.

Solution

Currently, I have written my own version of ravel_pytree using the existing functions in optree:

def ravel_pytree(pytree):
    leaves, structure = optree.tree_flatten(pytree)
    shapes = [leaf.shape for leaf in leaves]
    sizes = [leaf.numel() for leaf in leaves]
    flat_params = torch.cat([leaf.flatten() for leaf in leaves])

    def unravel_function(flat_params):
        assert flat_params.numel() == sum(sizes), f"Invalid size {flat_params.numel()} != {sum(sizes)}"
        assert len(flat_params.shape) == 1, f"Invalid shape {flat_params.shape}"
        flat_leaves = flat_params.split(sizes)
        leaves = [leaf.reshape(shape) for leaf, shape in zip(flat_leaves, shapes)]
        return optree.tree_unflatten(structure, leaves)

    return flat_params, unravel_function

Here is how it can be used for different applications.

Hypernetworks

We would need to transform the flat output of hypernet to a structure that matches net.named_parameters() and then trigger a functional call with torch.func.functional_call(net, named_parameters, inputs). A code snipped would look like this:

import torch
import torch.nn as nn
from torch.func import functional_call

x_context = torch.randn(10, 3)
y_context = torch.randn(10, 1)
x_target = torch.randn(5, 3)

hypernet = nn.Sequential(nn.Linear(3+1, 10), nn.ReLU(), nn.Linear(10, 3*2 + 2))
net = nn.Linear(3, 2)

# Get flat parameters from the hypernet
output = hypernet(torch.cat([x_context, y_context], dim=-1))  # torch.Size([10, 8])
flat_params = output.mean(dim=0)  # # torch.Size([8])

# Get a function to structure flat_params for net 
net_params_template = dict(net.named_parameters())
_, unravel_fn = ravel_pytree(net_params_template)

# Structure the flat_params
net_params = unravel_fn(flat_params)

# Get the output
output = functional_call(net, net_params, x_target)
print(output.shape)
# torch.Size([5, 2])

Laplace Approximation

import torch
import torch.nn as nn
from torch.distributions import Normal, MultivariateNormal
from torch.func import functional_call, hessian
torch.manual_seed(10)

net = nn.Sequential(nn.Linear(2, 8), nn.ReLU(), nn.Linear(8, 1))
map_params_template = dict(net.named_parameters())  # map means maximum a posteriori
map_flat_params, unravel_fn = ravel_pytree(map_params_template)

def neg_log_joint(flat_params, inputs, outputs):
    log_prior = Normal(0, 1.0).log_prob(flat_params).sum()

    params = unravel_fn(flat_params)
    output = functional_call(net, params, inputs)
    log_likelihood = Normal(output.ravel(), 0.1).log_prob(outputs).sum()

    return -(log_prior + log_likelihood)

inputs = torch.rand(11, 2)
outputs = torch.rand(11, 1)

H = hessian(neg_log_joint)(map_flat_params, inputs, outputs)
# print(H.shape)
# torch.Size([33, 33])

covar = torch.inverse(H)  # Not ideal, but okay for this MWE
covar = covar + torch.eye(covar.shape[0]) * 1e-2  # Add jitter for numerical stability
posterior = MultivariateNormal(map_flat_params, covar)

# Sample from the posterior
sample = posterior.sample()
print(sample.shape) # torch.Size([33])

params = unravel_fn(sample)
output = functional_call(net, params, inputs)
print(output.shape) # torch.Size([11, 1]

Alternatives

An alternative is to use the manually defined function above or to use some other way.

Additional context

No response

XuehaiPan commented 1 year ago

Hi @patel-zeel, thanks for raising this! This is a very useful feature. And thank for the very detailed comment and the code snippet.

Currently, I have written my own version of ravel_pytree using existing function of optree:

def ravel_pytree(pytree):
    leaves, structure = optree.tree_flatten(pytree)
    shapes = [leaf.shape for leaf in leaves]
    sizes = [leaf.numel() for leaf in leaves]
    flat_params = torch.cat([leaf.flatten() for leaf in leaves])

    def unravel_function(flat_params):
        assert flat_params.numel() == sum(sizes), f"Invalid size {flat_params.numel()} != {sum(sizes)}"
        assert len(flat_params.shape) == 1, f"Invalid shape {flat_params.shape}"
        flat_leaves = flat_params.split(sizes)
        leaves = [leaf.reshape(shape) for leaf, shape in zip(flat_leaves, shapes)]
        return optree.tree_unflatten(structure, leaves)

    return flat_params, unravel_function

I think the ravel_pytree API should be available for all pytrees that contain NDArray(s). We would to have a generic implementation for torch.Tensor, numpy.ndarray, jax.Array, etc. There are some small differences between these data structures. For example, for numpy.ndarray and jax.Array, there is no method numel. We can do this with arr.ravel().shape[0].

Here is the initial API design:

concatenate_func: Callable[[Sequence[ArrayLike]], ArrayLike]
split_func: Callable[[ArrayLike, Sequence[int]], List[ArrayLike]]

def ravel_pytree(pytree, concatenate_func, split_func):
    leaves, treespec = optree.tree_flatten(pytree)
    flat_leaves = [leaf.ravel() for leaf in leaves]
    flat_params = concatenate_func(flat_leaves)
    shapes = [leaf.shape for leaf in leaves]
    sizes = [flat_leaf.shape[0] for flat_leaf in flat_leaves]
    total_size = flat_params.shape[0]

    def unravel_function(flat_params):
        assert len(flat_params.shape) == 1, f"Invalid shape {flat_params.shape}"
        assert flat_params.shape[0] == total_size, f"Invalid size {flat_params.shape[0]} != {total_size}"
        flat_leaves = split_func(flat_params, sizes)
        leaves = [flat_leaf.reshape(shape) for flat_leaf, shape in zip(flat_leaves, shapes)]
        return optree.tree_unflatten(treespec, leaves)

    return flat_params, unravel_function

Now optree is going to be the C++ implementation for the pytree utility in the PyTorch package (maybe GA in PyTorch 2.2). Would you like to have this ravel_pytree to be implemented in the PyTorch library or the optree package here? Any thoughts are much appreciated.

patel-zeel commented 1 year ago

Hi @XuehaiPan,

Thank you for sharing your feedback and your idea about the generic interface! I think it will be very useful to have a PyTree library that supports all NDArray(s).

Now optree is going to be the C++ implementation for the pytree utility in the PyTorch package (maybe GA in PyTorch 2.2). Would you like to have this ravel_pytree to be implemented in the PyTorch library or the optree package here? Any thoughts are much appreciated.

I didn't know about PyTorch's plan to implement their own PyTree. Thanks for the info! Are they going to support most of the NDArray(s), including JAX and NumPy? I think other than ravel_pytree, other functions from jax.tree_util and optree already support all NDArray(s) and much more. Is that right?

If PyTorch is not planning to generalize, then I'd support implementing it in optree; otherwise, PyTorch will have a wider visibility and significantly more people would be able to use it for various purposes. I think, at the least, PyTorch's PyTree should have a ravel_pytree function that supports torch tensors.

XuehaiPan commented 1 year ago

If PyTorch is not planning to generalize, then I'd support implementing it in optree;

Let's do this. I'll submit a PR to resolve this.

XuehaiPan commented 1 year ago

Hi @patel-zeel, I created a PR to resolve this issue. You can try it via:

pip3 install git+https://github.com/XuehaiPan/optree.git@integration
import optree

flat, unravel_func = optree.integration.torch.tree_ravel(tensor_tree)

# or
from optree.integration.torch import tree_ravel

flat, unravel_func = tree_ravel(tensor_tree)

Here is the output of your example program:

import torch
import torch.nn as nn
from torch.func import functional_call
import optree

torch.manual_seed(0)

x_context = torch.randn(10, 3)
y_context = torch.randn(10, 1)
x_target = torch.randn(5, 3)

hypernet = nn.Sequential(
    nn.Linear(3 + 1, 10),
    nn.ReLU(),
    nn.Linear(10, 3 * 2 + 2),
)
net = nn.Linear(3, 2)

# Get flat parameters from the hypernet
output = hypernet(torch.cat([x_context, y_context], dim=-1))  # torch.Size([10, 8])
flat_params = output.mean(dim=0)
print(flat_params.shape)  # # torch.Size([8])

# Get a function to structure flat_params for net
net_params_template = dict(net.named_parameters())
dummy_flat, unravel_func = optree.integration.torch.tree_ravel(net_params_template)
print(dummy_flat)  # # torch.Size([8])
print(dummy_flat.shape)  # # torch.Size([8])

# Structure the flat_params
net_params = unravel_func(flat_params)
print(net_params)  # {'weight': tensor of size (2, 3), 'bias': tensor of size (2,)}

# Get the output
output = functional_call(net, net_params, x_target)
print(output.shape)  # torch.Size([5, 2])
$ python3 test.py
torch.Size([8])
tensor([-0.1859,  0.0009,  0.2589, -0.4081, -0.2447,  0.1698,  0.1906,  0.4331],
       grad_fn=<CatBackward0>)
torch.Size([8])
{'weight': tensor([[ 0.1633,  0.3138,  0.0756],
        [ 0.1910,  0.1452, -0.1463]], grad_fn=<ViewBackward0>), 'bias': tensor([-0.5388, -0.3698], grad_fn=<ViewBackward0>)}
torch.Size([5, 2])
patel-zeel commented 1 year ago

This is great, @XuehaiPan! I can see that tree_ravel is added for JAX and Numpy as well.

I could not install this branch on Google Colab for some reason. Here is the Traceback.

Collecting git+https://github.com/XuehaiPan/optree.git@integration
  Cloning https://github.com/XuehaiPan/optree.git (to revision integration) to /tmp/pip-req-build-8ami0owz
  Running command git clone --filter=blob:none --quiet https://github.com/XuehaiPan/optree.git /tmp/pip-req-build-8ami0owz
  Running command git checkout -b integration --track origin/integration
  Switched to a new branch 'integration'
  Branch 'integration' set up to track remote branch 'integration' from 'origin'.
  Resolved https://github.com/XuehaiPan/optree.git to commit 86f4f7c61ad793dff951fd454a2f926de2a355a4
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
  Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from optree==0.9.3.dev30+g86f4f7c) (4.5.0)
Building wheels for collected packages: optree
  error: subprocess-exited-with-error

  × Building wheel for optree (pyproject.toml) did not run successfully.
  │ exit code: 1
  ╰─> See above for output.

  note: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for optree (pyproject.toml) ... error
  ERROR: Failed building wheel for optree
Failed to build optree
ERROR: Could not build wheels for optree, which is required to install pyproject.toml-based projects
XuehaiPan commented 1 year ago

@patel-zeel Could you enable the verbose mode while installing packages?

!pip3 install --upgrade pip setuptools wheel
!pip3 install -vvv git+https://github.com/XuehaiPan/optree.git@integration
patel-zeel commented 1 year ago

Sure, here is the full traceback with verbose. It complains about cmake.

Using pip 23.3.1 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
Non-user install because site-packages writeable
Created temporary directory: /tmp/pip-build-tracker-5g5opr_9
Initialized build tracking at /tmp/pip-build-tracker-5g5opr_9
Created build tracker: /tmp/pip-build-tracker-5g5opr_9
Entered build tracker: /tmp/pip-build-tracker-5g5opr_9
Created temporary directory: /tmp/pip-install-6q3rr3l7
Created temporary directory: /tmp/pip-ephem-wheel-cache-d85ykh7r
Collecting git+https://github.com/XuehaiPan/optree.git@integration
  Created temporary directory: /tmp/pip-req-build-r4xwy2fp
  Cloning https://github.com/XuehaiPan/optree.git (to revision integration) to /tmp/pip-req-build-r4xwy2fp
  Running command git version
  git version 2.34.1
  Running command git clone --filter=blob:none --verbose --progress https://github.com/XuehaiPan/optree.git /tmp/pip-req-build-r4xwy2fp
  Cloning into '/tmp/pip-req-build-r4xwy2fp'...
  POST git-upload-pack (175 bytes)
  POST git-upload-pack (gzip 1072 to 579 bytes)
  remote: Enumerating objects: 940, done.
  remote: Counting objects:   0% (1/313)
  remote: Counting objects:   1% (4/313)
  remote: Counting objects:   2% (7/313)
  remote: Counting objects:   3% (10/313)
  remote: Counting objects:   4% (13/313)
  remote: Counting objects:   5% (16/313)
  remote: Counting objects:   6% (19/313)
  remote: Counting objects:   7% (22/313)
  remote: Counting objects:   8% (26/313)
  remote: Counting objects:   9% (29/313)
  remote: Counting objects:  10% (32/313)
  remote: Counting objects:  11% (35/313)
  remote: Counting objects:  12% (38/313)
  remote: Counting objects:  13% (41/313)
  remote: Counting objects:  14% (44/313)
  remote: Counting objects:  15% (47/313)
  remote: Counting objects:  16% (51/313)
  remote: Counting objects:  17% (54/313)
  remote: Counting objects:  18% (57/313)
  remote: Counting objects:  19% (60/313)
  remote: Counting objects:  20% (63/313)
  remote: Counting objects:  21% (66/313)
  remote: Counting objects:  22% (69/313)
  remote: Counting objects:  23% (72/313)
  remote: Counting objects:  24% (76/313)
  remote: Counting objects:  25% (79/313)
  remote: Counting objects:  26% (82/313)
  remote: Counting objects:  27% (85/313)
  remote: Counting objects:  28% (88/313)
  remote: Counting objects:  29% (91/313)
  remote: Counting objects:  30% (94/313)
  remote: Counting objects:  31% (98/313)
  remote: Counting objects:  32% (101/313)
  remote: Counting objects:  33% (104/313)
  remote: Counting objects:  34% (107/313)
  remote: Counting objects:  35% (110/313)
  remote: Counting objects:  36% (113/313)
  remote: Counting objects:  37% (116/313)
  remote: Counting objects:  38% (119/313)
  remote: Counting objects:  39% (123/313)
  remote: Counting objects:  40% (126/313)
  remote: Counting objects:  41% (129/313)
  remote: Counting objects:  42% (132/313)
  remote: Counting objects:  43% (135/313)
  remote: Counting objects:  44% (138/313)
  remote: Counting objects:  45% (141/313)
  remote: Counting objects:  46% (144/313)
  remote: Counting objects:  47% (148/313)
  remote: Counting objects:  48% (151/313)
  remote: Counting objects:  49% (154/313)
  remote: Counting objects:  50% (157/313)
  remote: Counting objects:  51% (160/313)
  remote: Counting objects:  52% (163/313)
  remote: Counting objects:  53% (166/313)
  remote: Counting objects:  54% (170/313)
  remote: Counting objects:  55% (173/313)
  remote: Counting objects:  56% (176/313)
  remote: Counting objects:  57% (179/313)
  remote: Counting objects:  58% (182/313)
  remote: Counting objects:  59% (185/313)
  remote: Counting objects:  60% (188/313)
  remote: Counting objects:  61% (191/313)
  remote: Counting objects:  62% (195/313)
  remote: Counting objects:  63% (198/313)
  remote: Counting objects:  64% (201/313)
  remote: Counting objects:  65% (204/313)
  remote: Counting objects:  66% (207/313)
  remote: Counting objects:  67% (210/313)
  remote: Counting objects:  68% (213/313)
  remote: Counting objects:  69% (216/313)
  remote: Counting objects:  70% (220/313)
  remote: Counting objects:  71% (223/313)
  remote: Counting objects:  72% (226/313)
  remote: Counting objects:  73% (229/313)
  remote: Counting objects:  74% (232/313)
  remote: Counting objects:  75% (235/313)
  remote: Counting objects:  76% (238/313)
  remote: Counting objects:  77% (242/313)
  remote: Counting objects:  78% (245/313)
  remote: Counting objects:  79% (248/313)
  remote: Counting objects:  80% (251/313)
  remote: Counting objects:  81% (254/313)
  remote: Counting objects:  82% (257/313)
  remote: Counting objects:  83% (260/313)
  remote: Counting objects:  84% (263/313)
  remote: Counting objects:  85% (267/313)
  remote: Counting objects:  86% (270/313)
  remote: Counting objects:  87% (273/313)
  remote: Counting objects:  88% (276/313)
  remote: Counting objects:  89% (279/313)
  remote: Counting objects:  90% (282/313)
  remote: Counting objects:  91% (285/313)
  remote: Counting objects:  92% (288/313)
  remote: Counting objects:  93% (292/313)
  remote: Counting objects:  94% (295/313)
  remote: Counting objects:  95% (298/313)
  remote: Counting objects:  96% (301/313)
  remote: Counting objects:  97% (304/313)
  remote: Counting objects:  98% (307/313)
  remote: Counting objects:  99% (310/313)
  remote: Counting objects: 100% (313/313)
  remote: Counting objects: 100% (313/313), done.
  remote: Compressing objects:   0% (1/127)
  remote: Compressing objects:   1% (2/127)
  remote: Compressing objects:   2% (3/127)
  remote: Compressing objects:   3% (4/127)
  remote: Compressing objects:   4% (6/127)
  remote: Compressing objects:   5% (7/127)
  remote: Compressing objects:   6% (8/127)
  remote: Compressing objects:   7% (9/127)
  remote: Compressing objects:   8% (11/127)
  remote: Compressing objects:   9% (12/127)
  remote: Compressing objects:  10% (13/127)
  remote: Compressing objects:  11% (14/127)
  remote: Compressing objects:  12% (16/127)
  remote: Compressing objects:  13% (17/127)
  remote: Compressing objects:  14% (18/127)
  remote: Compressing objects:  15% (20/127)
  remote: Compressing objects:  16% (21/127)
  remote: Compressing objects:  17% (22/127)
  remote: Compressing objects:  18% (23/127)
  remote: Compressing objects:  19% (25/127)
  remote: Compressing objects:  20% (26/127)
  remote: Compressing objects:  21% (27/127)
  remote: Compressing objects:  22% (28/127)
  remote: Compressing objects:  23% (30/127)
  remote: Compressing objects:  24% (31/127)
  remote: Compressing objects:  25% (32/127)
  remote: Compressing objects:  26% (34/127)
  remote: Compressing objects:  27% (35/127)
  remote: Compressing objects:  28% (36/127)
  remote: Compressing objects:  29% (37/127)
  remote: Compressing objects:  30% (39/127)
  remote: Compressing objects:  31% (40/127)
  remote: Compressing objects:  32% (41/127)
  remote: Compressing objects:  33% (42/127)
  remote: Compressing objects:  34% (44/127)
  remote: Compressing objects:  35% (45/127)
  remote: Compressing objects:  36% (46/127)
  remote: Compressing objects:  37% (47/127)
  remote: Compressing objects:  38% (49/127)
  remote: Compressing objects:  39% (50/127)
  remote: Compressing objects:  40% (51/127)
  remote: Compressing objects:  41% (53/127)
  remote: Compressing objects:  42% (54/127)
  remote: Compressing objects:  43% (55/127)
  remote: Compressing objects:  44% (56/127)
  remote: Compressing objects:  45% (58/127)
  remote: Compressing objects:  46% (59/127)
  remote: Compressing objects:  47% (60/127)
  remote: Compressing objects:  48% (61/127)
  remote: Compressing objects:  49% (63/127)
  remote: Compressing objects:  50% (64/127)
  remote: Compressing objects:  51% (65/127)
  remote: Compressing objects:  52% (67/127)
  remote: Compressing objects:  53% (68/127)
  remote: Compressing objects:  54% (69/127)
  remote: Compressing objects:  55% (70/127)
  remote: Compressing objects:  56% (72/127)
  remote: Compressing objects:  57% (73/127)
  remote: Compressing objects:  58% (74/127)
  remote: Compressing objects:  59% (75/127)
  remote: Compressing objects:  60% (77/127)
  remote: Compressing objects:  61% (78/127)
  remote: Compressing objects:  62% (79/127)
  remote: Compressing objects:  63% (81/127)
  remote: Compressing objects:  64% (82/127)
  remote: Compressing objects:  65% (83/127)
  remote: Compressing objects:  66% (84/127)
  remote: Compressing objects:  67% (86/127)
  remote: Compressing objects:  68% (87/127)
  remote: Compressing objects:  69% (88/127)
  remote: Compressing objects:  70% (89/127)
  remote: Compressing objects:  71% (91/127)
  remote: Compressing objects:  72% (92/127)
  remote: Compressing objects:  73% (93/127)
  remote: Compressing objects:  74% (94/127)
  remote: Compressing objects:  75% (96/127)
  remote: Compressing objects:  76% (97/127)
  remote: Compressing objects:  77% (98/127)
  remote: Compressing objects:  78% (100/127)
  remote: Compressing objects:  79% (101/127)
  remote: Compressing objects:  80% (102/127)
  remote: Compressing objects:  81% (103/127)
  remote: Compressing objects:  82% (105/127)
  remote: Compressing objects:  83% (106/127)
  remote: Compressing objects:  84% (107/127)
  remote: Compressing objects:  85% (108/127)
  remote: Compressing objects:  86% (110/127)
  remote: Compressing objects:  87% (111/127)
  remote: Compressing objects:  88% (112/127)
  remote: Compressing objects:  89% (114/127)
  remote: Compressing objects:  90% (115/127)
  remote: Compressing objects:  91% (116/127)
  remote: Compressing objects:  92% (117/127)
  remote: Compressing objects:  93% (119/127)
  remote: Compressing objects:  94% (120/127)
  remote: Compressing objects:  95% (121/127)
  remote: Compressing objects:  96% (122/127)
  remote: Compressing objects:  97% (124/127)
  remote: Compressing objects:  98% (125/127)
  remote: Compressing objects:  99% (126/127)
  remote: Compressing objects: 100% (127/127)
  remote: Compressing objects: 100% (127/127), done.
  Receiving objects:   0% (1/940)
  Receiving objects:   1% (10/940)
  Receiving objects:   2% (19/940)
  Receiving objects:   3% (29/940)
  Receiving objects:   4% (38/940)
  Receiving objects:   5% (47/940)
  Receiving objects:   6% (57/940)
  Receiving objects:   7% (66/940)
  Receiving objects:   8% (76/940)
  Receiving objects:   9% (85/940)
  Receiving objects:  10% (94/940)
  Receiving objects:  11% (104/940)
  Receiving objects:  12% (113/940)
  Receiving objects:  13% (123/940)
  Receiving objects:  14% (132/940)
  Receiving objects:  15% (141/940)
  Receiving objects:  16% (151/940)
  Receiving objects:  17% (160/940)
  Receiving objects:  18% (170/940)
  Receiving objects:  19% (179/940)
  Receiving objects:  20% (188/940)
  Receiving objects:  21% (198/940)
  Receiving objects:  22% (207/940)
  Receiving objects:  23% (217/940)
  Receiving objects:  24% (226/940)
  Receiving objects:  25% (235/940)
  Receiving objects:  26% (245/940)
  Receiving objects:  27% (254/940)
  Receiving objects:  28% (264/940)
  Receiving objects:  29% (273/940)
  Receiving objects:  30% (282/940)
  Receiving objects:  31% (292/940)
  Receiving objects:  32% (301/940)
  Receiving objects:  33% (311/940)
  Receiving objects:  34% (320/940)
  Receiving objects:  35% (329/940)
  Receiving objects:  36% (339/940)
  Receiving objects:  37% (348/940)
  Receiving objects:  38% (358/940)
  Receiving objects:  39% (367/940)
  Receiving objects:  40% (376/940)
  Receiving objects:  41% (386/940)
  Receiving objects:  42% (395/940)
  Receiving objects:  43% (405/940)
  Receiving objects:  44% (414/940)
  Receiving objects:  45% (423/940)
  Receiving objects:  46% (433/940)
  Receiving objects:  47% (442/940)
  Receiving objects:  48% (452/940)
  Receiving objects:  49% (461/940)
  Receiving objects:  50% (470/940)
  Receiving objects:  51% (480/940)
  Receiving objects:  52% (489/940)
  Receiving objects:  53% (499/940)
  Receiving objects:  54% (508/940)
  Receiving objects:  55% (517/940)
  Receiving objects:  56% (527/940)
  Receiving objects:  57% (536/940)
  Receiving objects:  58% (546/940)
  Receiving objects:  59% (555/940)
  Receiving objects:  60% (564/940)
  Receiving objects:  61% (574/940)
  Receiving objects:  62% (583/940)
  Receiving objects:  63% (593/940)
  Receiving objects:  64% (602/940)
  Receiving objects:  65% (611/940)
  Receiving objects:  66% (621/940)
  Receiving objects:  67% (630/940)
  Receiving objects:  68% (640/940)
  Receiving objects:  69% (649/940)
  Receiving objects:  70% (658/940)
  Receiving objects:  71% (668/940)
  Receiving objects:  72% (677/940)
  Receiving objects:  73% (687/940)
  remote: Total 940 (delta 228), reused 242 (delta 185), pack-reused 627
  Receiving objects:  74% (696/940)
  Receiving objects:  75% (705/940)
  Receiving objects:  76% (715/940)
  Receiving objects:  77% (724/940)
  Receiving objects:  78% (734/940)
  Receiving objects:  79% (743/940)
  Receiving objects:  80% (752/940)
  Receiving objects:  81% (762/940)
  Receiving objects:  82% (771/940)
  Receiving objects:  83% (781/940)
  Receiving objects:  84% (790/940)
  Receiving objects:  85% (799/940)
  Receiving objects:  86% (809/940)
  Receiving objects:  87% (818/940)
  Receiving objects:  88% (828/940)
  Receiving objects:  89% (837/940)
  Receiving objects:  90% (846/940)
  Receiving objects:  91% (856/940)
  Receiving objects:  92% (865/940)
  Receiving objects:  93% (875/940)
  Receiving objects:  94% (884/940)
  Receiving objects:  95% (893/940)
  Receiving objects:  96% (903/940)
  Receiving objects:  97% (912/940)
  Receiving objects:  98% (922/940)
  Receiving objects:  99% (931/940)
  Receiving objects: 100% (940/940)
  Receiving objects: 100% (940/940), 150.04 KiB | 3.41 MiB/s, done.
  Resolving deltas:   0% (0/508)
  Resolving deltas:   1% (6/508)
  Resolving deltas:   2% (11/508)
  Resolving deltas:   3% (16/508)
  Resolving deltas:   4% (21/508)
  Resolving deltas:   5% (26/508)
  Resolving deltas:   6% (31/508)
  Resolving deltas:   7% (36/508)
  Resolving deltas:   8% (41/508)
  Resolving deltas:   9% (46/508)
  Resolving deltas:  10% (51/508)
  Resolving deltas:  11% (56/508)
  Resolving deltas:  12% (61/508)
  Resolving deltas:  13% (67/508)
  Resolving deltas:  14% (72/508)
  Resolving deltas:  15% (77/508)
  Resolving deltas:  16% (82/508)
  Resolving deltas:  17% (87/508)
  Resolving deltas:  18% (92/508)
  Resolving deltas:  19% (97/508)
  Resolving deltas:  20% (102/508)
  Resolving deltas:  21% (107/508)
  Resolving deltas:  22% (112/508)
  Resolving deltas:  23% (117/508)
  Resolving deltas:  24% (122/508)
  Resolving deltas:  25% (127/508)
  Resolving deltas:  26% (133/508)
  Resolving deltas:  27% (138/508)
  Resolving deltas:  28% (143/508)
  Resolving deltas:  29% (148/508)
  Resolving deltas:  30% (153/508)
  Resolving deltas:  31% (158/508)
  Resolving deltas:  32% (163/508)
  Resolving deltas:  33% (168/508)
  Resolving deltas:  34% (173/508)
  Resolving deltas:  35% (178/508)
  Resolving deltas:  36% (183/508)
  Resolving deltas:  37% (188/508)
  Resolving deltas:  38% (194/508)
  Resolving deltas:  39% (199/508)
  Resolving deltas:  40% (204/508)
  Resolving deltas:  41% (209/508)
  Resolving deltas:  42% (214/508)
  Resolving deltas:  43% (219/508)
  Resolving deltas:  44% (224/508)
  Resolving deltas:  45% (229/508)
  Resolving deltas:  46% (234/508)
  Resolving deltas:  47% (239/508)
  Resolving deltas:  48% (244/508)
  Resolving deltas:  49% (249/508)
  Resolving deltas:  50% (254/508)
  Resolving deltas:  51% (260/508)
  Resolving deltas:  52% (265/508)
  Resolving deltas:  53% (270/508)
  Resolving deltas:  54% (275/508)
  Resolving deltas:  55% (280/508)
  Resolving deltas:  56% (285/508)
  Resolving deltas:  57% (290/508)
  Resolving deltas:  58% (295/508)
  Resolving deltas:  59% (300/508)
  Resolving deltas:  60% (305/508)
  Resolving deltas:  61% (310/508)
  Resolving deltas:  62% (315/508)
  Resolving deltas:  63% (321/508)
  Resolving deltas:  64% (326/508)
  Resolving deltas:  65% (331/508)
  Resolving deltas:  66% (336/508)
  Resolving deltas:  67% (341/508)
  Resolving deltas:  68% (346/508)
  Resolving deltas:  69% (351/508)
  Resolving deltas:  70% (356/508)
  Resolving deltas:  71% (361/508)
  Resolving deltas:  72% (366/508)
  Resolving deltas:  73% (371/508)
  Resolving deltas:  74% (376/508)
  Resolving deltas:  75% (381/508)
  Resolving deltas:  76% (387/508)
  Resolving deltas:  77% (392/508)
  Resolving deltas:  78% (397/508)
  Resolving deltas:  79% (402/508)
  Resolving deltas:  80% (407/508)
  Resolving deltas:  81% (412/508)
  Resolving deltas:  82% (417/508)
  Resolving deltas:  83% (422/508)
  Resolving deltas:  84% (427/508)
  Resolving deltas:  85% (432/508)
  Resolving deltas:  86% (437/508)
  Resolving deltas:  87% (442/508)
  Resolving deltas:  88% (448/508)
  Resolving deltas:  89% (453/508)
  Resolving deltas:  90% (458/508)
  Resolving deltas:  91% (463/508)
  Resolving deltas:  92% (468/508)
  Resolving deltas:  93% (473/508)
  Resolving deltas:  94% (478/508)
  Resolving deltas:  95% (483/508)
  Resolving deltas:  96% (488/508)
  Resolving deltas:  97% (493/508)
  Resolving deltas:  98% (498/508)
  Resolving deltas:  99% (503/508)
  Resolving deltas: 100% (508/508)
  Resolving deltas: 100% (508/508), done.
  Running command git show-ref integration
  6835b51e24b9a8aad299ee2950b66bf31a7758a1 refs/remotes/origin/integration
  Rev options <RevOptions git: rev='6835b51e24b9a8aad299ee2950b66bf31a7758a1'>, branch_name integration
  Running command git symbolic-ref -q HEAD
  refs/heads/main
  Running command git checkout -b integration --track origin/integration
  Switched to a new branch 'integration'
  Branch 'integration' set up to track remote branch 'integration' from 'origin'.
  Resolved https://github.com/XuehaiPan/optree.git to commit 6835b51e24b9a8aad299ee2950b66bf31a7758a1
  Running command git rev-parse HEAD
  6835b51e24b9a8aad299ee2950b66bf31a7758a1
  Added git+https://github.com/XuehaiPan/optree.git@integration to build tracker '/tmp/pip-build-tracker-5g5opr_9'
  Created temporary directory: /tmp/pip-build-env-g_x2tbsp
  Running command pip subprocess to install build dependencies
  Using pip 23.3.1 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
  Collecting setuptools
    Obtaining dependency information for setuptools from https://files.pythonhosted.org/packages/bb/26/7945080113158354380a12ce26873dd6c1ebd88d47f5bc24e2c5bb38c16a/setuptools-68.2.2-py3-none-any.whl.metadata
    Downloading setuptools-68.2.2-py3-none-any.whl.metadata (6.3 kB)
  Collecting pybind11
    Obtaining dependency information for pybind11 from https://files.pythonhosted.org/packages/06/55/9f73c32dda93fa4f539fafa268f9504e83c489f460c380371d94296126cd/pybind11-2.11.1-py3-none-any.whl.metadata
    Downloading pybind11-2.11.1-py3-none-any.whl.metadata (9.5 kB)
  Downloading setuptools-68.2.2-py3-none-any.whl (807 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 807.9/807.9 kB 4.6 MB/s eta 0:00:00
  Downloading pybind11-2.11.1-py3-none-any.whl (227 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 227.7/227.7 kB 5.0 MB/s eta 0:00:00
  Installing collected packages: setuptools, pybind11
    Creating /tmp/pip-build-env-g_x2tbsp/overlay/local/bin
    changing mode of /tmp/pip-build-env-g_x2tbsp/overlay/local/bin/pybind11-config to 755
  ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
  ipython 7.34.0 requires jedi>=0.16, which is not installed.
  lida 0.0.10 requires fastapi, which is not installed.
  lida 0.0.10 requires kaleido, which is not installed.
  lida 0.0.10 requires python-multipart, which is not installed.
  lida 0.0.10 requires uvicorn, which is not installed.
  Successfully installed pybind11-2.11.1 setuptools-68.2.2
  WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
  Installing build dependencies ... done
  Running command Getting requirements to build wheel
  running egg_info
  creating optree.egg-info
  writing optree.egg-info/PKG-INFO
  writing dependency_links to optree.egg-info/dependency_links.txt
  writing requirements to optree.egg-info/requires.txt
  writing top-level names to optree.egg-info/top_level.txt
  writing manifest file 'optree.egg-info/SOURCES.txt'
  reading manifest file 'optree.egg-info/SOURCES.txt'
  reading manifest template 'MANIFEST.in'
  adding license file 'LICENSE'
  writing manifest file 'optree.egg-info/SOURCES.txt'
  Getting requirements to build wheel ... done
  Running command pip subprocess to install backend dependencies
  Using pip 23.3.1 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
  Collecting wheel
    Obtaining dependency information for wheel from https://files.pythonhosted.org/packages/fa/7f/4c07234086edbce4a0a446209dc0cb08a19bb206a3ea53b2f56a403f983b/wheel-0.41.3-py3-none-any.whl.metadata
    Downloading wheel-0.41.3-py3-none-any.whl.metadata (2.2 kB)
  Downloading wheel-0.41.3-py3-none-any.whl (65 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.8/65.8 kB 1.5 MB/s eta 0:00:00
  Installing collected packages: wheel
    Creating /tmp/pip-build-env-g_x2tbsp/normal/local/bin
    changing mode of /tmp/pip-build-env-g_x2tbsp/normal/local/bin/wheel to 755
  Successfully installed wheel-0.41.3
  WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
  Installing backend dependencies ... done
  Created temporary directory: /tmp/pip-modern-metadata-0p7swwds
  Running command Preparing metadata (pyproject.toml)
  running dist_info
  creating /tmp/pip-modern-metadata-0p7swwds/optree.egg-info
  writing /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/PKG-INFO
  writing dependency_links to /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/dependency_links.txt
  writing requirements to /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/requires.txt
  writing top-level names to /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/top_level.txt
  writing manifest file '/tmp/pip-modern-metadata-0p7swwds/optree.egg-info/SOURCES.txt'
  reading manifest file '/tmp/pip-modern-metadata-0p7swwds/optree.egg-info/SOURCES.txt'
  reading manifest template 'MANIFEST.in'
  adding license file 'LICENSE'
  writing manifest file '/tmp/pip-modern-metadata-0p7swwds/optree.egg-info/SOURCES.txt'
  creating '/tmp/pip-modern-metadata-0p7swwds/optree-0.9.3.dev32+g6835b51.dist-info'
  Preparing metadata (pyproject.toml) ... done
  Source in /tmp/pip-req-build-r4xwy2fp has version 0.9.3.dev32+g6835b51, which satisfies requirement optree==0.9.3.dev32+g6835b51 from git+https://github.com/XuehaiPan/optree.git@integration
  Removed optree==0.9.3.dev32+g6835b51 from git+https://github.com/XuehaiPan/optree.git@integration from build tracker '/tmp/pip-build-tracker-5g5opr_9'
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from optree==0.9.3.dev32+g6835b51) (4.5.0)
Created temporary directory: /tmp/pip-unpack-h5ry5_js
Building wheels for collected packages: optree
  Running command git rev-parse HEAD
  6835b51e24b9a8aad299ee2950b66bf31a7758a1
  Created temporary directory: /tmp/pip-wheel-emk2r81l
  Destination directory: /tmp/pip-wheel-emk2r81l
  Running command Building wheel for optree (pyproject.toml)
  running bdist_wheel
  running build
  running build_py
  creating build
  creating build/lib.linux-x86_64-cpython-310
  creating build/lib.linux-x86_64-cpython-310/optree
  copying optree/__init__.py -> build/lib.linux-x86_64-cpython-310/optree
  copying optree/utils.py -> build/lib.linux-x86_64-cpython-310/optree
  copying optree/typing.py -> build/lib.linux-x86_64-cpython-310/optree
  copying optree/version.py -> build/lib.linux-x86_64-cpython-310/optree
  copying optree/registry.py -> build/lib.linux-x86_64-cpython-310/optree
  copying optree/ops.py -> build/lib.linux-x86_64-cpython-310/optree
  creating build/lib.linux-x86_64-cpython-310/optree/integration
  copying optree/integration/__init__.py -> build/lib.linux-x86_64-cpython-310/optree/integration
  copying optree/integration/jax.py -> build/lib.linux-x86_64-cpython-310/optree/integration
  copying optree/integration/numpy.py -> build/lib.linux-x86_64-cpython-310/optree/integration
  copying optree/integration/torch.py -> build/lib.linux-x86_64-cpython-310/optree/integration
  running egg_info
  writing optree.egg-info/PKG-INFO
  writing dependency_links to optree.egg-info/dependency_links.txt
  writing requirements to optree.egg-info/requires.txt
  writing top-level names to optree.egg-info/top_level.txt
  reading manifest file 'optree.egg-info/SOURCES.txt'
  reading manifest template 'MANIFEST.in'
  adding license file 'LICENSE'
  writing manifest file 'optree.egg-info/SOURCES.txt'
  copying optree/_C.pyi -> build/lib.linux-x86_64-cpython-310/optree
  copying optree/py.typed -> build/lib.linux-x86_64-cpython-310/optree
  running build_ext
  x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/usr/include/python3.10 -c flagcheck.cpp -o flagcheck.o -std=c++17
  /usr/local/bin/cmake /tmp/pip-req-build-r4xwy2fp -DCMAKE_BUILD_TYPE=Release -DCMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE=/tmp/pip-req-build-r4xwy2fp/build/lib.linux-x86_64-cpython-310/optree -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE=/tmp/pip-req-build-r4xwy2fp/build/temp.linux-x86_64-cpython-310 -DPYTHON_EXECUTABLE=/usr/bin/python3 -DPYTHON_INCLUDE_DIR=/usr/include/python3.10 -DPYBIND11_CMAKE_DIR=/tmp/pip-build-env-g_x2tbsp/overlay/local/lib/python3.10/dist-packages/pybind11/share/cmake/pybind11
  Traceback (most recent call last):
    File "/usr/local/bin/cmake", line 5, in <module>
      from cmake import cmake
  ModuleNotFoundError: No module named 'cmake'
  error: command '/usr/local/bin/cmake' failed with exit code 1
  error: subprocess-exited-with-error

  × Building wheel for optree (pyproject.toml) did not run successfully.
  │ exit code: 1
  ╰─> See above for output.

  note: This error originates from a subprocess, and is likely not a problem with pip.
  full command: /usr/bin/python3 /usr/local/lib/python3.10/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py build_wheel /tmp/tmpeu0x1mlx
  cwd: /tmp/pip-req-build-r4xwy2fp
  Building wheel for optree (pyproject.toml) ... error
  ERROR: Failed building wheel for optree
Failed to build optree
ERROR: Could not build wheels for optree, which is required to install pyproject.toml-based projects
Exception information:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 180, in exc_logging_wrapper
    status = run_func(*args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/req_command.py", line 245, in wrapper
    return func(self, options, args)
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/commands/install.py", line 429, in run
    raise InstallationError(
pip._internal.exceptions.InstallationError: Could not build wheels for optree, which is required to install pyproject.toml-based projects
Removed build tracker: '/tmp/pip-build-tracker-5g5opr_9'
WARNING: The following packages were previously imported in this runtime:
  [_distutils_hack,pkg_resources,setuptools]
You must restart the runtime in order to use newly installed versions.
XuehaiPan commented 1 year ago

It complains about cmake.

@patel-zeel You could run (add option --no-build-isolation):

!python -m pip install --upgrade pip setuptools
!python -m pip install --upgrade wheel cmake
!python -m pip install --no-build-isolation -vvv git+https://github.com/XuehaiPan/optree.git@integration
patel-zeel commented 1 year ago

Thanks for the resolution, @XuehaiPan!

Now, it works for JAX, PyTorch and NumPy.


import torch
import jax.numpy as jnp
import numpy as np

from optree.integration.torch import tree_ravel as torch_tree_ravel
from optree.integration.jax import tree_ravel as jax_tree_ravel
from optree.integration.numpy import tree_ravel as np_tree_ravel

## NumPy
np_pytree = {"dict": {"ones": np.ones(2)}, "list": [np.arange(4)]}
flat_array, unravel_fn = np_tree_ravel(np_pytree)
print(flat_array)
print(unravel_fn(flat_array))
# [1. 1. 0. 1. 2. 3.]
# {'dict': {'ones': array([1., 1.])}, 'list': [array([0, 1, 2, 3])]}

## Torch
torch_pytree = {"dict": {"ones": torch.ones(2)}, "list": [torch.arange(4)]}
flat_array, unravel_fn = torch_tree_ravel(torch_pytree)
print(flat_array)
print(unravel_fn(flat_array))
# tensor([1., 1., 0., 1., 2., 3.])
# {'dict': {'ones': tensor([1., 1.])}, 'list': [tensor([0, 1, 2, 3])]}

## JAX
jax_pytree = {"dict": {"ones": jnp.ones(2)}, "list": [jnp.arange(4)]}
flat_array, unravel_fn = jax_tree_ravel(jax_pytree)
print(flat_array)
print(unravel_fn(flat_array))
# [1. 1. 0. 1. 2. 3.]
# {'dict': {'ones': Array([1., 1.], dtype=float32)}, 'list': [Array([0, 1, 2, 3], dtype=int32)]}
XuehaiPan commented 1 year ago

Now, it works for JAX, PyTorch and NumPy.

Sounds good. Let's ship this feature.