Closed pfeatherstone closed 6 months ago
Correct me if i'm wrong but the only difference between ContinuousTransformerWrapper
and TransformerWrapper
is one uses a nn.Linear
layer to project inputs while the other uses an nn.Embedding
layer at the input. Those should be the only differences right?
In which case, can't we delete ContinuousTransformerWrapper
and simply make:
dim_in
dim_out
num_tokens
all optional?
Add some asserts like:
assert dim_im ^ num_tokens, "either project input or embed tokens. not both"
assert dim_out ^ num_tokens, "either project output or predict logits. not both"
Something like that. Then you can determine if it behaves more like ContinuousTransformerWrapper
or like normal TransformerWrapper
at runtime.
Also, it's my understanding that at the output, dim_out
and num_tokens
are doing the same thing. Either way, you use a nn.Linear
layer. So you probs don't even need dim_out
.
Anyway, this will stop me creating issues like : "can you add feature X from TransformerWrapper
to ContinuousTransformerWrapper
:)
@pfeatherstone oops, yea i added it
hmm, maybe at a later date, for now let us keep it separate
@pfeatherstone that is the only difference within this wrapper
other differences exist in the loss and sampling, for its autoregressive wrapper
@lucidrains could you add
memmask
to ContinuousTransformerWrapper please? Thanks