i think it'll start with breaking out the initialization of v into a separate v for each prompt, just before i in trange(n)
v = torch.randn_like(x) * sigma
vs = [v.clone() for _ in prompts] # per-prompt momentum
the complexity comes from taking the step, which generates a new velocity. maybe i need to project the new velocity onto the old one as a component? rotate the velocity with PCA and use the principal basis?
or maybe the hessian is what I need? maybe I could even cheat some and rotate the hessian towards the prompts proportionate to their weights? orient the second order curvature?
i think it'll start with breaking out the initialization of
v
into a separatev
for each prompt, just beforei in trange(n)
the complexity comes from taking the step, which generates a new velocity. maybe i need to project the new velocity onto the old one as a component? rotate the velocity with PCA and use the principal basis?