This PR addresses #32 to allow Fused Ops to also support lora dropout.
The key strategy is to pass the dropout to the LoRA.apply function, and have it call inside the matmul_lora, and return out the dropped out X to the top-level to be saved for the backward
LoRA.apply(.., dropout) : dropout is passed into apply.
matmul_lora(..., dropout) is called with the dropout, in which inside
X = dropout(X)
dropout.X = X # saved for retuning to the calling level
then ctx.save_for_backward(..., dropout.X) is saved for backward (and we will delete it off dropout when it is no more needed)
then inside backward, we will use dropout.X instead of X.
This PR addresses #32 to allow Fused Ops to also support lora dropout.
The key strategy is to pass the
dropout
to theLoRA.apply
function, and have it call inside thematmul_lora
, and return out the dropped outX
to the top-level to be saved for thebackward
LoRA.apply(.., dropout)
:dropout
is passed intoapply
.matmul_lora(..., dropout)
is called with thedropout
, in which insidectx.save_for_backward(..., dropout.X)
is saved for backward (and we will delete it offdropout
when it is no more needed)backward
, we will usedropout.X
instead of X.