Closed patel-zeel closed 1 year ago
Hi @patel-zeel, thanks for raising this! This is a very useful feature. And thank for the very detailed comment and the code snippet.
Currently, I have written my own version of
ravel_pytree
using existing function ofoptree
:def ravel_pytree(pytree): leaves, structure = optree.tree_flatten(pytree) shapes = [leaf.shape for leaf in leaves] sizes = [leaf.numel() for leaf in leaves] flat_params = torch.cat([leaf.flatten() for leaf in leaves]) def unravel_function(flat_params): assert flat_params.numel() == sum(sizes), f"Invalid size {flat_params.numel()} != {sum(sizes)}" assert len(flat_params.shape) == 1, f"Invalid shape {flat_params.shape}" flat_leaves = flat_params.split(sizes) leaves = [leaf.reshape(shape) for leaf, shape in zip(flat_leaves, shapes)] return optree.tree_unflatten(structure, leaves) return flat_params, unravel_function
I think the ravel_pytree
API should be available for all pytrees that contain NDArray(s). We would to have a generic implementation for torch.Tensor
, numpy.ndarray
, jax.Array
, etc. There are some small differences between these data structures. For example, for numpy.ndarray
and jax.Array
, there is no method numel
. We can do this with arr.ravel().shape[0]
.
Here is the initial API design:
concatenate_func: Callable[[Sequence[ArrayLike]], ArrayLike]
split_func: Callable[[ArrayLike, Sequence[int]], List[ArrayLike]]
def ravel_pytree(pytree, concatenate_func, split_func):
leaves, treespec = optree.tree_flatten(pytree)
flat_leaves = [leaf.ravel() for leaf in leaves]
flat_params = concatenate_func(flat_leaves)
shapes = [leaf.shape for leaf in leaves]
sizes = [flat_leaf.shape[0] for flat_leaf in flat_leaves]
total_size = flat_params.shape[0]
def unravel_function(flat_params):
assert len(flat_params.shape) == 1, f"Invalid shape {flat_params.shape}"
assert flat_params.shape[0] == total_size, f"Invalid size {flat_params.shape[0]} != {total_size}"
flat_leaves = split_func(flat_params, sizes)
leaves = [flat_leaf.reshape(shape) for flat_leaf, shape in zip(flat_leaves, shapes)]
return optree.tree_unflatten(treespec, leaves)
return flat_params, unravel_function
Now optree
is going to be the C++ implementation for the pytree utility in the PyTorch package (maybe GA in PyTorch 2.2). Would you like to have this ravel_pytree
to be implemented in the PyTorch library or the optree
package here? Any thoughts are much appreciated.
Hi @XuehaiPan,
Thank you for sharing your feedback and your idea about the generic interface! I think it will be very useful to have a PyTree library that supports all NDArray(s).
Now optree is going to be the C++ implementation for the pytree utility in the PyTorch package (maybe GA in PyTorch 2.2). Would you like to have this ravel_pytree to be implemented in the PyTorch library or the optree package here? Any thoughts are much appreciated.
I didn't know about PyTorch's plan to implement their own PyTree. Thanks for the info! Are they going to support most of the NDArray(s), including JAX and NumPy? I think other than ravel_pytree
, other functions from jax.tree_util
and optree
already support all NDArray(s) and much more. Is that right?
If PyTorch is not planning to generalize, then I'd support implementing it in optree
; otherwise, PyTorch will have a wider visibility and significantly more people would be able to use it for various purposes. I think, at the least, PyTorch's PyTree should have a ravel_pytree
function that supports torch tensors.
If PyTorch is not planning to generalize, then I'd support implementing it in
optree
;
Let's do this. I'll submit a PR to resolve this.
Hi @patel-zeel, I created a PR to resolve this issue. You can try it via:
pip3 install git+https://github.com/XuehaiPan/optree.git@integration
import optree
flat, unravel_func = optree.integration.torch.tree_ravel(tensor_tree)
# or
from optree.integration.torch import tree_ravel
flat, unravel_func = tree_ravel(tensor_tree)
Here is the output of your example program:
import torch
import torch.nn as nn
from torch.func import functional_call
import optree
torch.manual_seed(0)
x_context = torch.randn(10, 3)
y_context = torch.randn(10, 1)
x_target = torch.randn(5, 3)
hypernet = nn.Sequential(
nn.Linear(3 + 1, 10),
nn.ReLU(),
nn.Linear(10, 3 * 2 + 2),
)
net = nn.Linear(3, 2)
# Get flat parameters from the hypernet
output = hypernet(torch.cat([x_context, y_context], dim=-1)) # torch.Size([10, 8])
flat_params = output.mean(dim=0)
print(flat_params.shape) # # torch.Size([8])
# Get a function to structure flat_params for net
net_params_template = dict(net.named_parameters())
dummy_flat, unravel_func = optree.integration.torch.tree_ravel(net_params_template)
print(dummy_flat) # # torch.Size([8])
print(dummy_flat.shape) # # torch.Size([8])
# Structure the flat_params
net_params = unravel_func(flat_params)
print(net_params) # {'weight': tensor of size (2, 3), 'bias': tensor of size (2,)}
# Get the output
output = functional_call(net, net_params, x_target)
print(output.shape) # torch.Size([5, 2])
$ python3 test.py
torch.Size([8])
tensor([-0.1859, 0.0009, 0.2589, -0.4081, -0.2447, 0.1698, 0.1906, 0.4331],
grad_fn=<CatBackward0>)
torch.Size([8])
{'weight': tensor([[ 0.1633, 0.3138, 0.0756],
[ 0.1910, 0.1452, -0.1463]], grad_fn=<ViewBackward0>), 'bias': tensor([-0.5388, -0.3698], grad_fn=<ViewBackward0>)}
torch.Size([5, 2])
This is great, @XuehaiPan! I can see that tree_ravel
is added for JAX and Numpy as well.
I could not install this branch on Google Colab for some reason. Here is the Traceback.
Collecting git+https://github.com/XuehaiPan/optree.git@integration
Cloning https://github.com/XuehaiPan/optree.git (to revision integration) to /tmp/pip-req-build-8ami0owz
Running command git clone --filter=blob:none --quiet https://github.com/XuehaiPan/optree.git /tmp/pip-req-build-8ami0owz
Running command git checkout -b integration --track origin/integration
Switched to a new branch 'integration'
Branch 'integration' set up to track remote branch 'integration' from 'origin'.
Resolved https://github.com/XuehaiPan/optree.git to commit 86f4f7c61ad793dff951fd454a2f926de2a355a4
Installing build dependencies ... done
Getting requirements to build wheel ... done
Installing backend dependencies ... done
Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from optree==0.9.3.dev30+g86f4f7c) (4.5.0)
Building wheels for collected packages: optree
error: subprocess-exited-with-error
× Building wheel for optree (pyproject.toml) did not run successfully.
│ exit code: 1
╰─> See above for output.
note: This error originates from a subprocess, and is likely not a problem with pip.
Building wheel for optree (pyproject.toml) ... error
ERROR: Failed building wheel for optree
Failed to build optree
ERROR: Could not build wheels for optree, which is required to install pyproject.toml-based projects
@patel-zeel Could you enable the verbose mode while installing packages?
!pip3 install --upgrade pip setuptools wheel
!pip3 install -vvv git+https://github.com/XuehaiPan/optree.git@integration
Sure, here is the full traceback with verbose. It complains about cmake
.
Using pip 23.3.1 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
Non-user install because site-packages writeable
Created temporary directory: /tmp/pip-build-tracker-5g5opr_9
Initialized build tracking at /tmp/pip-build-tracker-5g5opr_9
Created build tracker: /tmp/pip-build-tracker-5g5opr_9
Entered build tracker: /tmp/pip-build-tracker-5g5opr_9
Created temporary directory: /tmp/pip-install-6q3rr3l7
Created temporary directory: /tmp/pip-ephem-wheel-cache-d85ykh7r
Collecting git+https://github.com/XuehaiPan/optree.git@integration
Created temporary directory: /tmp/pip-req-build-r4xwy2fp
Cloning https://github.com/XuehaiPan/optree.git (to revision integration) to /tmp/pip-req-build-r4xwy2fp
Running command git version
git version 2.34.1
Running command git clone --filter=blob:none --verbose --progress https://github.com/XuehaiPan/optree.git /tmp/pip-req-build-r4xwy2fp
Cloning into '/tmp/pip-req-build-r4xwy2fp'...
POST git-upload-pack (175 bytes)
POST git-upload-pack (gzip 1072 to 579 bytes)
remote: Enumerating objects: 940, done.
remote: Counting objects: 0% (1/313)
remote: Counting objects: 1% (4/313)
remote: Counting objects: 2% (7/313)
remote: Counting objects: 3% (10/313)
remote: Counting objects: 4% (13/313)
remote: Counting objects: 5% (16/313)
remote: Counting objects: 6% (19/313)
remote: Counting objects: 7% (22/313)
remote: Counting objects: 8% (26/313)
remote: Counting objects: 9% (29/313)
remote: Counting objects: 10% (32/313)
remote: Counting objects: 11% (35/313)
remote: Counting objects: 12% (38/313)
remote: Counting objects: 13% (41/313)
remote: Counting objects: 14% (44/313)
remote: Counting objects: 15% (47/313)
remote: Counting objects: 16% (51/313)
remote: Counting objects: 17% (54/313)
remote: Counting objects: 18% (57/313)
remote: Counting objects: 19% (60/313)
remote: Counting objects: 20% (63/313)
remote: Counting objects: 21% (66/313)
remote: Counting objects: 22% (69/313)
remote: Counting objects: 23% (72/313)
remote: Counting objects: 24% (76/313)
remote: Counting objects: 25% (79/313)
remote: Counting objects: 26% (82/313)
remote: Counting objects: 27% (85/313)
remote: Counting objects: 28% (88/313)
remote: Counting objects: 29% (91/313)
remote: Counting objects: 30% (94/313)
remote: Counting objects: 31% (98/313)
remote: Counting objects: 32% (101/313)
remote: Counting objects: 33% (104/313)
remote: Counting objects: 34% (107/313)
remote: Counting objects: 35% (110/313)
remote: Counting objects: 36% (113/313)
remote: Counting objects: 37% (116/313)
remote: Counting objects: 38% (119/313)
remote: Counting objects: 39% (123/313)
remote: Counting objects: 40% (126/313)
remote: Counting objects: 41% (129/313)
remote: Counting objects: 42% (132/313)
remote: Counting objects: 43% (135/313)
remote: Counting objects: 44% (138/313)
remote: Counting objects: 45% (141/313)
remote: Counting objects: 46% (144/313)
remote: Counting objects: 47% (148/313)
remote: Counting objects: 48% (151/313)
remote: Counting objects: 49% (154/313)
remote: Counting objects: 50% (157/313)
remote: Counting objects: 51% (160/313)
remote: Counting objects: 52% (163/313)
remote: Counting objects: 53% (166/313)
remote: Counting objects: 54% (170/313)
remote: Counting objects: 55% (173/313)
remote: Counting objects: 56% (176/313)
remote: Counting objects: 57% (179/313)
remote: Counting objects: 58% (182/313)
remote: Counting objects: 59% (185/313)
remote: Counting objects: 60% (188/313)
remote: Counting objects: 61% (191/313)
remote: Counting objects: 62% (195/313)
remote: Counting objects: 63% (198/313)
remote: Counting objects: 64% (201/313)
remote: Counting objects: 65% (204/313)
remote: Counting objects: 66% (207/313)
remote: Counting objects: 67% (210/313)
remote: Counting objects: 68% (213/313)
remote: Counting objects: 69% (216/313)
remote: Counting objects: 70% (220/313)
remote: Counting objects: 71% (223/313)
remote: Counting objects: 72% (226/313)
remote: Counting objects: 73% (229/313)
remote: Counting objects: 74% (232/313)
remote: Counting objects: 75% (235/313)
remote: Counting objects: 76% (238/313)
remote: Counting objects: 77% (242/313)
remote: Counting objects: 78% (245/313)
remote: Counting objects: 79% (248/313)
remote: Counting objects: 80% (251/313)
remote: Counting objects: 81% (254/313)
remote: Counting objects: 82% (257/313)
remote: Counting objects: 83% (260/313)
remote: Counting objects: 84% (263/313)
remote: Counting objects: 85% (267/313)
remote: Counting objects: 86% (270/313)
remote: Counting objects: 87% (273/313)
remote: Counting objects: 88% (276/313)
remote: Counting objects: 89% (279/313)
remote: Counting objects: 90% (282/313)
remote: Counting objects: 91% (285/313)
remote: Counting objects: 92% (288/313)
remote: Counting objects: 93% (292/313)
remote: Counting objects: 94% (295/313)
remote: Counting objects: 95% (298/313)
remote: Counting objects: 96% (301/313)
remote: Counting objects: 97% (304/313)
remote: Counting objects: 98% (307/313)
remote: Counting objects: 99% (310/313)
remote: Counting objects: 100% (313/313)
remote: Counting objects: 100% (313/313), done.
remote: Compressing objects: 0% (1/127)
remote: Compressing objects: 1% (2/127)
remote: Compressing objects: 2% (3/127)
remote: Compressing objects: 3% (4/127)
remote: Compressing objects: 4% (6/127)
remote: Compressing objects: 5% (7/127)
remote: Compressing objects: 6% (8/127)
remote: Compressing objects: 7% (9/127)
remote: Compressing objects: 8% (11/127)
remote: Compressing objects: 9% (12/127)
remote: Compressing objects: 10% (13/127)
remote: Compressing objects: 11% (14/127)
remote: Compressing objects: 12% (16/127)
remote: Compressing objects: 13% (17/127)
remote: Compressing objects: 14% (18/127)
remote: Compressing objects: 15% (20/127)
remote: Compressing objects: 16% (21/127)
remote: Compressing objects: 17% (22/127)
remote: Compressing objects: 18% (23/127)
remote: Compressing objects: 19% (25/127)
remote: Compressing objects: 20% (26/127)
remote: Compressing objects: 21% (27/127)
remote: Compressing objects: 22% (28/127)
remote: Compressing objects: 23% (30/127)
remote: Compressing objects: 24% (31/127)
remote: Compressing objects: 25% (32/127)
remote: Compressing objects: 26% (34/127)
remote: Compressing objects: 27% (35/127)
remote: Compressing objects: 28% (36/127)
remote: Compressing objects: 29% (37/127)
remote: Compressing objects: 30% (39/127)
remote: Compressing objects: 31% (40/127)
remote: Compressing objects: 32% (41/127)
remote: Compressing objects: 33% (42/127)
remote: Compressing objects: 34% (44/127)
remote: Compressing objects: 35% (45/127)
remote: Compressing objects: 36% (46/127)
remote: Compressing objects: 37% (47/127)
remote: Compressing objects: 38% (49/127)
remote: Compressing objects: 39% (50/127)
remote: Compressing objects: 40% (51/127)
remote: Compressing objects: 41% (53/127)
remote: Compressing objects: 42% (54/127)
remote: Compressing objects: 43% (55/127)
remote: Compressing objects: 44% (56/127)
remote: Compressing objects: 45% (58/127)
remote: Compressing objects: 46% (59/127)
remote: Compressing objects: 47% (60/127)
remote: Compressing objects: 48% (61/127)
remote: Compressing objects: 49% (63/127)
remote: Compressing objects: 50% (64/127)
remote: Compressing objects: 51% (65/127)
remote: Compressing objects: 52% (67/127)
remote: Compressing objects: 53% (68/127)
remote: Compressing objects: 54% (69/127)
remote: Compressing objects: 55% (70/127)
remote: Compressing objects: 56% (72/127)
remote: Compressing objects: 57% (73/127)
remote: Compressing objects: 58% (74/127)
remote: Compressing objects: 59% (75/127)
remote: Compressing objects: 60% (77/127)
remote: Compressing objects: 61% (78/127)
remote: Compressing objects: 62% (79/127)
remote: Compressing objects: 63% (81/127)
remote: Compressing objects: 64% (82/127)
remote: Compressing objects: 65% (83/127)
remote: Compressing objects: 66% (84/127)
remote: Compressing objects: 67% (86/127)
remote: Compressing objects: 68% (87/127)
remote: Compressing objects: 69% (88/127)
remote: Compressing objects: 70% (89/127)
remote: Compressing objects: 71% (91/127)
remote: Compressing objects: 72% (92/127)
remote: Compressing objects: 73% (93/127)
remote: Compressing objects: 74% (94/127)
remote: Compressing objects: 75% (96/127)
remote: Compressing objects: 76% (97/127)
remote: Compressing objects: 77% (98/127)
remote: Compressing objects: 78% (100/127)
remote: Compressing objects: 79% (101/127)
remote: Compressing objects: 80% (102/127)
remote: Compressing objects: 81% (103/127)
remote: Compressing objects: 82% (105/127)
remote: Compressing objects: 83% (106/127)
remote: Compressing objects: 84% (107/127)
remote: Compressing objects: 85% (108/127)
remote: Compressing objects: 86% (110/127)
remote: Compressing objects: 87% (111/127)
remote: Compressing objects: 88% (112/127)
remote: Compressing objects: 89% (114/127)
remote: Compressing objects: 90% (115/127)
remote: Compressing objects: 91% (116/127)
remote: Compressing objects: 92% (117/127)
remote: Compressing objects: 93% (119/127)
remote: Compressing objects: 94% (120/127)
remote: Compressing objects: 95% (121/127)
remote: Compressing objects: 96% (122/127)
remote: Compressing objects: 97% (124/127)
remote: Compressing objects: 98% (125/127)
remote: Compressing objects: 99% (126/127)
remote: Compressing objects: 100% (127/127)
remote: Compressing objects: 100% (127/127), done.
Receiving objects: 0% (1/940)
Receiving objects: 1% (10/940)
Receiving objects: 2% (19/940)
Receiving objects: 3% (29/940)
Receiving objects: 4% (38/940)
Receiving objects: 5% (47/940)
Receiving objects: 6% (57/940)
Receiving objects: 7% (66/940)
Receiving objects: 8% (76/940)
Receiving objects: 9% (85/940)
Receiving objects: 10% (94/940)
Receiving objects: 11% (104/940)
Receiving objects: 12% (113/940)
Receiving objects: 13% (123/940)
Receiving objects: 14% (132/940)
Receiving objects: 15% (141/940)
Receiving objects: 16% (151/940)
Receiving objects: 17% (160/940)
Receiving objects: 18% (170/940)
Receiving objects: 19% (179/940)
Receiving objects: 20% (188/940)
Receiving objects: 21% (198/940)
Receiving objects: 22% (207/940)
Receiving objects: 23% (217/940)
Receiving objects: 24% (226/940)
Receiving objects: 25% (235/940)
Receiving objects: 26% (245/940)
Receiving objects: 27% (254/940)
Receiving objects: 28% (264/940)
Receiving objects: 29% (273/940)
Receiving objects: 30% (282/940)
Receiving objects: 31% (292/940)
Receiving objects: 32% (301/940)
Receiving objects: 33% (311/940)
Receiving objects: 34% (320/940)
Receiving objects: 35% (329/940)
Receiving objects: 36% (339/940)
Receiving objects: 37% (348/940)
Receiving objects: 38% (358/940)
Receiving objects: 39% (367/940)
Receiving objects: 40% (376/940)
Receiving objects: 41% (386/940)
Receiving objects: 42% (395/940)
Receiving objects: 43% (405/940)
Receiving objects: 44% (414/940)
Receiving objects: 45% (423/940)
Receiving objects: 46% (433/940)
Receiving objects: 47% (442/940)
Receiving objects: 48% (452/940)
Receiving objects: 49% (461/940)
Receiving objects: 50% (470/940)
Receiving objects: 51% (480/940)
Receiving objects: 52% (489/940)
Receiving objects: 53% (499/940)
Receiving objects: 54% (508/940)
Receiving objects: 55% (517/940)
Receiving objects: 56% (527/940)
Receiving objects: 57% (536/940)
Receiving objects: 58% (546/940)
Receiving objects: 59% (555/940)
Receiving objects: 60% (564/940)
Receiving objects: 61% (574/940)
Receiving objects: 62% (583/940)
Receiving objects: 63% (593/940)
Receiving objects: 64% (602/940)
Receiving objects: 65% (611/940)
Receiving objects: 66% (621/940)
Receiving objects: 67% (630/940)
Receiving objects: 68% (640/940)
Receiving objects: 69% (649/940)
Receiving objects: 70% (658/940)
Receiving objects: 71% (668/940)
Receiving objects: 72% (677/940)
Receiving objects: 73% (687/940)
remote: Total 940 (delta 228), reused 242 (delta 185), pack-reused 627
Receiving objects: 74% (696/940)
Receiving objects: 75% (705/940)
Receiving objects: 76% (715/940)
Receiving objects: 77% (724/940)
Receiving objects: 78% (734/940)
Receiving objects: 79% (743/940)
Receiving objects: 80% (752/940)
Receiving objects: 81% (762/940)
Receiving objects: 82% (771/940)
Receiving objects: 83% (781/940)
Receiving objects: 84% (790/940)
Receiving objects: 85% (799/940)
Receiving objects: 86% (809/940)
Receiving objects: 87% (818/940)
Receiving objects: 88% (828/940)
Receiving objects: 89% (837/940)
Receiving objects: 90% (846/940)
Receiving objects: 91% (856/940)
Receiving objects: 92% (865/940)
Receiving objects: 93% (875/940)
Receiving objects: 94% (884/940)
Receiving objects: 95% (893/940)
Receiving objects: 96% (903/940)
Receiving objects: 97% (912/940)
Receiving objects: 98% (922/940)
Receiving objects: 99% (931/940)
Receiving objects: 100% (940/940)
Receiving objects: 100% (940/940), 150.04 KiB | 3.41 MiB/s, done.
Resolving deltas: 0% (0/508)
Resolving deltas: 1% (6/508)
Resolving deltas: 2% (11/508)
Resolving deltas: 3% (16/508)
Resolving deltas: 4% (21/508)
Resolving deltas: 5% (26/508)
Resolving deltas: 6% (31/508)
Resolving deltas: 7% (36/508)
Resolving deltas: 8% (41/508)
Resolving deltas: 9% (46/508)
Resolving deltas: 10% (51/508)
Resolving deltas: 11% (56/508)
Resolving deltas: 12% (61/508)
Resolving deltas: 13% (67/508)
Resolving deltas: 14% (72/508)
Resolving deltas: 15% (77/508)
Resolving deltas: 16% (82/508)
Resolving deltas: 17% (87/508)
Resolving deltas: 18% (92/508)
Resolving deltas: 19% (97/508)
Resolving deltas: 20% (102/508)
Resolving deltas: 21% (107/508)
Resolving deltas: 22% (112/508)
Resolving deltas: 23% (117/508)
Resolving deltas: 24% (122/508)
Resolving deltas: 25% (127/508)
Resolving deltas: 26% (133/508)
Resolving deltas: 27% (138/508)
Resolving deltas: 28% (143/508)
Resolving deltas: 29% (148/508)
Resolving deltas: 30% (153/508)
Resolving deltas: 31% (158/508)
Resolving deltas: 32% (163/508)
Resolving deltas: 33% (168/508)
Resolving deltas: 34% (173/508)
Resolving deltas: 35% (178/508)
Resolving deltas: 36% (183/508)
Resolving deltas: 37% (188/508)
Resolving deltas: 38% (194/508)
Resolving deltas: 39% (199/508)
Resolving deltas: 40% (204/508)
Resolving deltas: 41% (209/508)
Resolving deltas: 42% (214/508)
Resolving deltas: 43% (219/508)
Resolving deltas: 44% (224/508)
Resolving deltas: 45% (229/508)
Resolving deltas: 46% (234/508)
Resolving deltas: 47% (239/508)
Resolving deltas: 48% (244/508)
Resolving deltas: 49% (249/508)
Resolving deltas: 50% (254/508)
Resolving deltas: 51% (260/508)
Resolving deltas: 52% (265/508)
Resolving deltas: 53% (270/508)
Resolving deltas: 54% (275/508)
Resolving deltas: 55% (280/508)
Resolving deltas: 56% (285/508)
Resolving deltas: 57% (290/508)
Resolving deltas: 58% (295/508)
Resolving deltas: 59% (300/508)
Resolving deltas: 60% (305/508)
Resolving deltas: 61% (310/508)
Resolving deltas: 62% (315/508)
Resolving deltas: 63% (321/508)
Resolving deltas: 64% (326/508)
Resolving deltas: 65% (331/508)
Resolving deltas: 66% (336/508)
Resolving deltas: 67% (341/508)
Resolving deltas: 68% (346/508)
Resolving deltas: 69% (351/508)
Resolving deltas: 70% (356/508)
Resolving deltas: 71% (361/508)
Resolving deltas: 72% (366/508)
Resolving deltas: 73% (371/508)
Resolving deltas: 74% (376/508)
Resolving deltas: 75% (381/508)
Resolving deltas: 76% (387/508)
Resolving deltas: 77% (392/508)
Resolving deltas: 78% (397/508)
Resolving deltas: 79% (402/508)
Resolving deltas: 80% (407/508)
Resolving deltas: 81% (412/508)
Resolving deltas: 82% (417/508)
Resolving deltas: 83% (422/508)
Resolving deltas: 84% (427/508)
Resolving deltas: 85% (432/508)
Resolving deltas: 86% (437/508)
Resolving deltas: 87% (442/508)
Resolving deltas: 88% (448/508)
Resolving deltas: 89% (453/508)
Resolving deltas: 90% (458/508)
Resolving deltas: 91% (463/508)
Resolving deltas: 92% (468/508)
Resolving deltas: 93% (473/508)
Resolving deltas: 94% (478/508)
Resolving deltas: 95% (483/508)
Resolving deltas: 96% (488/508)
Resolving deltas: 97% (493/508)
Resolving deltas: 98% (498/508)
Resolving deltas: 99% (503/508)
Resolving deltas: 100% (508/508)
Resolving deltas: 100% (508/508), done.
Running command git show-ref integration
6835b51e24b9a8aad299ee2950b66bf31a7758a1 refs/remotes/origin/integration
Rev options <RevOptions git: rev='6835b51e24b9a8aad299ee2950b66bf31a7758a1'>, branch_name integration
Running command git symbolic-ref -q HEAD
refs/heads/main
Running command git checkout -b integration --track origin/integration
Switched to a new branch 'integration'
Branch 'integration' set up to track remote branch 'integration' from 'origin'.
Resolved https://github.com/XuehaiPan/optree.git to commit 6835b51e24b9a8aad299ee2950b66bf31a7758a1
Running command git rev-parse HEAD
6835b51e24b9a8aad299ee2950b66bf31a7758a1
Added git+https://github.com/XuehaiPan/optree.git@integration to build tracker '/tmp/pip-build-tracker-5g5opr_9'
Created temporary directory: /tmp/pip-build-env-g_x2tbsp
Running command pip subprocess to install build dependencies
Using pip 23.3.1 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
Collecting setuptools
Obtaining dependency information for setuptools from https://files.pythonhosted.org/packages/bb/26/7945080113158354380a12ce26873dd6c1ebd88d47f5bc24e2c5bb38c16a/setuptools-68.2.2-py3-none-any.whl.metadata
Downloading setuptools-68.2.2-py3-none-any.whl.metadata (6.3 kB)
Collecting pybind11
Obtaining dependency information for pybind11 from https://files.pythonhosted.org/packages/06/55/9f73c32dda93fa4f539fafa268f9504e83c489f460c380371d94296126cd/pybind11-2.11.1-py3-none-any.whl.metadata
Downloading pybind11-2.11.1-py3-none-any.whl.metadata (9.5 kB)
Downloading setuptools-68.2.2-py3-none-any.whl (807 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 807.9/807.9 kB 4.6 MB/s eta 0:00:00
Downloading pybind11-2.11.1-py3-none-any.whl (227 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 227.7/227.7 kB 5.0 MB/s eta 0:00:00
Installing collected packages: setuptools, pybind11
Creating /tmp/pip-build-env-g_x2tbsp/overlay/local/bin
changing mode of /tmp/pip-build-env-g_x2tbsp/overlay/local/bin/pybind11-config to 755
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.
lida 0.0.10 requires fastapi, which is not installed.
lida 0.0.10 requires kaleido, which is not installed.
lida 0.0.10 requires python-multipart, which is not installed.
lida 0.0.10 requires uvicorn, which is not installed.
Successfully installed pybind11-2.11.1 setuptools-68.2.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Installing build dependencies ... done
Running command Getting requirements to build wheel
running egg_info
creating optree.egg-info
writing optree.egg-info/PKG-INFO
writing dependency_links to optree.egg-info/dependency_links.txt
writing requirements to optree.egg-info/requires.txt
writing top-level names to optree.egg-info/top_level.txt
writing manifest file 'optree.egg-info/SOURCES.txt'
reading manifest file 'optree.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'optree.egg-info/SOURCES.txt'
Getting requirements to build wheel ... done
Running command pip subprocess to install backend dependencies
Using pip 23.3.1 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10)
Collecting wheel
Obtaining dependency information for wheel from https://files.pythonhosted.org/packages/fa/7f/4c07234086edbce4a0a446209dc0cb08a19bb206a3ea53b2f56a403f983b/wheel-0.41.3-py3-none-any.whl.metadata
Downloading wheel-0.41.3-py3-none-any.whl.metadata (2.2 kB)
Downloading wheel-0.41.3-py3-none-any.whl (65 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.8/65.8 kB 1.5 MB/s eta 0:00:00
Installing collected packages: wheel
Creating /tmp/pip-build-env-g_x2tbsp/normal/local/bin
changing mode of /tmp/pip-build-env-g_x2tbsp/normal/local/bin/wheel to 755
Successfully installed wheel-0.41.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Installing backend dependencies ... done
Created temporary directory: /tmp/pip-modern-metadata-0p7swwds
Running command Preparing metadata (pyproject.toml)
running dist_info
creating /tmp/pip-modern-metadata-0p7swwds/optree.egg-info
writing /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/PKG-INFO
writing dependency_links to /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/dependency_links.txt
writing requirements to /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/requires.txt
writing top-level names to /tmp/pip-modern-metadata-0p7swwds/optree.egg-info/top_level.txt
writing manifest file '/tmp/pip-modern-metadata-0p7swwds/optree.egg-info/SOURCES.txt'
reading manifest file '/tmp/pip-modern-metadata-0p7swwds/optree.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file '/tmp/pip-modern-metadata-0p7swwds/optree.egg-info/SOURCES.txt'
creating '/tmp/pip-modern-metadata-0p7swwds/optree-0.9.3.dev32+g6835b51.dist-info'
Preparing metadata (pyproject.toml) ... done
Source in /tmp/pip-req-build-r4xwy2fp has version 0.9.3.dev32+g6835b51, which satisfies requirement optree==0.9.3.dev32+g6835b51 from git+https://github.com/XuehaiPan/optree.git@integration
Removed optree==0.9.3.dev32+g6835b51 from git+https://github.com/XuehaiPan/optree.git@integration from build tracker '/tmp/pip-build-tracker-5g5opr_9'
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from optree==0.9.3.dev32+g6835b51) (4.5.0)
Created temporary directory: /tmp/pip-unpack-h5ry5_js
Building wheels for collected packages: optree
Running command git rev-parse HEAD
6835b51e24b9a8aad299ee2950b66bf31a7758a1
Created temporary directory: /tmp/pip-wheel-emk2r81l
Destination directory: /tmp/pip-wheel-emk2r81l
Running command Building wheel for optree (pyproject.toml)
running bdist_wheel
running build
running build_py
creating build
creating build/lib.linux-x86_64-cpython-310
creating build/lib.linux-x86_64-cpython-310/optree
copying optree/__init__.py -> build/lib.linux-x86_64-cpython-310/optree
copying optree/utils.py -> build/lib.linux-x86_64-cpython-310/optree
copying optree/typing.py -> build/lib.linux-x86_64-cpython-310/optree
copying optree/version.py -> build/lib.linux-x86_64-cpython-310/optree
copying optree/registry.py -> build/lib.linux-x86_64-cpython-310/optree
copying optree/ops.py -> build/lib.linux-x86_64-cpython-310/optree
creating build/lib.linux-x86_64-cpython-310/optree/integration
copying optree/integration/__init__.py -> build/lib.linux-x86_64-cpython-310/optree/integration
copying optree/integration/jax.py -> build/lib.linux-x86_64-cpython-310/optree/integration
copying optree/integration/numpy.py -> build/lib.linux-x86_64-cpython-310/optree/integration
copying optree/integration/torch.py -> build/lib.linux-x86_64-cpython-310/optree/integration
running egg_info
writing optree.egg-info/PKG-INFO
writing dependency_links to optree.egg-info/dependency_links.txt
writing requirements to optree.egg-info/requires.txt
writing top-level names to optree.egg-info/top_level.txt
reading manifest file 'optree.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'optree.egg-info/SOURCES.txt'
copying optree/_C.pyi -> build/lib.linux-x86_64-cpython-310/optree
copying optree/py.typed -> build/lib.linux-x86_64-cpython-310/optree
running build_ext
x86_64-linux-gnu-gcc -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/usr/include/python3.10 -c flagcheck.cpp -o flagcheck.o -std=c++17
/usr/local/bin/cmake /tmp/pip-req-build-r4xwy2fp -DCMAKE_BUILD_TYPE=Release -DCMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE=/tmp/pip-req-build-r4xwy2fp/build/lib.linux-x86_64-cpython-310/optree -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE=/tmp/pip-req-build-r4xwy2fp/build/temp.linux-x86_64-cpython-310 -DPYTHON_EXECUTABLE=/usr/bin/python3 -DPYTHON_INCLUDE_DIR=/usr/include/python3.10 -DPYBIND11_CMAKE_DIR=/tmp/pip-build-env-g_x2tbsp/overlay/local/lib/python3.10/dist-packages/pybind11/share/cmake/pybind11
Traceback (most recent call last):
File "/usr/local/bin/cmake", line 5, in <module>
from cmake import cmake
ModuleNotFoundError: No module named 'cmake'
error: command '/usr/local/bin/cmake' failed with exit code 1
error: subprocess-exited-with-error
× Building wheel for optree (pyproject.toml) did not run successfully.
│ exit code: 1
╰─> See above for output.
note: This error originates from a subprocess, and is likely not a problem with pip.
full command: /usr/bin/python3 /usr/local/lib/python3.10/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py build_wheel /tmp/tmpeu0x1mlx
cwd: /tmp/pip-req-build-r4xwy2fp
Building wheel for optree (pyproject.toml) ... error
ERROR: Failed building wheel for optree
Failed to build optree
ERROR: Could not build wheels for optree, which is required to install pyproject.toml-based projects
Exception information:
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 180, in exc_logging_wrapper
status = run_func(*args)
File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/req_command.py", line 245, in wrapper
return func(self, options, args)
File "/usr/local/lib/python3.10/dist-packages/pip/_internal/commands/install.py", line 429, in run
raise InstallationError(
pip._internal.exceptions.InstallationError: Could not build wheels for optree, which is required to install pyproject.toml-based projects
Removed build tracker: '/tmp/pip-build-tracker-5g5opr_9'
WARNING: The following packages were previously imported in this runtime:
[_distutils_hack,pkg_resources,setuptools]
You must restart the runtime in order to use newly installed versions.
It complains about
cmake
.
@patel-zeel You could run (add option --no-build-isolation
):
!python -m pip install --upgrade pip setuptools
!python -m pip install --upgrade wheel cmake
!python -m pip install --no-build-isolation -vvv git+https://github.com/XuehaiPan/optree.git@integration
Thanks for the resolution, @XuehaiPan!
Now, it works for JAX, PyTorch and NumPy.
import torch
import jax.numpy as jnp
import numpy as np
from optree.integration.torch import tree_ravel as torch_tree_ravel
from optree.integration.jax import tree_ravel as jax_tree_ravel
from optree.integration.numpy import tree_ravel as np_tree_ravel
## NumPy
np_pytree = {"dict": {"ones": np.ones(2)}, "list": [np.arange(4)]}
flat_array, unravel_fn = np_tree_ravel(np_pytree)
print(flat_array)
print(unravel_fn(flat_array))
# [1. 1. 0. 1. 2. 3.]
# {'dict': {'ones': array([1., 1.])}, 'list': [array([0, 1, 2, 3])]}
## Torch
torch_pytree = {"dict": {"ones": torch.ones(2)}, "list": [torch.arange(4)]}
flat_array, unravel_fn = torch_tree_ravel(torch_pytree)
print(flat_array)
print(unravel_fn(flat_array))
# tensor([1., 1., 0., 1., 2., 3.])
# {'dict': {'ones': tensor([1., 1.])}, 'list': [tensor([0, 1, 2, 3])]}
## JAX
jax_pytree = {"dict": {"ones": jnp.ones(2)}, "list": [jnp.arange(4)]}
flat_array, unravel_fn = jax_tree_ravel(jax_pytree)
print(flat_array)
print(unravel_fn(flat_array))
# [1. 1. 0. 1. 2. 3.]
# {'dict': {'ones': Array([1., 1.], dtype=float32)}, 'list': [Array([0, 1, 2, 3], dtype=int32)]}
Now, it works for JAX, PyTorch and NumPy.
Sounds good. Let's ship this feature.
Required prerequisites
Motivation
Hi,
What does this function do?
This function is motivated from JAX's
ravel_pytree
which converts a PyTree into a 1D tensor and provides a function to convert that 1D tensor back to the original PyTree format.This function is useful for these applications: i) Ease the process of defining "Hypernetworks"; ii) Hessian computation of neural networks (e.g., Laplace approximation)
Challanges in Hypernetworks
In "Hypernetworks" line of work, we have a hypernetwork
hypernet
, which generates the parameters of another networknet
. Typically,hypernet
's last layer will be aLinear
/Dense
layer, and thus it will be in a "flattened" format (1D tensor). It is challenging to convert this flat tensor into parameters ofnet
.Challenges in computing Hessians in Laplace approximation.
Laplace approximation of neural network requires one to compute the hessian of the negative log joint w.r.t. parameters and then invert it. The inverse of Hessian is the covariance matrix of the posterior normal distribution. One has to flatten the parameters to get the hessian. Also, a sample parameter from this posterior is a flat array that needs to be restructured before inserting into the network.
Solution
Currently, I have written my own version of
ravel_pytree
using the existing functions inoptree
:Here is how it can be used for different applications.
Hypernetworks
We would need to transform the flat output of
hypernet
to a structure that matchesnet.named_parameters()
and then trigger a functional call withtorch.func.functional_call(net, named_parameters, inputs)
. A code snipped would look like this:Laplace Approximation
Alternatives
An alternative is to use the manually defined function above or to use some other way.
Additional context
No response