Open AlphaBetaGamma96 opened 1 year ago
So it seems the solution is to place the vmap
called within the memory_efficient_fusion
call like so,
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)))
then just call,
ps_elocal_fusion(params, x) #works now.
Although, it's about an order of magnitude slower than the non-memory_efficient_fusion
version.
ps_elocal = vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)) #0.454 (s)
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0))) #5.804 (s)
EDIT: versions for reference;
PyTorch version: 1.13.0.dev20220820
CUDA version: 11.6
FuncTorch version: 0.3.0a0+86a9049
cc @Chillee @anijain2305
Any thoughts? In particular re: why memory_efficient_fusion made the the final case slower
I thought I'd also mention that memory_efficient_fusion
fails if a scalar is included. For example, using this function (which differs from the original value of -0.5 *
factor)
def local_kinetic_from_log_vmap(params, x):
d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
_local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-4,-2).sum() + dlogabs_dx.pow(2).sum())
return _local_kinetic
returns the following error,
RuntimeError: aten::_to_copy() Expected a value of type 'Tensor' for argument 'self' but instead found type 'float'.
Position: 0
Value: -0.5
Declaration: aten::_to_copy(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int? memory_format=None) -> Tensor
Cast error details: Unable to cast -0.5 to Tensor
Hi All,
I've tried improving the speed of my code via using
memory_efficient_fusion
, however, it leads toTensor.requires_grad_()
error and I have no idea why. The error is as follows,I've attached a 'minimal' reproducible example of this behaviour below. I've tried a few different things but nothing's seems to have worked. I did see in #840
memory_efficient_fusion
is done within a context manager, however, when using that I get the same error.Thanks in advance!
EDIT: When I tried running it, it tried to use the
networkx
package but that wasn't installed by default. So, I had to manually install that (which wasn't a problem), just not sure if installing from source should also include install those packages as well!