tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
184.04k stars 74.06k forks source link

Pytorch solves gym environments faster than tensorflow using the same training implementation and network architecture #61825

Open LuisFMCuriel opened 10 months ago

LuisFMCuriel commented 10 months ago

Issue type

Performance

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

tf2.13

Custom code

Yes

OS platform and distribution

Ubuntu 22.04.2 LTS

Mobile device

No response

Python version

3.10.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

11.8

GPU model and memory

T4 GPU

Current behavior?

I have encountered a performance difference between my PyTorch and TensorFlow implementations of the Double Deep Q-Network (DDQN) algorithm in a Gym environment. Both implementations share identical DDQN architectures, exploration routines, and training flows. However, the PyTorch implementation consistently converges faster, requiring fewer episodes to solve the environment.

Details:

Network and Training Flow: Both PyTorch and TensorFlow implementations employ the same dense neural network architecture (two hidden layers of [512, 128]) and training procedures. The timing for weight optimization is nearly identical between the two.

Exploration: Identical exploration strategies are employed in both implementations, ensuring consistency in agent behavior. Both of them use epsilon-greedy exploration.

The optimization code for each framework is: Tensorflow:

@tf.function
def optimize_model(self,
                   experiences,
                   max_gradient_norm = float('inf')):
    states, actions, rewards, next_states, is_terminals = experiences
    #batch_size = len(is_terminals)

    with tf.GradientTape() as tape:

        # We get the argmax (or maximum action index using the online network)
        argmax_a_q_sp = tf.argmax(self.online_model(next_states), axis=1)
        # Then we use the target model to calculate the estimated Q-values
        q_sp = tf.stop_gradient(self.target_model(next_states))

        # And extract the max value using the index gotten with the online model
        max_a_q_sp = tf.expand_dims(tf.gather(q_sp, argmax_a_q_sp, axis = 1)[:,0], axis = 1)
        # Then we start computing the loss value
        target_q_sa = rewards + (self.gamma * max_a_q_sp * (1 - is_terminals))

        # Flatten the column_indices tensor
        actions_ = tf.reshape(actions, [-1])

        # Use tf.range to create row indices
        row_indices = tf.range(actions_.shape[0])
        row_indices = tf.cast(row_indices, tf.int64)
        # Create combined indices
        combined_indices = tf.stack([row_indices, actions_], axis=1)

        # Gather elements from the second tensor using combined indices
        q_sa = tf.gather_nd(self.online_model(states), combined_indices)
        q_sa = tf.reshape(q_sa, (-1, 1))

        #q_sa = tf.gather(self.online_model(states), actions, axis=1)
        td_error = q_sa - target_q_sa
        value_loss = tf.reduce_mean(tf.square(td_error) * 0.5)
    variables = self.online_model.trainable_variables
    gradients = tape.gradient(value_loss, variables)

    self.value_optimizer.apply_gradients(zip(gradients, self.online_model.trainable_variables))

Pytorch:

def optimize_model(self,
                       experiences,
                       max_gradient_norm = float('inf')):

        states, actions, rewards, next_states, is_terminals = experiences
        batch_size = len(is_terminals)
        # We get the argmax (or maximum action index using the online network)
        argmax_a_q_sp = self.online_model(next_states).max(1)[1]
        # Then we use the target model to calculate the estimated Q-values
        q_sp = self.target_model(next_states).detach()
        # And extract the max value using the index gotten with the online model
        max_a_q_sp = q_sp[np.arange(batch_size), argmax_a_q_sp].unsqueeze(1)

        # Then we start computing the loss value
        target_q_sa = rewards + (self.gamma * max_a_q_sp * (1 - is_terminals))
        q_sa = self.online_model(states).gather(1, actions)

        td_error = q_sa - target_q_sa
        value_loss = td_error.pow(2).mul(0.5).mean()
        self.value_optimizer.zero_grad()
        value_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.online_model.parameters(), max_gradient_norm)
        self.value_optimizer.step()

I'm seeking guidance on potential factors that might explain this performance gap. Could variations in internal library optimizations, autograd systems, GPU utilization, or other factors play a role in this discrepancy? Any insights or suggestions for further investigation would be greatly appreciated.

You can access to an example by following this link: Open In Colab

Standalone code to reproduce the issue

You can find the code in this GitHub repository: [GitHub Repository](https://github.com/LuisFMCuriel/SwarmyMcQLearny).

Alternatively, you can easily access an example by following this link to a Colab notebook: 
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LuisFMCuriel/SwarmyMcQLearny/blob/main/notebooks/SwarmyMcQLearny.ipynb)

Relevant log output

No response

sushreebarsa commented 10 months ago

Hello, @LuisFMCuriel ! Could you please have a look at this gist where I tried to replicate the error reported? Please confirm the result. Thank you!

LuisFMCuriel commented 9 months ago

Hi @sushreebarsa ! Thanks for the quick replay. Yes, I confirm the result. Running the example for each (for example for 60 episodes), the tensorflow implementation has an elapsed time of 00:00:37 whereas the pytorch implementation has an elapsed time of only 00:00:09, having roughly the same reward (these numbers might change since seed is not set, but the behaviour is the same every time you run it). Is this a problem with my implementation of the neural gradient optimization? The program gives you the optimization time, and it looks to be approximately the same for both, so I am not sure where is the timing bottleneck.

Thanks again for the time!!