aai-institute / pyDVL

pyDVL is a library of stable implementations of algorithms for data valuation and influence function computation
https://pydvl.org
GNU Lesser General Public License v3.0
89 stars 9 forks source link

Error while trying to run "compute_influences" for TorchLogisticRegression #441

Closed ntheol closed 9 months ago

ntheol commented 9 months ago

Hello,

I tried to run sythetic dataset notebook and in the cell where the influences are computed I get this error: { "name": "ValueError", "message": "Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])) is deprecated. Please ensure they have the same size.", "stack": "--------------------------------------------------------------------------- ValueError Traceback (most recent call last) C:\Users\NIKOLA~1\AppData\Local\Temp/ipykernel_9640/2616014207.py in 4 ) 5 ----> 6 influence_values = compute_influences( 7 differentiable_model=TorchTwiceDifferentiable(model, F.binary_cross_entropy), 8 training_data=train_data_loader,

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\general.py in compute_influences(differentiable_model, training_data, test_data, input_data, inversion_method, influence_type, hessian_regularization, progress, **kwargs) 307 test_data = deepcopy(training_data) 308 --> 309 influencefactors, = compute_influence_factors( 310 differentiable_model, 311 training_data,

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\general.py in compute_influence_factors(model, training_data, test_data, inversion_method, hessian_perturbation, progress, **kwargs) 117 rhs = cat(list(test_grads())) 118 --> 119 return solve_hvp( 120 inversion_method, 121 model,

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\inversion.py in solve_hvp(inversion_method, model, training_data, b, hessian_perturbation, **kwargs) 65 \"\"\" 66 ---> 67 return InversionRegistry.call( 68 inversion_method, 69 model,

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\inversion.py in call(cls, inversion_method, model, training_data, b, hessian_perturbation, kwargs) 201 \"\"\" 202 --> 203 return cls.get(type(model), inversion_method)( 204 model, training_data, b, hessian_perturbation, kwargs 205 )

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\inversion.py in wrapper(*args, kwargs) 154 @functools.wraps(func) 155 def wrapper(*args, *kwargs): --> 156 return func(args, kwargs) 157 158 cls.registry[key] = wrapper

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\torch\torch_differentiable.py in solve_linear(model, training_data, b, hessian_perturbation) 523 all_x.append(x) 524 all_y.append(y) --> 525 hessian = model.hessian(torch.cat(all_x), torch.cat(all_y)) 526 matrix = hessian + hessian_perturbation * torch.eye( 527 model.num_params, device=model.device

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\torch\torch_differentiable.py in hessian(self, x, y) 157 p.detach() for p in self.model.parameters() if p.requires_grad 158 ) --> 159 return torch.func.hessian(model_func)(params) 160 161 @staticmethod

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\eager_transforms.py in wrapper_fn(*args) 1126 return jvp_out 1127 -> 1128 results = vmap(push_jvp, randomness=randomness)(basis) 1129 if has_aux: 1130 results, aux = results

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\vmap.py in wrapped(*args, kwargs) 432 433 # If chunk_size is not specified. --> 434 return _flat_vmap( 435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, kwargs 436 )

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\vmap.py in fn(*args, kwargs) 37 def fn(*args, *kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(args, kwargs) 40 return fn 41

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, *kwargs) 617 try: 618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec) --> 619 batched_outputs = func(batched_inputs, **kwargs) 620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) 621 finally:

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\eager_transforms.py in push_jvp(basis) 1117 1118 def push_jvp(basis): -> 1119 output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux) 1120 # output[0] is the output of func(*args) 1121 error_if_complex(\"jacfwd\", output[0], is_input=False)

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\vmap.py in fn(*args, kwargs) 37 def fn(*args, *kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(args, kwargs) 40 return fn 41

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\eager_transforms.py in _jvp_with_argnums(func, primals, tangents, argnums, strict, has_aux) 963 primals = _wrap_all_tensors(primals, level) 964 duals = _replace_args(primals, duals, argnums) --> 965 result_duals = func(*duals) 966 if has_aux: 967 if not (isinstance(result_duals, tuple) and len(result_duals) == 2):

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\eager_transforms.py in wrapper_fn(args) 487 def wrapper_fn(args): 488 error_if_complex(\"jacrev\", args, is_input=True) --> 489 vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux) 490 if has_aux: 491 output, vjp_fn, aux = vjp_out

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\vmap.py in fn(*args, kwargs) 37 def fn(*args, *kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(args, kwargs) 40 return fn 41

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\_functorch\eager_transforms.py in _vjp_with_argnums(func, argnums, has_aux, primals) 289 diff_primals = _slice_argnums(primals, argnums, as_tuple=False) 290 treemap(partial(_create_differentiable, level=level), diff_primals) --> 291 primals_out = func(primals) 292 293 if has_aux:

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\torch\torch_differentiable.py in model_func(param) 152 strict=True, 153 ) --> 154 return self.loss(outputs, y.to(self.device)) 155 156 params = flatten_tensors_to_vector(

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\ n\functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction) 3087 reduction_enum = _Reduction.get_enum(reduction) 3088 if target.size() != input.size(): -> 3089 raise ValueError( 3090 \"Using a target size ({}) that is different to the input size ({}) is deprecated. \" 3091 \"Please ensure they have the same size.\".format(target.size(), input.size())

ValueError: Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])) is deprecated. Please ensure they have the same size." } I am pretty sure that I run the notebook exactly as it is. Let me know If you need any additional infromation

Thank you!

Xuzzo commented 9 months ago

Hello, thanks for the report. Which version of torch are you using?

ntheol commented 9 months ago

Hello, I have torch 2.0.1

AnesBenmerzoug commented 9 months ago

Thanks @ntheol for reporting this issue. We are currently updating and fixing the Influence Functions notebooks so it's highly appreciated. Could you please switch to the fix/435-broken-imports branch and try running the notebook again?

@schroedk Do you perhaps know what could have caused this error?

ntheol commented 9 months ago

@AnesBenmerzoug I run it without any problem from fix/435-broken-imports branch.

AnesBenmerzoug commented 9 months ago

@ntheol That's great! We will hopefully release a new patch version next week. Until then, you should be able to try things out on that branch.

ntheol commented 9 months ago

Thanks!

ntheol commented 9 months ago

Hello again! I tried to run a notebook of my own with a different dataset and I got the same error again unfortunately in the fix branch. Also right now I started getting the error about the TorchTwiceDifferentialble that I didn't have before.

Xuzzo commented 9 months ago

Could you paste the trace for both errors?

ntheol commented 9 months ago

ImportError Traceback (most recent call last) C:\Users\NIKOLA~1\AppData\Local\Temp/ipykernel_7504/193257936.py in 8 import torch.nn.functional as F 9 import matplotlib.pyplot as plt ---> 10 from pydvl.influence import compute_influences, TorchTwiceDifferentiable 11 from support.shapley import ( 12 synthetic_classification_dataset,

ImportError: cannot import name 'TorchTwiceDifferentiable' from 'pydvl.influence' (c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence__init__.py)

ntheol commented 9 months ago

ValueError Traceback (most recent call last) C:\Users\NIKOLA~1\AppData\Local\Temp/ipykernel_4032/3598293409.py in ----> 1 train_influences = compute_influences( 2 TorchTwiceDifferentiable(nn_model, F.binary_cross_entropy), 3 training_data=training_data_loader, 4 test_data=test_data_loader, 5 influence_type="up",

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\general.py in compute_influences(differentiable_model, training_data, test_data, input_data, inversion_method, influence_type, hessian_regularization, progress, **kwargs) 307 test_data = deepcopy(training_data) 308 --> 309 influencefactors, = compute_influence_factors( 310 differentiable_model, 311 training_data,

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\general.py in compute_influence_factors(model, training_data, test_data, inversion_method, hessian_perturbation, progress, **kwargs) 117 rhs = cat(list(test_grads())) 118 --> 119 return solve_hvp( 120 inversion_method, 121 model,

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\inversion.py in solve_hvp(inversion_method, model, training_data, b, hessian_perturbation, **kwargs) 65 """ 66 ---> 67 return InversionRegistry.call( 68 inversion_method, 69 model,

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\inversion.py in call(cls, inversion_method, model, training_data, b, hessian_perturbation, kwargs) 201 """ 202 --> 203 return cls.get(type(model), inversion_method)( 204 model, training_data, b, hessian_perturbation, kwargs 205 )

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\inversion.py in wrapper(*args, kwargs) 154 @functools.wraps(func) 155 def wrapper(*args, *kwargs): --> 156 return func(args, kwargs) 157 158 cls.registry[key] = wrapper

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\torch\torch_differentiable.py in solve_linear(model, training_data, b, hessian_perturbation) 523 all_x.append(x) 524 all_y.append(y) --> 525 hessian = model.hessian(torch.cat(all_x), torch.cat(all_y)) 526 matrix = hessian + hessian_perturbation * torch.eye( 527 model.num_params, device=model.device

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\torch\torch_differentiable.py in hessian(self, x, y) 157 p.detach() for p in self.model.parameters() if p.requires_grad 158 ) --> 159 return torch.func.hessian(model_func)(params) 160 161 @staticmethod

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\eager_transforms.py in wrapper_fn(*args) 1126 return jvp_out 1127 -> 1128 results = vmap(push_jvp, randomness=randomness)(basis) 1129 if has_aux: 1130 results, aux = results

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\vmap.py in wrapped(*args, kwargs) 432 433 # If chunk_size is not specified. --> 434 return _flat_vmap( 435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, kwargs 436 )

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\vmap.py in fn(*args, kwargs) 37 def fn(*args, *kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(args, kwargs) 40 return fn 41

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, *kwargs) 617 try: 618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec) --> 619 batched_outputs = func(batched_inputs, **kwargs) 620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) 621 finally:

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\eager_transforms.py in push_jvp(basis) 1117 1118 def push_jvp(basis): -> 1119 output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux) 1120 # output[0] is the output of func(*args) 1121 error_if_complex("jacfwd", output[0], is_input=False)

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\vmap.py in fn(*args, kwargs) 37 def fn(*args, *kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(args, kwargs) 40 return fn 41

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\eager_transforms.py in _jvp_with_argnums(func, primals, tangents, argnums, strict, has_aux) 963 primals = _wrap_all_tensors(primals, level) 964 duals = _replace_args(primals, duals, argnums) --> 965 result_duals = func(*duals) 966 if has_aux: 967 if not (isinstance(result_duals, tuple) and len(result_duals) == 2):

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\eager_transforms.py in wrapper_fn(args) 487 def wrapper_fn(args): 488 error_if_complex("jacrev", args, is_input=True) --> 489 vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux) 490 if has_aux: 491 output, vjp_fn, aux = vjp_out

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\vmap.py in fn(*args, kwargs) 37 def fn(*args, *kwargs): 38 with torch.autograd.graph.disable_saved_tensors_hooks(message): ---> 39 return f(args, kwargs) 40 return fn 41

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch_functorch\eager_transforms.py in _vjp_with_argnums(func, argnums, has_aux, primals) 289 diff_primals = _slice_argnums(primals, argnums, as_tuple=False) 290 treemap(partial(_create_differentiable, level=level), diff_primals) --> 291 primals_out = func(primals) 292 293 if has_aux:

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\pydvl\influence\torch\torch_differentiable.py in model_func(param) 152 strict=True, 153 ) --> 154 return self.loss(outputs, y.to(self.device)) 155 156 params = flatten_tensors_to_vector(

c:\Users\Nikolas Theol\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction) 3087 reduction_enum = _Reduction.get_enum(reduction) 3088 if target.size() != input.size(): -> 3089 raise ValueError( 3090 "Using a target size ({}) that is different to the input size ({}) is deprecated. " 3091 "Please ensure they have the same size.".format(target.size(), input.size())

ValueError: Using a target size (torch.Size([684646])) that is different to the input size (torch.Size([684646, 1])) is deprecated. Please ensure they have the same size.

Xuzzo commented 9 months ago

It looks like you are not in the correct commit. Could you check again that you pulled the latest version of fix/435-broken-imports?

ntheol commented 9 months ago

Yeah you were right, my fault. Thanks!

schroedk commented 9 months ago

@ntheol the issue is due to using the binary cross entropy function from torch. It matters, if your tensors have shape (n,) or (n, 1). Here a minimal example:

import torch

y_test = torch.rand((5,))
y_pred = torch.rand((5, 1))

# This raises the observed error
torch.nn.functional.binary_cross_entropy(y_test, y_pred)

The DataLoader definition in the notebook did not account for this. It is fixed in #436.