AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Add Gemma2-27b #843

Closed ZhaoyueCheng closed 2 months ago

ZhaoyueCheng commented 2 months ago

Description

Add Gemma2-27B to the model after test using KL divergence between golden logits and output logits from Maxtext

Tests

Gemma2-9B logs

INFO 2024-08-24T02:55:08.922897784Z [resource.labels.containerName: jax-tpu] Comparing forward pass for golden data index = 0
INFO 2024-08-24T02:55:08.922971322Z [resource.labels.containerName: jax-tpu] config.global_batch_size_to_train_on=1
INFO 2024-08-24T02:55:08.959527619Z [resource.labels.containerName: jax-tpu] prompt="I love to" raw ids=[ 2 235285 2182 577], logits.shape = (4, 256128)
INFO 2024-08-24T02:55:09.064874814Z [resource.labels.containerName: jax-tpu] ids=[[ 2 235285 2182 577]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]
INFO 2024-08-24T02:56:12.358334773Z [resource.labels.containerName: jax-tpu] golden_logits[0]=array([-29.983728, -19.59722 , -19.804455, ..., -29.88443 , -29.890678,
INFO 2024-08-24T02:56:12.358389803Z [resource.labels.containerName: jax-tpu] -29.887852], dtype=float32)
INFO 2024-08-24T02:56:12.358396874Z [resource.labels.containerName: jax-tpu] full_train_logits[0, 0, :]=array([-29.983795, -19.5924 , -19.85103 , ..., -29.884758, -29.890398,
INFO 2024-08-24T02:56:12.358400854Z [resource.labels.containerName: jax-tpu] -29.888012], dtype=float32)
INFO 2024-08-24T02:56:12.360461065Z [resource.labels.containerName: jax-tpu] Max Numerical Difference 0.13658976554870605
INFO 2024-08-24T02:56:12.554640613Z [resource.labels.containerName: jax-tpu] golden_probabilities[0]=Array([8.9334685e-11, 2.8961672e-06, 2.3540865e-06, ..., 9.8660850e-11,
INFO 2024-08-24T02:56:12.554677806Z [resource.labels.containerName: jax-tpu] 9.8046293e-11, 9.8323835e-11], dtype=float32)
INFO 2024-08-24T02:56:12.555723035Z [resource.labels.containerName: jax-tpu] model_probabilities[0]=Array([8.8320636e-11, 2.8773211e-06, 2.2216029e-06, ..., 9.7515461e-11,
INFO 2024-08-24T02:56:12.555736106Z [resource.labels.containerName: jax-tpu] 9.6967205e-11, 9.7198666e-11], dtype=float32)
INFO 2024-08-24T02:56:12.982979137Z [resource.labels.containerName: jax-tpu] KL divergence = [2.3383072e-04 1.5704559e-06 1.5070462e-05 7.0545022e-05], max KL divergence = 0.00023383072402793914
INFO 2024-08-24T02:56:12.983030062Z [resource.labels.containerName: jax-tpu] Checking KL Divergence between train distribution and golden distribution
INFO 2024-08-24T02:56:13.011343425Z [resource.labels.containerName: jax-tpu] Comparing forward pass for golden data index = 1
INFO 2024-08-24T02:56:13.011382532Z [resource.labels.containerName: jax-tpu] config.global_batch_size_to_train_on=1
INFO 2024-08-24T02:56:13.053889448Z [resource.labels.containerName: jax-tpu] prompt="Today is a" raw ids=[ 2 15528 603 476], logits.shape = (4, 256128)
INFO 2024-08-24T02:56:13.054927675Z [resource.labels.containerName: jax-tpu] ids=[[ 2 15528 603 476]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]
INFO 2024-08-24T02:57:16.749277894Z [resource.labels.containerName: jax-tpu] golden_logits[0]=array([-29.983728, -19.59722 , -19.804455, ..., -29.88443 , -29.890678,
INFO 2024-08-24T02:57:16.749324254Z [resource.labels.containerName: jax-tpu] -29.887852], dtype=float32)
INFO 2024-08-24T02:57:16.749371291Z [resource.labels.containerName: jax-tpu] full_train_logits[0, 0, :]=array([-29.983795, -19.5924 , -19.85103 , ..., -29.884758, -29.890398,
INFO 2024-08-24T02:57:16.749418203Z [resource.labels.containerName: jax-tpu] -29.888012], dtype=float32)
INFO 2024-08-24T02:57:16.751782697Z [resource.labels.containerName: jax-tpu] Max Numerical Difference 0.106536865234375
INFO 2024-08-24T02:57:16.757568997Z [resource.labels.containerName: jax-tpu] golden_probabilities[0]=Array([8.9334685e-11, 2.8961672e-06, 2.3540865e-06, ..., 9.8660850e-11,
INFO 2024-08-24T02:57:16.757592725Z [resource.labels.containerName: jax-tpu] 9.8046293e-11, 9.8323835e-11], dtype=float32)
INFO 2024-08-24T02:57:16.758484693Z [resource.labels.containerName: jax-tpu] model_probabilities[0]=Array([8.8320636e-11, 2.8773211e-06, 2.2216029e-06, ..., 9.7515461e-11,
INFO 2024-08-24T02:57:16.758499401Z [resource.labels.containerName: jax-tpu] 9.6967205e-11, 9.7198666e-11], dtype=float32)
INFO 2024-08-24T02:57:16.768582644Z [resource.labels.containerName: jax-tpu] KL divergence = [2.33830724e-04 1.14613670e-06 2.12977957e-05 1.27497115e-05], max KL divergence = 0.00023383072402793914
INFO 2024-08-24T02:57:16.768616136Z [resource.labels.containerName: jax-tpu] Checking KL Divergence between train distribution and golden distribution
INFO 2024-08-24T02:57:16.768629909Z [resource.labels.containerName: jax-tpu] Comparing forward pass for golden data index = 2
INFO 2024-08-24T02:57:16.768634283Z [resource.labels.containerName: jax-tpu] config.global_batch_size_to_train_on=1
INFO 2024-08-24T02:57:16.839345940Z [resource.labels.containerName: jax-tpu] prompt="What is the" raw ids=[ 2 1841 603 573], logits.shape = (4, 256128)
INFO 2024-08-24T02:57:16.840900605Z [resource.labels.containerName: jax-tpu] ids=[[ 2 1841 603 573]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]
INFO 2024-08-24T02:58:20.211019091Z [resource.labels.containerName: jax-tpu] golden_logits[0]=array([-29.983728, -19.59722 , -19.804455, ..., -29.88443 , -29.890678,
INFO 2024-08-24T02:58:20.211084776Z [resource.labels.containerName: jax-tpu] -29.887852], dtype=float32)
INFO 2024-08-24T02:58:20.211091368Z [resource.labels.containerName: jax-tpu] full_train_logits[0, 0, :]=array([-29.983795, -19.5924 , -19.85103 , ..., -29.884758, -29.890398,
INFO 2024-08-24T02:58:20.211095715Z [resource.labels.containerName: jax-tpu] -29.888012], dtype=float32)
INFO 2024-08-24T02:58:20.215266464Z [resource.labels.containerName: jax-tpu] Max Numerical Difference 0.12917780876159668
INFO 2024-08-24T02:58:20.220954264Z [resource.labels.containerName: jax-tpu] golden_probabilities[0]=Array([8.9334685e-11, 2.8961672e-06, 2.3540865e-06, ..., 9.8660850e-11,
INFO 2024-08-24T02:58:20.220976215Z [resource.labels.containerName: jax-tpu] 9.8046293e-11, 9.8323835e-11], dtype=float32)
INFO 2024-08-24T02:58:20.221868005Z [resource.labels.containerName: jax-tpu] model_probabilities[0]=Array([8.8320636e-11, 2.8773211e-06, 2.2216029e-06, ..., 9.7515461e-11,
INFO 2024-08-24T02:58:20.221883117Z [resource.labels.containerName: jax-tpu] 9.6967205e-11, 9.7198666e-11], dtype=float32)
INFO 2024-08-24T02:58:20.231609851Z [resource.labels.containerName: jax-tpu] KL divergence = [2.3383072e-04 2.6830924e-06 8.2429669e-06 1.2013030e-05], max KL divergence = 0.00023383072402793914

Gemma2-27B logs

INFO 2024-08-24T02:19:38.844676047Z [resource.labels.containerName: jax-tpu] prompt="I love to" raw ids=[ 2 235285 2182 577], logits.shape = (4, 256128)
INFO 2024-08-24T02:19:38.955571970Z [resource.labels.containerName: jax-tpu] ids=[[ 2 235285 2182 577]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]
INFO 2024-08-24T02:21:40.763377342Z [resource.labels.containerName: jax-tpu] golden_logits[0]=array([ 7.635116, 21.448622, -0.50029 , ..., 7.501218, 7.490741,
INFO 2024-08-24T02:21:40.763435646Z [resource.labels.containerName: jax-tpu] 7.638074], dtype=float32)
INFO 2024-08-24T02:21:40.763454931Z [resource.labels.containerName: jax-tpu] full_train_logits[0, 0, :]=array([ 7.6197963 , 21.492254 , -0.48434365, ..., 7.489437 ,
INFO 2024-08-24T02:21:40.763459182Z [resource.labels.containerName: jax-tpu] 7.4893174 , 7.620025 ], dtype=float32)
INFO 2024-08-24T02:21:40.765758918Z [resource.labels.containerName: jax-tpu] Max Numerical Difference 11.813161849975586
INFO 2024-08-24T02:21:40.964641821Z [resource.labels.containerName: jax-tpu] golden_probabilities[0]=Array([1.58940448e-11, 1.58622024e-05, 4.65663563e-15, ...,
INFO 2024-08-24T02:21:40.964680570Z [resource.labels.containerName: jax-tpu] 1.39021945e-11, 1.37572740e-11, 1.59411338e-11], dtype=float32)
INFO 2024-08-24T02:21:40.965612216Z [resource.labels.containerName: jax-tpu] model_probabilities[0]=Array([1.56618295e-11, 1.65796064e-05, 4.73433142e-15, ...,
INFO 2024-08-24T02:21:40.965626899Z [resource.labels.containerName: jax-tpu] 1.37476350e-11, 1.37459835e-11, 1.56654134e-11], dtype=float32)
INFO 2024-08-24T02:21:41.394815935Z [resource.labels.containerName: jax-tpu] KL divergence = [0.0001245 0.00349709 0.03192982 0.08636895], max KL divergence = 0.08636894822120667
INFO 2024-08-24T02:21:41.394913483Z [resource.labels.containerName: jax-tpu] Checking KL Divergence between train distribution and golden distribution
INFO 2024-08-24T02:21:41.433645877Z [resource.labels.containerName: jax-tpu] Comparing forward pass for golden data index = 1
INFO 2024-08-24T02:21:41.433708685Z [resource.labels.containerName: jax-tpu] config.global_batch_size_to_train_on=1
INFO 2024-08-24T02:21:41.471897226Z [resource.labels.containerName: jax-tpu] prompt="Today is a" raw ids=[ 2 15528 603 476], logits.shape = (4, 256128)
INFO 2024-08-24T02:21:41.473050876Z [resource.labels.containerName: jax-tpu] ids=[[ 2 15528 603 476]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]
INFO 2024-08-24T02:23:44.106026713Z [resource.labels.containerName: jax-tpu] golden_logits[0]=array([ 7.635116, 21.448622, -0.50029 , ..., 7.501218, 7.490741,
INFO 2024-08-24T02:23:44.106088540Z [resource.labels.containerName: jax-tpu] 7.638074], dtype=float32)
INFO 2024-08-24T02:23:44.106103025Z [resource.labels.containerName: jax-tpu] full_train_logits[0, 0, :]=array([ 7.6197963 , 21.492254 , -0.48434365, ..., 7.489437 ,
INFO 2024-08-24T02:23:44.106107817Z [resource.labels.containerName: jax-tpu] 7.4893174 , 7.620025 ], dtype=float32)
INFO 2024-08-24T02:23:44.107721642Z [resource.labels.containerName: jax-tpu] Max Numerical Difference 5.898859977722168
INFO 2024-08-24T02:23:44.119024791Z [resource.labels.containerName: jax-tpu] golden_probabilities[0]=Array([1.58940448e-11, 1.58622024e-05, 4.65663563e-15, ...,
INFO 2024-08-24T02:23:44.119055225Z [resource.labels.containerName: jax-tpu] 1.39021945e-11, 1.37572740e-11, 1.59411338e-11], dtype=float32)
INFO 2024-08-24T02:23:44.120063608Z [resource.labels.containerName: jax-tpu] model_probabilities[0]=Array([1.56618295e-11, 1.65796064e-05, 4.73433142e-15, ...,
INFO 2024-08-24T02:23:44.120076241Z [resource.labels.containerName: jax-tpu] 1.37476350e-11, 1.37459835e-11, 1.56654134e-11], dtype=float32)
INFO 2024-08-24T02:23:44.129684548Z [resource.labels.containerName: jax-tpu] KL divergence = [0.0001245 0.00413377 0.02779528 0.06672263], max KL divergence = 0.06672263145446777
INFO 2024-08-24T02:23:44.129725210Z [resource.labels.containerName: jax-tpu] Checking KL Divergence between train distribution and golden distribution
INFO 2024-08-24T02:23:44.129731726Z [resource.labels.containerName: jax-tpu] Comparing forward pass for golden data index = 2
INFO 2024-08-24T02:23:44.129736674Z [resource.labels.containerName: jax-tpu] config.global_batch_size_to_train_on=1
INFO 2024-08-24T02:23:44.167353427Z [resource.labels.containerName: jax-tpu] prompt="What is the" raw ids=[ 2 1841 603 573], logits.shape = (4, 256128)
INFO 2024-08-24T02:23:44.168459346Z [resource.labels.containerName: jax-tpu] ids=[[ 2 1841 603 573]], decoder_segment_ids = [[1. 1. 1. 1.]], decoder_positions= [[0 1 2 3]]
INFO 2024-08-24T02:25:45.133158682Z [resource.labels.containerName: jax-tpu] golden_logits[0]=array([ 7.635116, 21.448622, -0.50029 , ..., 7.501218, 7.490741,
INFO 2024-08-24T02:25:45.133219599Z [resource.labels.containerName: jax-tpu] 7.638074], dtype=float32)
INFO 2024-08-24T02:25:45.133225103Z [resource.labels.containerName: jax-tpu] full_train_logits[0, 0, :]=array([ 7.6197963 , 21.492254 , -0.48434365, ..., 7.489437 ,
INFO 2024-08-24T02:25:45.133229223Z [resource.labels.containerName: jax-tpu] 7.4893174 , 7.620025 ], dtype=float32)
INFO 2024-08-24T02:25:45.134766205Z [resource.labels.containerName: jax-tpu] Max Numerical Difference 3.721693515777588
INFO 2024-08-24T02:25:45.140202194Z [resource.labels.containerName: jax-tpu] golden_probabilities[0]=Array([1.58940448e-11, 1.58622024e-05, 4.65663563e-15, ...,
INFO 2024-08-24T02:25:45.140241423Z [resource.labels.containerName: jax-tpu] 1.39021945e-11, 1.37572740e-11, 1.59411338e-11], dtype=float32)
INFO 2024-08-24T02:25:45.141244353Z [resource.labels.containerName: jax-tpu] model_probabilities[0]=Array([1.56618295e-11, 1.65796064e-05, 4.73433142e-15, ...,
INFO 2024-08-24T02:25:45.141253794Z [resource.labels.containerName: jax-tpu] 1.37476350e-11, 1.37459835e-11, 1.56654134e-11], dtype=float32)
INFO 2024-08-24T02:25:45.151952823Z [resource.labels.containerName: jax-tpu] KL divergence = [0.0001245 0.0040167 0.01386897 0.023892 ], max KL divergence = 0.02389199659228325