huggingface / pixparse

Pixel Parsing. A reproduction of OCR-free end-to-end document understanding models with open data
11 stars 3 forks source link

[Suggestion] Remove crop_margin dependency on cv2 #22

Open molbap opened 12 months ago

molbap commented 12 months ago

Currently we depend on cv2 for Nougat's crop margin, what do you think about something like this, I just changed the two calls to cv2 to this:

def pythonfindNonZero(gray):
    non_zero_indices = np.column_stack(np.nonzero(gray))
    idxvec = non_zero_indices[:, [1, 0]]
    return idxvec

def pythonBoundingRect(coords):
    min_vals = np.min(coords, axis=0).astype(int)
    max_vals = np.max(coords, axis=0).astype(int)
    return min_vals[0], min_vals[1], max_vals[0] - min_vals[0], max_vals[1] - min_vals[1]

class pythonCropMargin:
    def __init__(self):
        pass

    def __call__(self, img):
        if isinstance(img, torch.Tensor):
            assert False
        else:
            data = np.array(img.convert("L"))
            data = data.astype(np.uint8)
            max_val = data.max()
            min_val = data.min()
            if max_val == min_val:
                return img
            data = (data - min_val) / (max_val - min_val) * 255
            gray = 255 * (data < 200).astype(np.uint8)

            coords = pythonfindNonZero(gray)
            a, b, w, h = pythonBoundingRect(coords)
            return img.crop((a, b, w + a, h + b))

This is less efficient than cv2 (21ms on average on my machine vs 12ms for cv2 impl). Does use numpy. I have slightly different resulting images (one pixel on either axis), inducing a slightly different mean/std but overall looks similar

molbap commented 12 months ago

After checking, these two seem to output the same as findNonZero and BoundingRect.

def pythonfindNonZero(src):
    assert len(src.shape) == 2, "Input must be a 2D array"
    non_zero_indices = np.column_stack(np.nonzero(src))
    idxvec = non_zero_indices[:, [1, 0]]
    idxvec = idxvec.reshape(-1, 1, 2)
    return idxvec

def pythonBoundingRect(coords):
    min_vals = np.min(coords, axis=(0, 1)).astype(int)
    max_vals = np.max(coords, axis=(0, 1)).astype(int)

    return min_vals[0], min_vals[1], max_vals[0] - min_vals[0] + 1, max_vals[1] - min_vals[1] + 1

Need to test in on BoundingRect and findNonZero test suites.

rwightman commented 12 months ago

@molbap nice, I'm surprised it's only ~2x slower, does that hold as the image gets larger? It might be worth adding for the 'better' transforms so it's fully cv2 free.

rwightman commented 12 months ago

gray = 255 * (data < 200).astype(np.uint8) ... couldn't this be more efficient for findnonzero if it was left bool without the mul * 255?. EDIT: so just gray = data < 200

molbap commented 12 months ago

sure, actually there's quite a few optimizations that are done in switch case statements in the cpp code, we can add a few in I checked with 2000x2000 images, but not larger, the memory footprint increases due to the buffer optim in numpy's nonzero func, I'll try higher res

molbap commented 12 months ago

bool optim shaves 10% of the runtime it seems.

molbap commented 11 months ago

In here, source for https://github.com/huggingface/transformers/pull/25942, a modified version exists for crop_margin which is about as fast as cv2, and yields the same results. https://github.com/NielsRogge/transformers/blob/d5e0590067069a19c57c97d76cd1a70e86ef8c5f/src/transformers/models/nougat/image_processing_nougat.py#L144C1-L193C21