Open VinaySingh561 opened 2 months ago
I had the same problem. Did you solve it?
After downgrading these libraries, the errors no longer occurred: e3nn-jax==0.20.6, dm-haiku==0.0.12, jax==0.4.31, jaxlib==0.4.31."
Thanks for your help.
On Fri, Oct 4, 2024 at 5:18 AM AI4TE @.***> wrote:
After downgrading these libraries, the errors no longer occurred: e3nn-jax==0.20.6, dm-haiku==0.0.12, jax==0.4.31, jaxlib==0.4.31."
— Reply to this email directly, view it on GitHub https://github.com/atomicarchitects/phonax/issues/1#issuecomment-2392521586, or unsubscribe https://github.com/notifications/unsubscribe-auth/AUQD7SG2GGRDODHT7OYX6YDZZXJUNAVCNFSM6AAAAABOZV3KAOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOJSGUZDCNJYGY . You are receiving this because you authored the thread.Message ID: @.***>
Hi, I am trying to run Tutorial_new_model_training but getting an error in step 3 [construct and initialize the NequIP energy model]. Please help me resolving the follwing error : `--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[8], line 1 ----> 1 model_fn, params, num_message_passing = NequIP_JAXMD_model( 2 r_max=r_max, 3 atomic_energies_dict={}, 4 train_graphs=train_loader.graphs, 5 initialize_seed=config["model"]["seed"], 6 num_species = config["model"]["num_species"], 7 use_sc = True, 8 graph_net_steps = config["model"]["num_layers"], 9 hidden_irreps = config["model"]["internal_irreps"], 10 nonlinearities = {'e': 'swish', 'o': 'tanh'}, 11 save_dir_name = save_dir_name, 12 reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None, 13 ) 15 print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params))) 17 predictor = jax.jit( 18 lambda w, g: predict_energy_forces_stress(lambda x: model_fn(w, x), g) 19 )
File ~/phonax/phonax/phonax/nequip_model.py:716, in NequIP_JAXMD_model(r_max, atomic_energies_dict, train_graphs, initialize_seed, scaling, atomic_energies, avg_num_neighbors, avg_r_min, num_species, path_normalization, gradient_normalization, learnable_atomic_energies, radial_basis, radial_envelope, save_dir_name, reload, **kwargs) 713 return node_energies 715 if (initializeseed is not None) and reload is None: --> 716 params = jax.jit(model.init)( 717 jax.random.PRNGKey(initialize_seed), 718 jnp.zeros((1, 3)), 719 jnp.array([16]), 720 jnp.array([0]), 721 jnp.array([0]), 722 ) 723 elif reload is not None: 724 with open(f"{reload}/params.pkl", "rb") as f:
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, kwargs)
165 def init_fn(*args, *kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(args, kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses
hk.{get,set}_state
then use " 170 "hk.transform_with_state
.")File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, *kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/phonax/phonax/phonax/nequip_model.py:696, in NequIP_JAXMDmodel..model (vectors, node_z, senders, receivers)
689 if hk.running_init():
690 logging.info(
691 "model: "
692 f"hidden_irreps={nequip.hidden_irreps} "
693 f"sh_irreps={nequip.sh_irreps} ",
694 )
--> 696 contributions = nequip(
697 vectors, node_z, senders, receivers
698 ) # [n_nodes, num_interactions, 0e]
699 node_energies = contributions[:, 0]
701 node_energies = mean + std * node_energies
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, *kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(args, kwargs)
466 # Module names are set in the constructor. If
f
is the constructor then 467 # its name will only be set after**f
has run. For methods other 468 # than__init__
we need the name before running in order to wrap their 469 # execution withnamed_call
. 470 if module_name is None:File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, kwds)
76 @wraps(func)
77 def inner(*args, *kwds):
78 with self._recreate_cm():
---> 79 return func(args, kwds)
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, *kwargs) 303 """Runs any method interceptors or the original method.""" 304 if not interceptor_stack: --> 305 return bound_method(args, **kwargs) 307 ctx = MethodContext(module=self, 308 method_name=method_name, 309 orig_method=bound_method, 310 orig_class=orig_class) 311 interceptor_stack_copy = interceptor_stack.clone()
File ~/phonax/phonax/phonax/nequip_model.py:466, in NequIPEnergyModel.call(self, vectors, nodespecie, senders, receivers) 464 # convolutions 465 for in range(self.graph_net_steps): --> 466 h_node = NequIPConvolution( 467 hidden_irreps=hidden_irreps, 468 use_sc=self.use_sc, 469 nonlinearities=self.nonlinearities, 470 radial_net_nonlinearity=self.radial_net_nonlinearity, 471 radial_net_n_hidden=self.radial_net_n_hidden, 472 radial_net_n_layers=self.radial_net_n_layers, 473 num_basis=self.num_basis, 474 avg_num_neighbors=self.avg_num_neighbors, 475 scalar_mlp_std=self.scalar_mlp_std 476 )(h_node, 477 node_attrs, 478 edge_sh, 479 edge_src, 480 edge_dst, 481 embedded_dr_edge 482 ) 484 # output block, two Linears that decay dimensions from h to h//2 to 1 485 for mul, ir in h_node.irreps:
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, *kwargs)
461 if method_name != "call":
462 f = jax.named_call(f, name=method_name)
--> 464 out = f(args, kwargs)
466 # Module names are set in the constructor. If
f
is the constructor then 467 # its name will only be set after**f
has run. For methods other 468 # than__init__
we need the name before running in order to wrap their 469 # execution withnamed_call
. 470 if module_name is None:File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, kwds)
76 @wraps(func)
77 def inner(*args, *kwds):
78 with self._recreate_cm():
---> 79 return func(args, kwds)
File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, *kwargs) 303 """Runs any method interceptors or the original method.""" 304 if not interceptor_stack: --> 305 return bound_method(args, **kwargs) 307 ctx = MethodContext(module=self, 308 method_name=method_name, 309 orig_method=bound_method, 310 orig_class=orig_class) 311 interceptor_stack_copy = interceptor_stack.clone()
File ~/phonax/phonax/phonax/nequip_model.py:336, in NequIPConvolution.call(self, node_features, node_attributes, edge_sh, edge_src, edge_dst, edge_embedded) 333 # self-connection, similar to a resnet-update that sums the output from 334 # the TP to chemistry-weighted h 335 if self.use_sc: --> 336 h = h + self_connection 338 # gate nonlinearity, applied to gate data, consisting of: 339 # a) regular scalars, 340 # b) gate scalars, and 341 # c) non-scalars to be gated 342 # in this order 343 gate_fn = partial( 344 e3nn.gate, 345 even_act=get_nonlinearity_by_name(self.nonlinearities['e']), (...) 348 odd_gate_act=get_nonlinearity_by_name(self.nonlinearities['o']) 349 )
File ~/.conda/envs/phonax/lib/python3.10/site-packages/e3nn_jax/_src/irreps_array.py:311, in IrrepsArray.add(self, other) 306 raise ValueError( 307 f"IrrepsArray({self.irreps}, shape={self.shape}) + scalar is not equivariant." 308 ) 310 if self.irreps != other.irreps: --> 311 raise ValueError( 312 f"IrrepsArray({self.irreps}, shape={self.shape}) + IrrepsArray({other.irreps}) is not equivariant." 313 ) 315 zero_flags = tuple(x and y for x, y in zip(self.zero_flags, other.zero_flags)) 316 chunks = None
ValueError: IrrepsArray(36x0e+12x1o+8x2e, shape=(1, 112)) + IrrepsArray(16x0e+12x0e+8x0e+12x1o+8x2e) is not equivariant.`