kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

Pre trained weights for transfer learning #165

Closed paramjeet2021 closed 2 years ago

paramjeet2021 commented 2 years ago

I would like to train a model and training it from scratch has constraint such as data availability, cost and time.

Is there a way to use GPTJ-6B weights as pre trained weights and train last few layers of this architecture i.e. use transfer learning, using limited dataset? We normally use this for word embeddings, computer vision tasks, so I am curious to know if transfer learning can be applied over here as well.

kingoflolz commented 2 years ago

See #157