Closed Edenzzzz closed 2 months ago
@Edenzzzz thanks for your interests.
to keep orthogonal during training, i think you just have to call as: https://github.com/stanfordnlp/pyreft/blob/main/pyreft/interventions.py#L25
pytorch reparameterizes the weights by calling orthogonalization every step automatically (please refer to this torch tutorial on different reparameterizations).
on intuition:
NoReFT is another ReFT method which removes orthogonalization. There are a bunch of other variants in the interventions.py
right now -- feel free to check them out.
quick summary of what we found in the ablation study: LoReFT still performs the best yet other methods also work pretty well such as NoReFT which removes orthogonality constraint.
marking this issue as closed for now --- feel free to reopen or open another issue in the future.
Hi, thanks for the inspiring work! I have a confusion though: the code initializes orthogonal projections, but I can't seem to find how it can be kept orthogonal during training. In that case we can't recover the input even if Wh + b is identity? Thanks!