pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 462 forks source link

Big difference between pytorch xla and pytorch on cpu #4170

Open houghtonweihu opened 1 year ago

houghtonweihu commented 1 year ago

I am training a reinforcement learning model. The same code running on tpu and cpu produces very different outcome. I trained each model for 100 steps. Here is showing the last few step. First tpu(xla) (The first number is step, the second is user id, third max reward, last actual reward):

88 0 92.0 23.0

88 1 58.0 9.0

88 2 34.0 3.0

88 3 33.0 10.0

88 4 35.0 4.0

=====================

89 0 92.0 22.0

89 1 58.0 6.0

89 2 34.0 2.0

89 3 33.0 5.0

89 4 35.0 5.0

=====================

90 0 92.0 17.0

90 1 58.0 4.0

90 2 34.0 2.0

90 3 33.0 3.0

90 4 35.0 3.0

=====================

91 0 92.0 13.0

91 1 58.0 4.0

91 2 34.0 4.0

91 3 33.0 3.0

91 4 35.0 3.0

=====================

92 0 92.0 11.0

92 1 58.0 1.0

92 2 34.0 3.0

92 3 33.0 2.0

92 4 35.0 2.0

=====================

93 0 92.0 8.0

93 1 58.0 5.0

93 2 34.0 3.0

93 3 33.0 2.0

93 4 35.0 1.0

=====================

94 0 92.0 8.0

94 1 58.0 0.0

94 2 34.0 1.0

94 3 33.0 5.0

94 4 35.0 1.0

=====================

95 0 92.0 6.0

95 1 58.0 4.0

95 2 34.0 2.0

95 3 33.0 2.0

95 4 35.0 3.0

=====================

96 0 92.0 8.0

96 1 58.0 3.0

96 2 34.0 4.0

96 3 33.0 4.0

96 4 35.0 4.0

=====================

97 0 92.0 4.0

97 1 58.0 7.0

97 2 34.0 2.0

97 3 33.0 3.0

97 4 35.0 1.0

=====================

98 0 92.0 6.0

98 1 58.0 7.0

98 2 34.0 2.0

98 3 33.0 1.0

98 4 35.0 3.0

=====================

99 0 92.0 2.0

99 1 58.0 4.0

99 2 34.0 2.0

99 3 33.0 1.0

99 4 35.0 1.0

this is not learning at all on tpu with xla.

Now on cpu, the same code:

88 0 92.0 88.0

88 1 58.0 54.0

88 2 34.0 30.0

88 3 33.0 29.0

88 4 35.0 32.0

=====================

89 0 92.0 86.0

89 1 58.0 54.0

89 2 34.0 32.0

89 3 33.0 31.0

89 4 35.0 32.0

=====================

90 0 92.0 84.0

90 1 58.0 57.0

90 2 34.0 31.0

90 3 33.0 30.0

90 4 35.0 34.0

=====================

91 0 92.0 86.0

91 1 58.0 54.0

91 2 34.0 32.0

91 3 33.0 31.0

91 4 35.0 34.0

=====================

92 0 92.0 86.0

92 1 58.0 56.0

92 2 34.0 32.0

92 3 33.0 31.0

92 4 35.0 33.0

=====================

93 0 92.0 89.0

93 1 58.0 52.0

93 2 34.0 31.0

93 3 33.0 31.0

93 4 35.0 35.0

=====================

94 0 92.0 87.0

94 1 58.0 58.0

94 2 34.0 32.0

94 3 33.0 29.0

94 4 35.0 34.0

=====================

95 0 92.0 87.0

95 1 58.0 54.0

95 2 34.0 32.0

95 3 33.0 31.0

95 4 35.0 32.0

=====================

96 0 92.0 86.0

96 1 58.0 55.0

96 2 34.0 30.0

96 3 33.0 29.0

96 4 35.0 32.0

=====================

97 0 92.0 89.0

97 1 58.0 52.0

97 2 34.0 33.0

97 3 33.0 31.0

97 4 35.0 34.0

=====================

98 0 92.0 87.0

98 1 58.0 50.0

98 2 34.0 32.0

98 3 33.0 32.0

98 4 35.0 32.0

=====================

99 0 92.0 90.0

99 1 58.0 52.0

99 2 34.0 32.0

99 3 33.0 32.0

99 4 35.0 34.0

This shows the model is learning to converge on cpu, the same code.

JackCaoG commented 1 year ago

yea, it doesn't seems like it is learning. I don't think it is a TPU issue but maybe something in the model code. Did you have a chance to look at https://github.com/pytorch/xla/blob/master/API_GUIDE.md#running-on-a-single-xla-device (I assume you train on single device).

If you can give a minimal repo the debugging will be easier.