tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.76k stars 717 forks source link

Type error in PolicySaver.save() #929

Open anmol438 opened 1 month ago

anmol438 commented 1 month ago

I have a DQN agent with policy of type to train a gym environment (CartPole-v1). I am using tf_agents 0.16.0 and gym 0.23.0 During saving the policy tf_agents.policies.policy_saver.PolicySaver I am having the following error: policy_saver = PolicySaver(agent.policy) policy_saver.save('./policy')

TypeError Traceback (most recent call last) Cell In[8], line 1 ----> 1 policy_saver = PolicySaver(agent.policy) 2 policy_saver.save('./policy')

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tf_agents\policies\policy_saver.py:333, in PolicySaver.init(self, policy, batch_size, use_nest_path_signatures, seed, train_step, input_fn_and_spec, metadata) 326 get_initial_state_fn.get_concrete_function(*get_initial_state_input_specs) 328 train_step_fn = common.function( 329 lambda: saved_policy.train_step 330 ).get_concrete_function() 331 get_metadata_fn = common.function( 332 lambda: saved_policy.metadata --> 333 ).get_concrete_function() 335 batched_time_step_spec = tf.nest.map_structure( 336 lambda spec: add_batch_dim(spec, [batch_size]), policy.time_step_spec 337 ) 338 batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec)

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:1227, in Function.get_concrete_function(self, *args, kwargs) 1225 def get_concrete_function(self, *args, *kwargs): 1226 # Implements PolymorphicFunction.get_concrete_function. -> 1227 concrete = self._get_concrete_function_garbage_collected(args, kwargs) 1228 concrete._garbage_collector.release() # pylint: disable=protected-access 1229 return concrete

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:1197, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs) 1195 if self._variable_creation_config is None: 1196 initializers = [] -> 1197 self._initialize(args, kwargs, add_initializers_to=initializers) 1198 self._initialize_uninitialized_variables(initializers) 1200 if self._created_variables: 1201 # In this case we have created variables on the first call, so we run the 1202 # version which is guaranteed to never create variables.

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:695, in Function._initialize(self, args, kwds, add_initializers_to) 690 self._variable_creation_config = self._generate_scoped_tracing_options( 691 variable_capturing_scope, 692 tracing_compilation.ScopeType.VARIABLE_CREATION, 693 ) 694 # Force the definition of the function for these arguments --> 695 self._concrete_variable_creation_fn = tracing_compilation.trace_function( 696 args, kwds, self._variable_creation_config 697 ) 699 def invalid_creator_scope(*unused_args, **unused_kwds): 700 """Disables variable creation."""

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:178, in trace_function(args, kwargs, tracing_options) 175 args = tracing_options.input_signature 176 kwargs = {} --> 178 concrete_function = _maybe_define_function( 179 args, kwargs, tracing_options 180 ) 182 if not tracing_options.bind_graph_to_function: 183 concrete_function._garbage_collector.release() # pylint: disable=protected-access

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:283, in _maybe_define_function(args, kwargs, tracing_options) 281 else: 282 target_func_type = lookup_func_type --> 283 concrete_function = _create_concrete_function( 284 target_func_type, lookup_func_context, func_graph, tracing_options 285 ) 287 if tracing_options.function_cache is not None: 288 tracing_options.function_cache.add( 289 concrete_function, current_func_context 290 )

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:331, in _create_concrete_function(function_type, type_context, func_graph, tracing_options) 328 tracing_options.function_captures.merge_by_ref_with(graph_capture_container) 330 # Create a new FunctionType including captures and outputs. --> 331 output_type = trace_type.from_value( 332 traced_func_graph.structured_outputs, type_context 333 ) 334 traced_func_type = function_type_lib.FunctionType( 335 function_type.parameters.values(), 336 traced_func_graph.function_captures.capture_types, 337 return_annotation=output_type, 338 ) 340 concrete_function = concrete_function_lib.ConcreteFunction.from_func_graph( 341 traced_func_graph, 342 traced_func_type, (...) 348 shared_func_graph=False, 349 )

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\core\function\trace_type\trace_type_builder.py:144, in from_value(value, context) 142 if context.is_legacy_signature and isinstance(value, trace.TraceType): 143 return value --> 144 elif isinstance(value, trace.SupportsTracingProtocol): 145 generated_type = value.tf_tracing_type(context) 146 if not isinstance(generated_type, trace.TraceType):

File c:\Users\iitka.conda\envs\temp\lib\site-packages\typing_extensions.py:647, in _ProtocolMeta.instancecheck(cls, instance) 645 for attr in cls.__protocol_attrs__: 646 try: --> 647 val = inspect.getattr_static(instance, attr) 648 except AttributeError: 649 break

File c:\Users\iitka.conda\envs\temp\lib\inspect.py:1743, in getattr_static(obj, attr, default) 1740 dict_attr = _shadowed_dict(klass) 1741 if (dict_attr is _sentinel or 1742 type(dict_attr) is types.MemberDescriptorType): -> 1743 instance_result = _check_instance(obj, attr) 1744 else: 1745 klass = obj

File c:\Users\iitka.conda\envs\temp\lib\inspect.py:1690, in _check_instance(obj, attr) 1688 instance_dict = {} 1689 try: -> 1690 instance_dict = object.getattribute(obj, "dict") 1691 except AttributeError: 1692 pass

TypeError: this dict descriptor does not support '_DictWrapper' objects

What are the possible reasons and how can I resolve it?

BaLinuss commented 1 week ago

Hi, I encountered the same problem with older code that worked previously. Rolling back to Version 0.16.0 made it work again. Note that you also need tf@2.12.0 to make it work.

Hope this works for you!