dmarx / notebooks

misc notebooks i wanted to put in tracking
17 stars 2 forks source link

[klmc2] per-prompt momentum #5

Open dmarx opened 1 year ago

dmarx commented 1 year ago

i think it'll start with breaking out the initialization of v into a separate v for each prompt, just before i in trange(n)

    v = torch.randn_like(x) * sigma
    vs = [v.clone() for _ in prompts] # per-prompt momentum

the complexity comes from taking the step, which generates a new velocity. maybe i need to project the new velocity onto the old one as a component? rotate the velocity with PCA and use the principal basis?

dmarx commented 1 year ago

or maybe the hessian is what I need? maybe I could even cheat some and rotate the hessian towards the prompts proportionate to their weights? orient the second order curvature?