jeizenga / wfalm

Refinements of the WFA alignment algorithm with better complexity
MIT License
26 stars 0 forks source link

[Discussion] Potential linear-space and alternative low-memory WFA #4

Closed lh3 closed 2 years ago

lh3 commented 2 years ago

CC @smarco. This is going to be a long post. I will start with edit distance and then discuss WFA at the end.

Linear-space O(nd) for edit distance

The Myers' O(nd) paper in 1986 gives a linear-space algorithm inspired by the Hirschberg algorithm. I haven't understood Myers' algorithm. Here is my nasty version:

function FindAln(T,P) {
  d = EditDist(T,P); // standard edit distance, without traceback
  if (d < D) return OndAlign(T,P); // standard O(nd) with traceback
  (t,p) = SplitByAln(T,P,floor(d/2)); // find the point in alignment that reaches distance floor(d/2)
  return concat(FindAln(T[0:t],P[0:p]),FindAln(T[t:|T|],P[p:|P|]));
}

To implement SplitByAln(T,P,e), we store the diagonals at edit distance e. For iterations at larger edit distance, we monitor how diagonals can be traced back to diagonals at e. When we reach the end of alignment, we can split the alignment at distance e. The algorithm has the same time and space complexity as EditDist(), which is O(nd) in time and O(n) in space. FindAln() is recursive. The total number of elements on the stack is O(n). The overall space complexity is thus linear. FindAln() effectively runs EditDist() on the full-length strings for four times. I guess Myers' algorithm probably only needs to run twice.

Fast low-memory O(nd) for edit distance

Suppose the edit distance between two strings is d. Similar to SplitByAln(), we can split the full alignment into ceil(d/q) blocks such that each block, except the last block, has edit distance exactly q. Then we can apply the standard O(nd) traceback to each block and then concatenate these partial alignments into the final alignment. The blocking step has O(nd) time complexity and O(d2/q) space complexity. The base alignment step has O(nq) time complexity and O(q2) space complexity. When q<<d, the time spent on the base alignment step will be negligible.

I have implemented this algorithm in my lv89 repo. As is expected, when d=100000 and q=1000, the low-memory mode is as fast as the standard traceback mode while using ~5% of memory.

The linear-space or this low-memory algorithm can't be directly applied to affine-gap penalty. This leads to the next section.

Possible improvement to wfalm

If I understand your preprint correctly, you keep "stripes" of scores but you don't traceback positions between adjacent stripes. If you traceback positions, you can find the CIGAR between two stripes without recomputing all cells. This will be faster and use less memory although the time and space complexity (if you use your dynamic schedule) remain the same.

I believe it is even theoretically possible to implement a linear-space WFA. Nonetheless, linear-space algorithms are typically twice as slow. The low-memory algorithm is probably more practical.

smarco commented 2 years ago

In fact, I believe that DAligner from Myers already used the sampling strategy to save "checkpoints" and retrieve the CIGAR linking them afterwards. However, I never got to find where this strategy was published. He was aware of the problem of memory consumption and he palliated the problem using this sampling-wavefronts technique.

Now, as you already know, WFA is based on O(ND) ideas. As well as the O(ND), the WFA can be adapted to use O(s) memory in O(ns*log(s)). The biWFA (WIP) can work the alignment in both directions (i.e., forward and backwards) using 2 "strippes" of wavefronts from both ends. The delicate part is to compute the juncture/collision between the wavefronts that guarantees the optimal score CIGAR. Beyond that, the approach is similar to that of O(ND) or Hirschberg. This method can be used to divide the problem into smaller parts that are feasible to align using the regular WFA or a low-memory mode.

Said that, in my experience, people tend to end up using an aggressive heuristic strategy (e.g., a long band or some pruning) sacrificing accuracy for a large gain in performance. For that reason, the homology of the input sequences plays a fundamental role. An evaluation using highly similar sequences could make the WFA shine, whereas an evaluation of >30% error rate could make the WFA consume a lot of resources (i.e., memory and time). An O(s) memory can be of great help to reduce the problem to a more manageable size to be later combined with other methods. In any case, it is always better to try and see the results.

smarco commented 2 years ago

@lh3, I was having a closer look at the lv89 repo.

  1. I thought that bounding paths in the wavefront using min-max bounds was too loose to actually work. I have seen that the sequences that you are aligning have a large length difference (i.e., 50kb). In those cases, I can see that this exact pruning makes a lot of sense. Moreover, I don't see why not apply the same pruning to gap-lineal, gap-affine, gap-affine-2p, etc. I will take this as something to fix on WFA2-lib. Thanks.

  2. The alignment strategy you implement in lv89 for each block is based on piggybacking the alignment operations (i.e., CIGAR) as you compute the wavefronts forward. This is the same approach WFA2-lib uses on its low-memory modes but applied to blocks of the wavefront. But, here is my question, whether you use piggybacking or regularWFA+backtrace, you always need to recompute the block. Here, the piggybacking further reduces memory usage, but you still need to recompute the space between stripes. I don't see how you can find the CIGAR between two stripes without recomputing all cells between stripes.

lh3 commented 2 years ago

The effectiveness of strategy 1 is heavily influenced by inputs. It occasionally leads to a noticeable speedup but it is more often a waste of time. You can see that lv89 does this check every 32 cycles to avoid checking too frequently.

but you still need to recompute the space between stripes. I don't see how you can find the CIGAR between two stripes without recomputing all cells between stripes.

It is easier to explain with edit distance (or linear gap penalties). Suppose you know the base alignment. You can split the full alignment into blocks such that each block gives you an edit distance of exactly p. In this case, you know the query and target subsequences involved in each block. You can reconstruct the base alignment by re-aligning the subsequences separately in each block. For each block, the time and space complexity is O(nkp) and O(p2), respectively, where nk is the sequence length in the k-th block. My implementation has an assertion to check if the edit distance from the reconstructed alignment is identical to the one resulted from the blocking step. It works for the few sequence pairs I have tested.

It is harder to apply this strategy to affine gap penalties. In that case, we need to know both sequences and the starting state in each block. I feel wfalm's current implementation might simplify this step, so I posted my suggestion here.

smarco commented 2 years ago

The effectiveness of strategy 1 is heavily influenced by inputs. It occasionally leads to a noticeable speedup [...]

Your benchmarks show that it can have a large performance impact when aligning sequences of different lengths (diff=|N-M|). I guess that triggering the process 1/32 is a balance. WFA2-lib performs similarly with the heuristics (e.g., adaptive wavefront); once every s steps. The ultimate goal should be to reduce from the ends of the wavefront (i.e., [lo...hi]). Maybe, one can estimate the method's effectiveness based on the diff, lo, and hi (i.e., deviation from the finish diagonal). I need to explore this further.

You can reconstruct the base alignment by re-aligning the subsequences separately in each block.

Right.

It is harder to apply this strategy to affine gap penalties. In that case, we need to know both sequences and the starting state in each block. I feel wfalm's current implementation might simplify this step, so I posted my suggestion here.

I guess that the same idea applies for other metrics like gap-affine. The partial alignments (belonging to the optimal solutions) cannot make jumps in score larger than scope=max{O+E,X}. Thus, storing stripes of length scope guarantees that the optimal path has to "make a jump" in that stripe. Therefore, you can reconstruct the optimal path from the last stripe to the first stripe.

jeizenga commented 2 years ago

Sorry to be silent on this until now, I've been largely offline for the last several days.

I experimented early on with a "meet-in-the-middle" algorithm with O(s) memory, but I was unable to get it to work. The correctness proof in the O(ND) paper (Lemma 3) does not seem to apply to WFA, since several of its internal claims depend on the version of edit distance used in that paper, which does not allow substitution edits.

The issue I ran into in my prototype is that the greedy approach of taking as long a match as possible in each iteration can cause the diagonals to shoot past each other in a way that I couldn't figure out how to recover from efficiently. For example, if you align strings with the structure XRRY and XRY then the diagonals starting from the beginning will greedily match XR and the diagonals starting from the end greedily match RY, but the diagonals are offset by |R|. To detect that the wavefronts have met, you need to remember at least as many DP iterations as the penalty of a gap with length |R|. By increasing the length of R, you can cause the two directions to miss each other for any constant number of DP iterations that are maintained in the working set.

@smarco I'm interested to see how you've accomplished a bi-directional version. Given my own experiments, I'm not terribly surprised that it was only possible with a hit to the time complexity. I've worked out an O(s log s) memory variant with O(sN log s) run time (fully implemented but not included in the preprint), but if you've managed the same with O(s) memory, it's probably of even more interest.

My guess is that any bi-directional WFA will require a fully optimal alignment to function reliably (but let me know if this isn't the case for biWFA). My reasoning is that pruning away the wrong diagonals can cause the meet-in-the-middle step to fail catastrophically, whereas the standard WFA only suffers some alignment inaccuracy. One benefit of the O(s log s) memory variant I mentioned is that it has the same behavior as standard WFA in this respect.

RagnarGrootKoerkamp commented 2 years ago

@jeizenga The problem you're running into sounds very similar to the reason why meet-in-the-middle A* doesn't (always) work (well).

MITM for Dijkstra works well, and you can stop as soon as a node is expanded from both ends. This is guaranteed to happen since nodes are explored by order of g (=distance from start/end). For A* (assuming a consistent heuristic), you can still only stop as soon as a node is expanded from both ends, but as you say, it can happen that two long parallel paths miss each other and are only joined 'at the end'.

[edit: disregard the remainder; see the posts below] However, the approach used in Hirschberg'75 and Myers'88 is not to split on equal distance, but instead to align to the middle of one of the sequences (see fig 2 in Myers'88). This way overshooting isn't possible, because the greedy matching is 'broken'/stopped at a fixed column.

I suppose splitting at the sequence midpoint instead of the distance midpoint doesn't give the guaranteed 2x speedup from the smaller 2 * (s/2)^2 term from splitting at halfway-distance, but it has the same complexity in the end.

jeizenga commented 2 years ago

Maybe I'm missing an obvious fix, but the problem seems more fundamental to me here because of the greedy longest-common-extension involved in WFA. For instance, this could be the execution of a forward and backward WFA:

0   8 7 6 5 4 3 2 1 0        
  0                   0      
    0                   0    
      0                   0  
        0 1 2 3 4 5 6 7 8   0

When you take matches greedily, it's possible for the forward WFA to end up "behind" the backward WFA so that they never actually meet (within any fixed constant number of wavefronts).

Lemma 3 from the O(ND) paper goes into some detail proving that for the no-substitutions edit distance, you are still guaranteed to find a cell on an optimal traceback. All that is necessary is that the forward and backward WFA to be within the same diagonal. However, there are parts of the proof that aren't valid if you allow substitution operations. That said, it might be possible to modify the proof to provide a similar guarantee, I haven't spent much time trying.

RagnarGrootKoerkamp commented 2 years ago

(Note: I'm excluding any complications introduced by the gap-affine scoring here. Surely those can be solved, but it will be tricky.)

In this particular case, I don't think there is a problem. As soon as you get to distance 5, you have this situation:

 0 1 2 3 4 5 6 7 8 9 (diagonal numbers)
 0         5 4 3 2 1 0        
   0         .         0      
     0         .         0    
       0         .         0  
         0 1 2 3 4 5         0
           9 8 7 6 5 4 3 2 1 0 (reverse diagonal numbers)

At this point, the wavefronts on diagonal 5 have crossed, because M_{5,5} + M'_{5,n-m-5} = 9 + 9 = 18 >= 14 = n. (There may be an off-by-one there but you get the idea.)

So you don't get an exact collision in a specific location (where M_{s,k} + M'_{s,n-m-k} = n), but anyway you can easily detect when the wavefronts overlap.

More generally, you can look at it this way: WFA is just an efficient model of running Dijkstra on the edit graph, since the order in which WFA expands nodes (by distance to the start) is the same as Dijkstra. We know that bidirectional/MITM Dijkstra is done as soon as a state is expanded from both directions. WFA has 'expanded' a node u(i,j) on diagonal k as soon as M_{s,k} >= i. When M_{s,k} + M'_{s, n-m-k} >= n, this implies that there is some node on diagonal k that was expanded from both directions.

It should also work for splitting at halfway a sequence, with a slight modification: For each diagonal, just stop computing additional values as soon as the furthest reaching point crosses the halfway row/column, and when that happens, check if the furthest reaching point on the same diagonal coming from the other side has also already crossed the column, in which case you found a midpoint.

smarco commented 2 years ago

Ah, the parallelogram! I have it in plenty of my notebook pages. At some point, I thought it was not possible because of this.

I think that the key idea here is to note that the WFA computes f.r. paths with score s for every diagonal. In layman's terms, the most advanced cell in the diagonal having score s. Even if 2 paths don't exactly intersect (i.e., they just overlap), you can figure out the value of the DP-matrix cell where they overlap.

Imagine that you advance both WFs (forward and reverse) and, then, for example (wlog), you increase one step the reverse-WF and they overlap with scores s_f and s_r, for the first time. Then, you would like to know the values of the DP-cells in that reverse-WF's path for the forward-WF. But you don't, because the forward-WF went further and its greedy nature is making things complicated.

      A  A  A  A  A  A  A  A 
   0 12 10  8  0
A     0           0
A        0           0
A           0           0
A              0  8 10 12  0

But, in fact, you do know the value of those cells. Because you know the most advanced cell that has score s_f. Moreover, you know the least advanced cell with score s_f, because matches are scored 0 (i.e., extension on the diagonal). Note that also, you know the f.r. points of 0..s_f-1 before they were intersecting (s < s_f). In fact, you know all the f.r. points in that overlapping diagonal from 0..s_f (because you have computed them, although you don't explicitly need all of them). Also note that, because of the way the WFA penalties are defined, the scores across the diagonal are monotonically increasing. Hence, there is no possibility that a "not-computed" score is the one in the intersection.

      A  A  A  A  A  A  A  A 
   0  8 10 12 14 16 18 20 22
A  8  0  8 10 12 14 16 18 20
A 10  8  0  8 10 12 14 16 18
A 12 10  8  0  8 10 12 14 16
A 14 12 10  8  0  8 10 12 14

Nonetheless, be careful, because they can be "meeting" (as @RagnarGrootKoerkamp likes to say) in a D or I path. So, you need to check those too. Thus, that part of the code is delicate. To make things trickier, if they meet at a D or I path, you have to subtract o (i.e., gap-open) from the score of the intersection. That means, that the algorithm cannot stop at s_f and s_r but it has to go on a few more steps until make sure that the best junction/intersection cannot be improved in score (a few more optimization steps will do, though).

It is important to note that you cannot combine alignments (i.e., CIGARs) from both forward and reverse WFs, because these are not compatible in all situations (surprisingly, in some cases they are). What you have is a point that is in the optimal alignment path (presumably, at score ~ s/2). Then you have to proceed again aligning the remaining parts/sectors like O(ND) or Hirschberg. Hence, the log(s) term. But then again, IMO performing 2 bidirectional WFs (~ 2·(n·s/2 + s^2/4)) seems faster than letting a single forward WF get it all done (~ (n·s + s^2)), because here we have been discussing alignments with high alignment-error (s>>n).

RagnarGrootKoerkamp commented 2 years ago

It is important to note that you cannot combine alignments (i.e., CIGARs) from both forward and reverse WFs, because these are not compatible in all situations

@smarco could you eleborate on this? I must be missing something, since to me intuitively it is always possible to just join the alignments (for gap-linear alignment, anyway).

Then you have to proceed again aligning the remaining parts/sectors like O(ND) or Hirschberg. Hence, the log(s) term.

I don't understand where this log(s) term here and ealier in the thread comes from. As explained in Myers'86, splitting on half distance D gives a recursion T(P, D) = PD + T(P_1, D/2) + D(P_2, D/2) for some P_1+P_2 <= P, and this solves to T(P, D) <= 3PD. Thus, while the number of layers in the recursion is logarithmic, the overall time is still constant. (You could also look at it this way: First you do a full alignment of width D. Then you do a number of smaller alingments covering the entire N, but at width D/2. Then at D/4, .... This sums to 2ND = O(ND) in total.)

smarco commented 2 years ago

could you eleborate on this? I must be missing something, since to me intuitively it is always possible to just join the alignments (for gap-linear alignment, anyway).

Sure. Imagine we want to align Q="AAAAA" against T="AA". The forward-WF computes A_f="MMDDD" and the reverse-WF A_r="DDDMM". There is no breakpoint that allows pasting together a prefix of A_f and a suffix of A_r giving an optimal alignment (excluding empty prefix/suffix). The problem, again, is the extended Ms due to the strategy used by the WFA.

I don't understand where this log(s) term here and ealier in the thread comes from. As explained in Myers'86, splitting on half distance D gives a recursion T(P, D) = PD + T(P_1, D/2) + D(P_2, D/2) for some P_1+P_2 <= P, and this solves to T(P, D) <= 3PD. Thus, while the number of layers in the recursion is logarithmic, the overall time is still constant. (You could also look at it this way: First you do a full alignment of width D. Then you do a number of smaller alingments covering the entire N, but at width D/2. Then at D/4, .... This sums to 2ND = O(ND) in total.)

Ok. I think you are right. My bound was too loose (I didn't do the math, my bad). It can be bounded to prove that is indeed O(ns+s^2), as the original WFA (I did that summation now).

RagnarGrootKoerkamp commented 2 years ago

Sure. Imagine we want to align Q="AAAAA" against T="AA". The forward-WF computes A_f="MMDDD" and the reverse-WF A_r="DDDMM". There is no breakpoint that allows pasting together a prefix of A_f and a suffix of A_r giving an optimal alignment (excluding empty prefix/suffix). The problem, again, is the extended Ms due to the strategy used by the WFA.

Right, if the furthest points overlap, you can't 'just' join the alignments. Still, this may be resolved locally. My idea is that if the overlap has size l>0, you can 'just' remove the first l matches (Ms) from the alignment string of the second half.

So suppose you have this picture again:

 S         V 4 3 2 1 0        
   0         .         0      
     0         W         0    
       0         .         0  
         0 1 2 3 4 U         T=X

We have optimal paths S->U and V->T. The question is whether the optimal path U->T is always a 'simplification' of the path V->T. For now my feeling is that yes, this is true, but I'm not convinced yet. If U->V are matches, this is fine. Otherwise, for any W on V->U, the last step from T to W is to the left. (It's not possible to get to W!=U from below, or the previous wavefront would already have overlapped. Using that the distance d(W, T) is equal for all W, I think this argument can be repeated a few times until a diagonal is reached where there is a run of matching characters of length l.

That means that after the optimal path T->V crosses the row of U in some point X, it consists of a number of matches (move top-left), followed by a number of indels (moves to the left). In that case, it is not worse to just skip the matches and directly go from X to U using indels.

smarco commented 2 years ago

Right, if the furthest points overlap, you can't 'just' join the alignments.

But look at the bright side, after the first bidirectional alignment, you already know a juncture in the optimal alignment path and the exact score of that alignment. For score-only alignments, it is a great improvement.

The question is whether the optimal path U->T is always a 'simplification' of the path V->T

I'm not entirely sure of what you have in mind here. But, if you are willing to renounce the optimal solution, you can always try to "paste" the two paths together somehow (the approximation might be good). Although... I don't like it as much as finding the juncture and proceeding to solve the subproblems (that is very elegant). Also, not so slow. For the few datasets I have tested, it seems to be ~2x slower than the regular WFA (as @lh3 was estimating), using a ridiculous amount of memory.

RagnarGrootKoerkamp commented 2 years ago

You're right: we don't really need a way to paste paths since the entire point is low memory and for that you need the recursion. But I think some theorem about how you could paste them would still be cool, it may give some insight into the structure of the furthest reaching points (that could eventually lead to more speedup).

smarco commented 2 years ago

I think some theorem about how you could paste them would still be cool, it may give some insight into the structure of the furthest reaching points (that could eventually lead to more speedup).

:-) agree.

lh3 commented 2 years ago

I implemented the low-memory algorithm in my first post. I noticed that both wfa and wfalm probably use 5*4 bytes per trackback entry (PS: with 2-piece affine gap penalty). This can be reduced to 1 byte without affecting the performance. I guess wfa and wfalm will also use less memory if they implement a similar strategy.

lh3 commented 2 years ago

I am closing this thread as the linear-space BiWFA is both faster (when long INDELs are not close to sequence ends) and more lightweight than miniwfa. Congrats on the great work!

smarco commented 2 years ago

Thanks, @lh3.Your support of the WFA has been (and will be) key to its success.