ml-research / self-expanding-neural-networks

Self-Expanding Neural Networks
MIT License
37 stars 5 forks source link

copy() argument mismatch #1

Open tarolling opened 9 months ago

tarolling commented 9 months ago

Description

While running experiment 4, got a super specific TypeError with the copy() command (see error trace below). Did some digging for a little bit, but couldn't find any problems like this, so just thought I'd submit an issue to see what could be done (or what I did wrong). Haven't changed any files except for log and data files that were produced during the run.

Setup

Platform: WSL (using Ubuntu 22.04) w/ VSC Setup: Followed instructions w/ Docker for MLP version of SENN; everything built correctly

Stack Trace

root@cb530e4c3f76:/senn# python experiment4.py --name example_exp_4
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /datasets/MNIST/raw/train-images-idx3-ubyte.gz
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:01<00:00, 5386572.61it/s]
Extracting /datasets/MNIST/raw/train-images-idx3-ubyte.gz to /datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /datasets/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 93324879.68it/s]
Extracting /datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 5157487.35it/s]
Extracting /datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 11179887.77it/s]
Extracting /datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /datasets/MNIST/raw

Compiling data tranch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 47999/48000 [00:28<00:00, 1677.96it/s]
Compiling data tranch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 9999/10000 [00:05<00:00, 1763.20it/s]
contents: [10, 10]
Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 1682, in apply
    return apply(
  File "/usr/local/lib/python3.8/dist-packages/flax/core/scope.py", line 998, in wrapper
    y = fn(root, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 2307, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 467, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 967, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
jax._src.traceback_util.UnfilteredStackTrace: TypeError: copy() takes no arguments (1 given)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
TypeError: copy() takes no arguments (1 given)
root@cb530e4c3f76:/senn# python --version
Python 3.8.10
root@cb530e4c3f76:/senn# python experiment4.py --name example_exp_4
Compiling data tranch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 47999/48000 [00:29<00:00, 1634.44it/s]
Compiling data tranch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 9999/10000 [00:05<00:00, 1814.47it/s]
contents: [10, 10]
Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 1682, in apply
    return apply(
  File "/usr/local/lib/python3.8/dist-packages/flax/core/scope.py", line 998, in wrapper
    y = fn(root, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 2307, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 467, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 967, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
jax._src.traceback_util.UnfilteredStackTrace: TypeError: copy() takes no arguments (1 given)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
TypeError: copy() takes no arguments (1 given)
xierongpytorch commented 6 months ago

Description

While running experiment 4, got a super specific TypeError with the copy() command (see error trace below). Did some digging for a little bit, but couldn't find any problems like this, so just thought I'd submit an issue to see what could be done (or what I did wrong). Haven't changed any files except for log and data files that were produced during the run.

Setup

Platform: WSL (using Ubuntu 22.04) w/ VSC Setup: Followed instructions w/ Docker for MLP version of SENN; everything built correctly

Stack Trace

root@cb530e4c3f76:/senn# python experiment4.py --name example_exp_4
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /datasets/MNIST/raw/train-images-idx3-ubyte.gz
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:01<00:00, 5386572.61it/s]
Extracting /datasets/MNIST/raw/train-images-idx3-ubyte.gz to /datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /datasets/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 93324879.68it/s]
Extracting /datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 5157487.35it/s]
Extracting /datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 11179887.77it/s]
Extracting /datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /datasets/MNIST/raw

Compiling data tranch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 47999/48000 [00:28<00:00, 1677.96it/s]
Compiling data tranch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 9999/10000 [00:05<00:00, 1763.20it/s]
contents: [10, 10]
Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 1682, in apply
    return apply(
  File "/usr/local/lib/python3.8/dist-packages/flax/core/scope.py", line 998, in wrapper
    y = fn(root, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 2307, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 467, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 967, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
jax._src.traceback_util.UnfilteredStackTrace: TypeError: copy() takes no arguments (1 given)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
TypeError: copy() takes no arguments (1 given)
root@cb530e4c3f76:/senn# python --version
Python 3.8.10
root@cb530e4c3f76:/senn# python experiment4.py --name example_exp_4
Compiling data tranch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 47999/48000 [00:29<00:00, 1634.44it/s]
Compiling data tranch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 9999/10000 [00:05<00:00, 1814.47it/s]
contents: [10, 10]
Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 1682, in apply
    return apply(
  File "/usr/local/lib/python3.8/dist-packages/flax/core/scope.py", line 998, in wrapper
    y = fn(root, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 2307, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 467, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/module.py", line 967, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
jax._src.traceback_util.UnfilteredStackTrace: TypeError: copy() takes no arguments (1 given)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "experiment4.py", line 1291, in <module>
    main()
  File "experiment4.py", line 723, in main
    solver = Solver(cfg, template, task, next(key), example)
  File "experiment4.py", line 167, in __init__
    self.state = self.model.apply(self.state, 
  File "/senn/nets.py", line 537, in restrict_params
    return state.copy({
TypeError: copy() takes no arguments (1 given)

Hi! I have also configured probably the same environment and encountered the same problem as you. I would be grateful if you have a solution.

RaiiZen1 commented 2 months ago

Same problem here as well.

tarolling commented 1 month ago

Providing an update on this: I'm fairly certain this issue is caused because the Python packages that are in the Docker image and in the requirements.txt are not solved correctly. In a nutshell, PyTorch and TensorFlow don't play well together, and reproducing their environment step-by-step with Poetry revealed that they are battling over conflicting versions of Nvidia driver packages.

Because no versions of torchvision match >0.19.1,<0.20.0
 and torchvision (0.19.1) depends on torch (2.4.1), torchvision (>=0.19.1,<0.20.0) requires torch (2.4.1).
(1) So, because torch (2.4.1) depends on nvidia-nccl-cu12 (2.20.5), torchvision (>=0.19.1,<0.20.0) requires nvidia-nccl-cu12 (2.20.5).

    Because no versions of tensorflow match >2.17.0,<2.18.0rc0 || >2.18.0rc0,<2.18.0rc1 || >2.18.0rc1,<3.0.0
 and tensorflow[and-cuda] (2.18.0rc0) depends on nvidia-nccl-cu12 (2.21.5), tensorflow[and-cuda] (>2.17.0,<2.18.0rc1 || >2.18.0rc1,<3.0.0) requires nvidia-nccl-cu12 (2.21.5).
    And because tensorflow[and-cuda] (2.17.0) depends on nvidia-nccl-cu12 (2.19.3)
 and tensorflow[and-cuda] (2.18.0rc1) depends on nvidia-nccl-cu12 (2.21.5), tensorflow[and-cuda] (>=2.17.0,<3.0.0) requires nvidia-nccl-cu12 (2.19.3 || 2.21.5).
    And because torchvision (>=0.19.1,<0.20.0) requires nvidia-nccl-cu12 (2.20.5) (1), torchvision (>=0.19.1,<0.20.0) is incompatible with tensorflow[and-cuda] (>=2.17.0,<3.0.0)
    So, because sernn depends on both tensorflow[and-cuda] (^2.17.0) and torchvision (^0.19.1), version solving failed.

This is at least one of the issues I spotted, which may or may not deal with this original issue. I suspect, however, that Jax got caught in the middle of this (Jax has its own CUDA packages) and versions didn't align, thus producing this error. I don't want to bother finding the Nvidia toolkit versions + Python version you have to install to get torch and TF to play well with each other, but since PyTorch is really only used for loading in datasets here, I might think about modifying my fork so that things will actually run.