allenai / mmda

multimodal document analysis
Apache License 2.0
158 stars 18 forks source link

Prompted chat-gpt4 to optimize build_graph_index_overlap using numpy … #238

Closed egork520 closed 1 year ago

egork520 commented 1 year ago

…and it came up with workable suggestion

Response from OpenAI:

To rewrite thebuild_graph_index_overlapandbuild_merged_spans_from_connected_componentsfunctions with a runtime speedup using thenumpy` package, you can vectorize the operations and utilize numpy's advanced indexing capabilities. First, make sure to import numpy:

import numpy as np

Now, you can rewrite the build_graph_index_overlap function as follows:

def build_graph_index_overlap(self):
    """
    Build graph, each node is represented by (start, end) of tuple, with the list of spans. Spans are considered
    overlapping if they are index_distance apart
    """
    span_starts = np.array([span.start for span in self.list_of_spans])
    span_ends = np.array([span.end for span in self.list_of_spans])

    # Compute the distance matrix
    start_diff = np.abs(span_starts[:, np.newaxis] - span_ends)
    end_diff = np.abs(span_ends[:, np.newaxis] - span_starts)
    min_diff = np.minimum(start_diff, end_diff)

    # Find neighboring spans
    neighboring_spans = np.where(min_diff <= self.index_distance)

    for i, j in zip(*neighboring_spans):
        if i != j:
            span_i = self.list_of_spans[i]
            span_j = self.list_of_spans[j]
            self.graph[span_i.start, span_i.end].append(span_j)
            self.graph[span_j.start, span_j.end].append(span_i)

For the build_merged_spans_from_connected_components function, there are no obvious parts that could benefit from using numpy. However, it is possible to optimize the function by removing the nested for loop and using list comprehensions instead:

def build_merged_spans_from_connected_components(self, index):
    """
    For each of the lists of the connected nodes determined by symbol distance or box distance,
    merge boxes and find, min, max of the index
    """
    if index:
        self.build_graph_index_overlap()
    else:
        self.build_graph_box_overlap()

    nodes_in_comp, number_of_comps = self.get_components()

    # all intervals in each connected component must be merged.
    merged_spans = []
    for comp in range(number_of_comps):
        if nodes_in_comp[comp]:
            spans_by_page = defaultdict(list)
            for pg, page_spans in groupby(
                nodes_in_comp[comp],
                lambda s: s.box.page if s.box is not None else None,
            ):
                spans_by_page[pg].extend(page_spans)

            merged_spans.extend(
                [
                    Span(
                        start=min([span.start for span in page_spans]),
                        end=max([span.end for span in page_spans]),
                        box=Box.small_boxes_to_big_box(
                            [span.box for span in page_spans]
                        ),
                    )
                    for page_spans in spans_by_page.values()
                ]
            )

    return merged_spans

This implementation leverages the power of numpy to compute the distance matrix and find neighboring spans more efficiently in the build_graph_index_overlap function. Note that this approach might use more memory compared to the original implementation due to the creation of the distance matrix. However, it should provide a significant runtime speedup.

For the build_merged_spans_from_connected_components function, the optimization mainly focuses on simplifying the code and using list comprehensions.`