google-research / jaxpruner

Apache License 2.0
206 stars 14 forks source link

wrapped optimizer with params #5

Closed louieworth closed 1 year ago

louieworth commented 1 year ago

Following the quick_start.ipynb document in #Modification #1, I create a wrapped optimizer with params info from the model, where it occurs a bug and I do not know how to fix it.

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/site-packages/jaxpruner/base_updater.py", line 225, in init_fn
    sparse_state = self.init_state(params)
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/site-packages/jaxpruner/base_updater.py", line 167, in init_state
    target_sparsities = self.sparsity_distribution_fn(params)
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/site-packages/jaxpruner/sparsity_distributions.py", line 72, in uniform
    if isinstance(params, chex.Array):
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/typing.py", line 720, in __instancecheck__
    return self.__subclasscheck__(type(obj))
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/typing.py", line 723, in __subclasscheck__
    raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks

Source code

# first create a pruner by some default settings
actor_optimizer = optax.adam(learning_rate=actor_lr)
actor_optimizer = pruner.wrap_optax(actor_optimizer)
actor = Model.create(actor_def,
                             inputs=[actor_key, observations],
                             tx=actor_optimizer)
@classmethod
    def create(cls,
               model_def: nn.Module,
               inputs: Sequence[jnp.ndarray],
               tx: Optional[optax.GradientTransformation] = None) -> 'Model':
        variables = model_def.init(*inputs)

        _, params = variables.pop('params')
       ###########
       # where the bug is
        if tx is not None:
            opt_state = tx.init(params)
        else:
            opt_state = None
        ###############

        return cls(step=1,
                   apply_fn=model_def.apply,
                   params=params,
                   tx=tx,
                   opt_state=opt_state)

    def __call__(self, *args, **kwargs):
        return self.apply_fn({'params': self.params}, *args, **kwargs)

    def apply_gradient(
            self,
            loss_fn: Optional[Callable[[Params], Any]] = None,
            grads: Optional[Any] = None,
            has_aux: bool = True) -> Union[Tuple['Model', Any], 'Model']:
        assert (loss_fn is not None or grads is not None,
                'Either a loss function or grads must be specified.')
        if grads is None:
            grad_fn = jax.grad(loss_fn, has_aux=has_aux)
            if has_aux:
                grads, aux = grad_fn(self.params)
            else:
                grads = grad_fn(self.params)
        else:
            assert (has_aux,
                    'When grads are provided, expects no aux outputs.')

        updates, new_opt_state = self.tx.update(grads, self.opt_state,
                                                self.params)
        new_params = optax.apply_updates(self.params, updates)

        new_model = self.replace(step=self.step + 1,
                                 params=new_params,
                                 opt_state=new_opt_state)
        if has_aux:
            return new_model, aux
        else:
            return new_model

It is a source code for jaxrl in jaxrl/agents/sac/sac_learner.py. BTW, since I am really new to JAX and sparse NN, would you mind providing some guidance (example code provided by ) about constructing the SAC pruning example? It will be highly appreciated.

Verson

jax = 0.4.11
python = 3.9.1
stepp1 commented 1 year ago

Hi. Maybe check #3. It seems that the issue was fixed by upgrading to python 3.11.

evcu commented 1 year ago

+1, thanks @stepp1

Hi. Maybe check #3. It seems that the issue was fixed by upgrading to python 3.11.

This should resolve the issue.

louieworth commented 1 year ago

Thanks