lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Adding memmask to ContinuousTransformerWrapper #227

Closed pfeatherstone closed 6 months ago

pfeatherstone commented 6 months ago

@lucidrains could you add memmask to ContinuousTransformerWrapper please? Thanks

pfeatherstone commented 6 months ago

Correct me if i'm wrong but the only difference between ContinuousTransformerWrapper and TransformerWrapper is one uses a nn.Linear layer to project inputs while the other uses an nn.Embedding layer at the input. Those should be the only differences right?

In which case, can't we delete ContinuousTransformerWrapper and simply make:

all optional?

Add some asserts like:

Something like that. Then you can determine if it behaves more like ContinuousTransformerWrapper or like normal TransformerWrapper at runtime.

Also, it's my understanding that at the output, dim_out and num_tokens are doing the same thing. Either way, you use a nn.Linear layer. So you probs don't even need dim_out.

Anyway, this will stop me creating issues like : "can you add feature X from TransformerWrapper to ContinuousTransformerWrapper :)

lucidrains commented 6 months ago

@pfeatherstone oops, yea i added it

hmm, maybe at a later date, for now let us keep it separate

lucidrains commented 6 months ago

@pfeatherstone that is the only difference within this wrapper

other differences exist in the loss and sampling, for its autoregressive wrapper