aliutkus / torchinterp1d

1D interpolation for pytorch
BSD 3-Clause "New" or "Revised" License
162 stars 19 forks source link

Problems reproducing values given by np.interp() #12

Open JCBrouwer opened 3 years ago

JCBrouwer commented 3 years ago

Hello, I'm trying to rewrite some histogram matching code in pytorch which relies on some 1D interpolations.

I've noticed that while most of the values in my result with torchinterp1d are the same, there are a couple values which are an order of magnitude off of what I expect.

Here's some code to reproduce the issue:

import numpy as np
import torch
from torchinterp1d import Interp1d

interp1d = Interp1d()

# histogram matching with numpy

random_state = np.random.RandomState(12345)  #  not all seeds have this issue, but this is one that does

bins = 64

target = random_state.normal(size=(128 * 128)) * 2   #  some random data between about -8 and 8
source = random_state.normal(size=(128 * 128)) * 2
matched = np.empty_like(target)

lo = min(target.min(), source.min())
hi = max(target.max(), source.max())

target_hist_np, bin_edges_np = np.histogram(target, bins=bins, range=[lo, hi])
source_hist_np, _ = np.histogram(source, bins=bins, range=[lo, hi])

target_cdf_np = target_hist_np.cumsum()
target_cdf_np = target_cdf_np / target_cdf_np[-1]

source_cdf_np = source_hist_np.cumsum()
source_cdf_np = source_cdf_np / source_cdf_np[-1]

remapped_cdf_np = np.interp(target_cdf_np, source_cdf_np, bin_edges_np[1:])

matched_np = np.interp(target, bin_edges_np[1:], remapped_cdf_np, left=0, right=bins)

# now with pytorch

target = torch.from_numpy(target)
source = torch.from_numpy(source)

target_hist = torch.histc(target, bins, lo, hi)
source_hist = torch.histc(source, bins, lo, hi)

assert np.allclose(target_hist_np, target_hist.numpy())
assert np.allclose(source_hist_np, source_hist.numpy())

target_cdf = target_hist.cumsum(0)
target_cdf = target_cdf / target_cdf[-1]

assert np.allclose(target_cdf_np, target_cdf.numpy())

source_cdf = source_hist.cumsum(0)
source_cdf = source_cdf / source_cdf[-1]

assert np.allclose(source_cdf_np, source_cdf.numpy())

bin_edges = torch.linspace(lo, hi, bins + 1)

assert np.allclose(bin_edges_np, bin_edges.numpy())

remapped_cdf = interp1d(source_cdf, bin_edges[1:], target_cdf).squeeze()
# ^^^ first positions of this have -100 values all of a sudden?!

print(remapped_cdf_np)
print(remapped_cdf.numpy())
assert np.allclose(remapped_cdf_np, remapped_cdf.numpy())  # fails

matched = interp1d(bin_edges[1:], remapped_cdf, target)

assert np.allclose(matched_np, matched.numpy())

The above code gives me the output:

[-8.04819874 -8.04819874 -8.04819874 -7.03412467 -6.52708763 -6.34600297
 -6.27356911 -6.10455677 -5.89329133 -5.55526664 -5.28932075 -5.00597652
 -4.81837282 -4.66795183 -4.43309052 -4.17367044 -3.93438144 -3.670879
 -3.44365304 -3.19894227 -2.97056192 -2.7420723  -2.47732906 -2.21208839
 -1.96009338 -1.69844422 -1.44815496 -1.20431557 -0.94311239 -0.68723275
 -0.44108403 -0.18912467  0.055417    0.30790917  0.5585027   0.81660576
  1.0688232   1.33458219  1.60847022  1.85890728  2.12938742  2.38900627
  2.66416974  2.93036861  3.17321839  3.41920686  3.64490881  3.92116168
  4.1785585   4.43336298  4.75240842  4.99895072  5.34133486  5.58747586
  5.77523205  5.9234885   6.12066957  6.29370601  6.36613988  6.52911607
  7.6699494   7.6699494   7.6699494   7.92346792]
[-1.37849957e+02 -1.37849957e+02 -1.37849957e+02 -7.28813667e+00
 -6.78085313e+00 -6.34605302e+00 -6.27363935e+00 -6.10459291e+00
 -5.89331551e+00 -5.55530036e+00 -5.28934614e+00 -5.00599635e+00
 -4.81837965e+00 -4.66795432e+00 -4.43309173e+00 -4.17367124e+00
 -3.93438180e+00 -3.67087919e+00 -3.44365297e+00 -3.19894198e+00
 -2.97056138e+00 -2.74207334e+00 -2.47732997e+00 -2.21208787e+00
 -1.96009280e+00 -1.69844372e+00 -1.44815450e+00 -1.20431574e+00
 -9.43111794e-01 -6.87232221e-01 -4.41083541e-01 -1.89125193e-01
  5.54165081e-02  3.07908676e-01  5.58502213e-01  8.16605249e-01
  1.06882271e+00  1.33458229e+00  1.60847020e+00  1.85890732e+00
  2.12938745e+00  2.38900625e+00  2.66416942e+00  2.93036823e+00
  3.17321798e+00  3.41920633e+00  3.64490848e+00  3.92116086e+00
  4.17855749e+00  4.43336134e+00  4.75240455e+00  4.99894268e+00
  5.34132004e+00  5.58744816e+00  5.77521847e+00  5.92348256e+00
  6.12062090e+00  6.29366587e+00  6.36607955e+00  6.52899272e+00
  6.90914693e+00  6.90914693e+00  6.90914693e+00  7.92322078e+00]
Traceback (most recent call last):
  File "histmatch.py", line 256, in <module>
    assert np.allclose(remapped_cdf_np, remapped_cdf.numpy())  # fails
AssertionError

The values printed in the second array are from torchinterp1d while the top values are from np.interp for the same inputs (as evidenced by earlier asserts not triggering). Note that the order of arguments for torchinterp1d are slightly different than np.interp, but I believe they should produce the same result.

In fact, most of the values that are printed are the same. Take the last value of the array for example: 7.92322078e+00 is pretty close to 7.92346792. The same holds for almost all values in the array, except for the first 3. These are an order of magnitude lower than the rest of the values (around -140).

To be concrete, these two lines give different results for the same inputs:

remapped_cdf_np = np.interp(x=target_cdf_np, xp=source_cdf_np, fp=bin_edges_np[1:])
remapped_cdf = interp1d(x=source_cdf, y=bin_edges[1:], xnew=target_cdf).squeeze()

What's going on here? Is there a way to exactly reproduce numpy's results with pytorch?