atomicarchitects / phonax

MIT License
16 stars 1 forks source link

Problem running Tutorial_new_model_training.ipynb #1

Open VinaySingh561 opened 2 months ago

VinaySingh561 commented 2 months ago

Hi, I am trying to run Tutorial_new_model_training but getting an error in step 3 [construct and initialize the NequIP energy model]. Please help me resolving the follwing error : `--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[8], line 1 ----> 1 model_fn, params, num_message_passing = NequIP_JAXMD_model( 2 r_max=r_max, 3 atomic_energies_dict={}, 4 train_graphs=train_loader.graphs, 5 initialize_seed=config["model"]["seed"], 6 num_species = config["model"]["num_species"], 7 use_sc = True, 8 graph_net_steps = config["model"]["num_layers"], 9 hidden_irreps = config["model"]["internal_irreps"], 10 nonlinearities = {'e': 'swish', 'o': 'tanh'}, 11 save_dir_name = save_dir_name, 12 reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None, 13 ) 15 print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params))) 17 predictor = jax.jit( 18 lambda w, g: predict_energy_forces_stress(lambda x: model_fn(w, x), g) 19 )

File ~/phonax/phonax/phonax/nequip_model.py:716, in NequIP_JAXMD_model(r_max, atomic_energies_dict, train_graphs, initialize_seed, scaling, atomic_energies, avg_num_neighbors, avg_r_min, num_species, path_normalization, gradient_normalization, learnable_atomic_energies, radial_basis, radial_envelope, save_dir_name, reload, **kwargs) 713 return node_energies 715 if (initializeseed is not None) and reload is None: --> 716 params = jax.jit(model.init)( 717 jax.random.PRNGKey(initialize_seed), 718 jnp.zeros((1, 3)), 719 jnp.array([16]), 720 jnp.array([0]), 721 jnp.array([0]), 722 ) 723 elif reload is not None: 724 with open(f"{reload}/params.pkl", "rb") as f:

[... skipping hidden 11 frame]

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, kwargs) 165 def init_fn(*args, *kwargs) -> hk.MutableParams: --> 166 params, state = f.init(args, kwargs) 167 if state: 168 raise base.NonEmptyStateError( 169 "If your transformed function uses hk.{get,set}_state then use " 170 "hk.transform_with_state.")

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, *kwargs) 420 with base.new_context(rng=rng) as ctx: 421 try: --> 422 f(args, **kwargs) 423 except jax.errors.UnexpectedTracerError as e: 424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

File ~/phonax/phonax/phonax/nequip_model.py:696, in NequIP_JAXMDmodel..model(vectors, node_z, senders, receivers) 689 if hk.running_init(): 690 logging.info( 691 "model: " 692 f"hidden_irreps={nequip.hidden_irreps} " 693 f"sh_irreps={nequip.sh_irreps} ", 694 ) --> 696 contributions = nequip( 697 vectors, node_z, senders, receivers 698 ) # [n_nodes, num_interactions, 0e] 699 node_energies = contributions[:, 0] 701 node_energies = mean + std * node_energies

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, *kwargs) 461 if method_name != "call": 462 f = jax.named_call(f, name=method_name) --> 464 out = f(args, kwargs) 466 # Module names are set in the constructor. If f is the constructor then 467 # its name will only be set after** f has run. For methods other 468 # than __init__ we need the name before running in order to wrap their 469 # execution with named_call. 470 if module_name is None:

File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, kwds) 76 @wraps(func) 77 def inner(*args, *kwds): 78 with self._recreate_cm(): ---> 79 return func(args, kwds)

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, *kwargs) 303 """Runs any method interceptors or the original method.""" 304 if not interceptor_stack: --> 305 return bound_method(args, **kwargs) 307 ctx = MethodContext(module=self, 308 method_name=method_name, 309 orig_method=bound_method, 310 orig_class=orig_class) 311 interceptor_stack_copy = interceptor_stack.clone()

File ~/phonax/phonax/phonax/nequip_model.py:466, in NequIPEnergyModel.call(self, vectors, nodespecie, senders, receivers) 464 # convolutions 465 for in range(self.graph_net_steps): --> 466 h_node = NequIPConvolution( 467 hidden_irreps=hidden_irreps, 468 use_sc=self.use_sc, 469 nonlinearities=self.nonlinearities, 470 radial_net_nonlinearity=self.radial_net_nonlinearity, 471 radial_net_n_hidden=self.radial_net_n_hidden, 472 radial_net_n_layers=self.radial_net_n_layers, 473 num_basis=self.num_basis, 474 avg_num_neighbors=self.avg_num_neighbors, 475 scalar_mlp_std=self.scalar_mlp_std 476 )(h_node, 477 node_attrs, 478 edge_sh, 479 edge_src, 480 edge_dst, 481 embedded_dr_edge 482 ) 484 # output block, two Linears that decay dimensions from h to h//2 to 1 485 for mul, ir in h_node.irreps:

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:464, in wrap_method..wrapped(self, *args, *kwargs) 461 if method_name != "call": 462 f = jax.named_call(f, name=method_name) --> 464 out = f(args, kwargs) 466 # Module names are set in the constructor. If f is the constructor then 467 # its name will only be set after** f has run. For methods other 468 # than __init__ we need the name before running in order to wrap their 469 # execution with named_call. 470 if module_name is None:

File ~/.conda/envs/phonax/lib/python3.10/contextlib.py:79, in ContextDecorator.call..inner(*args, kwds) 76 @wraps(func) 77 def inner(*args, *kwds): 78 with self._recreate_cm(): ---> 79 return func(args, kwds)

File ~/.conda/envs/phonax/lib/python3.10/site-packages/haiku/_src/module.py:305, in run_interceptors(bound_method, method_name, self, orig_class, *args, *kwargs) 303 """Runs any method interceptors or the original method.""" 304 if not interceptor_stack: --> 305 return bound_method(args, **kwargs) 307 ctx = MethodContext(module=self, 308 method_name=method_name, 309 orig_method=bound_method, 310 orig_class=orig_class) 311 interceptor_stack_copy = interceptor_stack.clone()

File ~/phonax/phonax/phonax/nequip_model.py:336, in NequIPConvolution.call(self, node_features, node_attributes, edge_sh, edge_src, edge_dst, edge_embedded) 333 # self-connection, similar to a resnet-update that sums the output from 334 # the TP to chemistry-weighted h 335 if self.use_sc: --> 336 h = h + self_connection 338 # gate nonlinearity, applied to gate data, consisting of: 339 # a) regular scalars, 340 # b) gate scalars, and 341 # c) non-scalars to be gated 342 # in this order 343 gate_fn = partial( 344 e3nn.gate, 345 even_act=get_nonlinearity_by_name(self.nonlinearities['e']), (...) 348 odd_gate_act=get_nonlinearity_by_name(self.nonlinearities['o']) 349 )

File ~/.conda/envs/phonax/lib/python3.10/site-packages/e3nn_jax/_src/irreps_array.py:311, in IrrepsArray.add(self, other) 306 raise ValueError( 307 f"IrrepsArray({self.irreps}, shape={self.shape}) + scalar is not equivariant." 308 ) 310 if self.irreps != other.irreps: --> 311 raise ValueError( 312 f"IrrepsArray({self.irreps}, shape={self.shape}) + IrrepsArray({other.irreps}) is not equivariant." 313 ) 315 zero_flags = tuple(x and y for x, y in zip(self.zero_flags, other.zero_flags)) 316 chunks = None

ValueError: IrrepsArray(36x0e+12x1o+8x2e, shape=(1, 112)) + IrrepsArray(16x0e+12x0e+8x0e+12x1o+8x2e) is not equivariant.`

AI4TE commented 1 month ago

I had the same problem. Did you solve it?

AI4TE commented 1 month ago

After downgrading these libraries, the errors no longer occurred: e3nn-jax==0.20.6, dm-haiku==0.0.12, jax==0.4.31, jaxlib==0.4.31."

VinaySingh561 commented 1 month ago

Thanks for your help.

On Fri, Oct 4, 2024 at 5:18 AM AI4TE @.***> wrote:

After downgrading these libraries, the errors no longer occurred: e3nn-jax==0.20.6, dm-haiku==0.0.12, jax==0.4.31, jaxlib==0.4.31."

— Reply to this email directly, view it on GitHub https://github.com/atomicarchitects/phonax/issues/1#issuecomment-2392521586, or unsubscribe https://github.com/notifications/unsubscribe-auth/AUQD7SG2GGRDODHT7OYX6YDZZXJUNAVCNFSM6AAAAABOZV3KAOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOJSGUZDCNJYGY . You are receiving this because you authored the thread.Message ID: @.***>