Closed talkhanz closed 1 year ago
The folder structure and some files share code from the Tranception code that @pascalnotin suggested. The APT model is however a pure GPT2 model (GPT2 is subclassed as per the requirement)
The train.py script is used to run the training code. AptTokenizer (from @jamaliki 's merge)is being used for the tokenization (I'm assuming this is handling AA sequences?)
We have checkpoint_steps denoting the interval of epochs to wait before persisting it to disk.
Perplexity is also saved in the metrics dictionary when the checkpoints and model state is saved.
Also added a sample trimmed up version of the uniref50 dataset ( i think 1000 rows) in the protein_lm/datasets/uniref folder. If there is anything missing or incorrect, feel free to reach out
My first instinct is that this is a lot of code. A lot of it seems like boilerplate from transformers. Could we try to just import as much as possible?
My first instinct is that this is a lot of code. A lot of it seems like boilerplate from transformers. Could we try to just import as much as possible?
hmm that makes sense. ill make a minimal version relying on as much imports as possible. Thankyou for the feedback!
Regarding this, is it necessary to include the JAX-possibilities? As far as I got it we wanted to rely on HF and pytorch, so the JAX options will not be used anyway right?
Agree with @Muedi re: JAX -- initial dev should focus on pytorch / HF since that's the stack most colleagues were familiar with. We can revisit later as we transition to training the larger model archi
Thanks @talkhanz. Merging the PR
My first instinct is that this is a lot of code. A lot of it seems like boilerplate from transformers. Could we try to just import as much as possible?
hmm that makes sense. ill make a minimal version relying on as much imports as possible. Thankyou for the feedback!
@pascalnotin I think we want to wait for @talkhanz to clean up the code a bit, no?
addressed most of the requirements as per issue here