Open okhat opened 1 year ago
Hi Omar, thanks for your interest!
GPT-J models use Rotary Position Embedding (RoPE), and from what I've seen, the code looks quite different and requires slightly different implementation. It would be easier to adapt the code for OPT models.
Main changes needed (for OPT):
In _adapt_weights
method:
replace
self.transformer.wpe = GPT2LMHeadModel.from_pretrained(self.config.name_or_path).transformer.wpe
with self.model.decoder.embed_tokens = OPTForCausalLM.from_pretrained(self.config.name_or_path).model.decoder.embed_tokens
Adapting position_ids:
PCW's implementation relies on HF's cache mechanism. In GPT2 models, most of the changes to the input (inputs_ids, attention_mask and position_ids) are handled in the prepare_inputs_for_generation
method. However, in OPT models the position_ids are handled in OPTLearnedPositionalEmbedding
, so you would need to adapt prepare_inputs_for_generation
and OPTLearnedPositionalEmbedding
accordingly.
If you would like to implement these changes, we would be happy to review your PR, or alternatively, you may wait until we prioritize this request.
Hope this helps :smiley:
hi @inbalmai21 I have already wrote the code for OPT and also fixed a minor issue with GPT2 where B >5. I would like to contribute to this repo.
Thank you
Thank you for the great work and release! Will this work with GPT-J (modulo minor edits to the code)?