Resolved the issue of continuous GPU memory growth by altering weights = attn to weights = attn.detach().
Enhanced execution speed, decreasing memory consumption, by substituting the masked Multi-Layer Perceptron (MLP) operation within FeatureEmbed with sparse matrix multiplication via the creation of a new SparseLinear class.
No need for torch==1.7.1, it works perfectly with torch==2.1.0.
weights = attn
toweights = attn.detach()
.FeatureEmbed
with sparse matrix multiplication via the creation of a newSparseLinear
class.torch==1.7.1
, it works perfectly withtorch==2.1.0
.