phbradley / alphafold_finetune

Python code for fine-tuning AlphaFold to perform protein-peptide binding predictions
Apache License 2.0
132 stars 17 forks source link

Unable to execute simple finetune examples due to jnp.linalg.norm error #4

Open snerligit opened 1 year ago

snerligit commented 1 year ago

Please see the error I am facing while running simple finetune command given in Readme:

`done importing WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) binder_intercepts: [0.80367635, 0.43373787] cmd: /gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py --data_dir /gstore/scratch/u/nerlis/alphafold_finetune/alphafold/data --binder_intercepts 0.80367635 --binder_intercepts 0.43373787 --freeze_binder --train_dataset /gstore/scratch/u/nerlis/alphafold_finetune/examples/tiny_pmhc_finetune/tiny_example_train.tsv --valid_dataset /gstore/scratch/u/nerlis/alphafold_finetune/examples/tiny_pmhc_finetune/tiny_example_valid.tsv local_device: cpu ng033 model_name: model_2_ptm outprefix: testrun WARNING:tensorflow:From /gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 pae: [[0.]] initial binder params: {'PaeBinderClassifier': {'x_intercept': DeviceArray([[0.80367637, 0.43373787]], dtype=float32), 'slope': DeviceArray([-7.901963], dtype=float32)}} create_batch_from_dataset_row: DRB1_0901_True_SVAYKAAVGATPEAK_2 unk WARNING chainbreak: B 29 31 4.174218489729545 examples/tiny_pmhc_finetune/templates/6qzc_MH2_DRA_01010101_DRB1_01010101.pdb WARNING chainbreak: A 74 275 11.303902954289725 examples/tiny_pmhc_finetune/natives/run135_batch_0660DRB1_0901_True_SVAYKAAVGATPEAK_2_model_1_model_1.pdb WARNING chainbreak: A 359 560 17.32551615392742 examples/tiny_pmhc_finetune/natives/run135_batch_0660DRB1_0901_True_SVAYKAAVGATPEAK_2_model_1_model_1.pdb train_epoch: 0 batch: 0 binder_params: {'PaeBinderClassifier': {'slope': DeviceArray([-7.901963], dtype=float32), 'x_intercept': DeviceArray([[0.80367637, 0.43373787]], dtype=float32)}} not setting num_iter_recycling!!! will do 3 recycles 2023-02-15 16:04:11.006011: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65]


[Compiling module pmap_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.


2023-02-15 16:09:49.975817: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 7m38.970211666s


[Compiling module pmap_train_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.


pae: Traced<ShapedArray(float32[1])>with<JVPTrace(level=2/1)> with primal = Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=0/1)> tangent = Traced<ShapedArray(float32[1])>with<JaxprTrace(level=1/1)> with pval = (ShapedArray(float32[1]), None) recipe = JaxprEqnRecipe(eqn_id=<object object at 0x2aacf2d691e0>, in_tracers=(Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[1]):JaxprTrace(level=1/1)>), out_tracer_refs=[<weakref at 0x2aad0069b540; to 'JaxprTracer' at 0x2aad0069bd60>], out_avals=[ShapedArray(float32[1])], primitive=xla_call, params={'device': None, 'backend': None, 'name': 'true_divide', 'donated_invars': (False, False), 'inline': True, 'keep_unused': False, 'call_jaxpr': { lambda ; a:f32[] b:f32[1]. let c:f32[1] = div b a in (c,) }}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x2aad04505d30>, name_stack=NameStack(stack=(Transform(name='jvp'), Scope(name='PaeBinderClassifier'))))) train_epoch_n= 0 0 loss= 0.84698945 structure_flag: False lddt_ca= 1.0 fape= 0.04288465 binder_probs= [[0.4951988 0.5048012]] binder_loss= [0.68359053] peptide_plddt= 108.62619 binder_features= [[0.4313074]] binder_labels= [array([[0., 1.]])] binder_params= {'PaeBinderClassifier': {'slope': DeviceArray([-7.901963], dtype=float32), 'x_intercept': DeviceArray([[0.80367637, 0.43373787]], dtype=float32)}} grad accumulate: 1 0 grad update! 1 1 Traceback (most recent call last): File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 786, in grads_sum = norm_grads_per_example(grads_sum, File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 380, in norm_grads_per_example total_grad_norm = jnp.linalg.norm([jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads]) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/api.py", line 622, in cache_miss execute = dispatch._xla_call_impllazy(fun, *tracers, params) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy return xla_callable(fun, device, backend, name, donated_invars, keep_unused, File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun ans = call(fun, args) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars, False, File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, kwargs) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2 ans = fun.call_wrapped(intracers) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 490, in norm _check_arraylike("jnp.linalg.norm", x) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike raise TypeError(msg.format(fun_name, type(arg), pos)) jax._src.traceback_util.UnfilteredStackTrace: TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 786, in grads_sum = norm_grads_per_example(grads_sum, File "/gstore/scratch/u/nerlis/alphafold_finetune/run_finetuning.py", line 380, in norm_grads_per_example total_grad_norm = jnp.linalg.norm([jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads]) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/linalg.py", line 490, in norm _check_arraylike("jnp.linalg.norm", x) File "/gstore/home/nerlis/anaconda3/envs/tcrdock_test/lib/python3.8/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike raise TypeError(msg.format(fun_name, type(arg), pos)) TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.`

ljq0811 commented 1 year ago

I have the same problem, have you solved it please?

phbradley commented 1 year ago

Hi there, I'm sorry for missing this issue originally, and I appreciate that installing the various dependencies is challenging. Two things might help:

Thanks, Phil