rs-station / reciprocalspaceship

Tools for exploring reciprocal space
https://rs-station.github.io/reciprocalspaceship/
MIT License
28 stars 12 forks source link

DataSet.select_dtypes does not support custom dtypes #104

Closed kmdalton closed 2 years ago

kmdalton commented 2 years ago

pd.DataFrame has a method select_dtypes which returns columns matching a particular numpy dtype. In the context of rs it'd be natural for this to support differentiating custom MTZDtype's. However, this is not the case right now.

Given an example mtz file,

[ins] In [1]: mtz.head()
Out[1]: 
             F(+)      SigF(+)       F(-)      SigF(-)  N(+)  N(-)       high(+)     loc(+)  low(+)     scale(+)       high(-)     loc(-)  low(-)     scale(-)
H K L
0 0 4  0.94140863 0.0060185874 0.94140863 0.0060185874   8.0   8.0 10000000000.0 0.94140863   1e-32 0.0060185874 10000000000.0 0.94140863   1e-32 0.0060185874
    8   1.8974894   0.01334675  1.8974894   0.01334675   8.0   8.0 10000000000.0  1.8974894   1e-32   0.01334675 10000000000.0  1.8974894   1e-32   0.01334675
    12  2.1121132   0.02015744  2.1121132   0.02015744   8.0   8.0 10000000000.0  2.1121132   1e-32   0.02015744 10000000000.0  2.1121132   1e-32   0.02015744
    16   5.133872  0.033373583   5.133872  0.033373583   4.0   4.0 10000000000.0   5.133872   1e-32  0.033373583 10000000000.0   5.133872   1e-32  0.033373583
    20 0.19568625   0.12823802 0.19568625   0.12823802   1.0   1.0 10000000000.0 0.12831146   1e-32   0.17213167 10000000000.0 0.12831146   1e-32   0.17213167

with dtypes

[ins] In [2]: mtz.dtypes
Out[2]: 
F(+)        FriedelSFAmplitude
SigF(+)        StddevFriedelSF
F(-)        FriedelSFAmplitude
SigF(-)        StddevFriedelSF
N(+)                   MTZReal
N(-)                   MTZReal
high(+)                MTZReal
loc(+)                 MTZReal
low(+)                 MTZReal
scale(+)               MTZReal
high(-)                MTZReal
loc(-)                 MTZReal
low(-)                 MTZReal
scale(-)               MTZReal
dtype: object

rs.DataSet.select_dtypes appears to fallback to the numpy dtype. For instance, when I call, mtz.select_dtypes("G") I expect rs to return a DataSet or view containing only "F(+)" and "F(-)" columns. Instead, I get all the columns backed by np.float32

[nav] In [3]: mtz.select_dtypes("G")
Out[5]: 
               F(+)      SigF(+)       F(-)      SigF(-)  N(+)  N(-)       high(+)     loc(+)  low(+)     scale(+)       high(-)     loc(-)  low(-)     scale(-)
H  K  L
0  0  4  0.94140863 0.0060185874 0.94140863 0.0060185874   8.0   8.0 10000000000.0 0.94140863   1e-32 0.0060185874 10000000000.0 0.94140863   1e-32 0.0060185874
      8   1.8974894   0.01334675  1.8974894   0.01334675   8.0   8.0 10000000000.0  1.8974894   1e-32   0.01334675 10000000000.0  1.8974894   1e-32   0.01334675
      12  2.1121132   0.02015744  2.1121132   0.02015744   8.0   8.0 10000000000.0  2.1121132   1e-32   0.02015744 10000000000.0  2.1121132   1e-32   0.02015744
      16   5.133872  0.033373583   5.133872  0.033373583   4.0   4.0 10000000000.0   5.133872   1e-32  0.033373583 10000000000.0   5.133872   1e-32  0.033373583
      20 0.19568625   0.12823802 0.19568625   0.12823802   1.0   1.0 10000000000.0 0.12831146   1e-32   0.17213167 10000000000.0 0.12831146   1e-32   0.17213167
...             ...          ...        ...          ...   ...   ...           ...        ...     ...          ...           ...        ...     ...          ...
14 13 19        NaN          NaN 0.55378014   0.08148462   NaN   2.0           NaN        NaN     NaN          NaN 10000000000.0 0.55378014     0.0   0.08148462
   11 20        NaN          NaN  0.6732702   0.09068045   NaN   2.0           NaN        NaN     NaN          NaN 10000000000.0  0.6732702     0.0   0.09068045
   10 20        NaN          NaN  0.8092094   0.08233523   NaN   2.0           NaN        NaN     NaN          NaN 10000000000.0  0.8092094     0.0   0.08233523
   9  20        NaN          NaN  1.2847979   0.06926164   NaN   2.0           NaN        NaN     NaN          NaN 10000000000.0  1.2847979     0.0   0.06926164
   8  20        NaN          NaN   1.344098   0.06747224   NaN   2.0           NaN        NaN     NaN          NaN 10000000000.0   1.344098     0.0   0.06747224

which is all columns in this case.

Making this behave as expected either requires a change to the underlying pandas method or overloading the method in rs. From this perspective, it might be better to raise this issue with the pandas devs. Not sure.

JBGreisman commented 2 years ago

In my mind, this is a behavior that was made by pandas so I hesitate to overload it to change it. For example, this method also doesn't differentiate between np.int64 and pd.Int64Dtype (the nullable pandas int64 implementation):

In [17]: df = pd.DataFrame(np.arange(12).reshape(3, 4),
    ...:                   columns=['A', 'B', 'C', 'D'])
    ...: df["A"] = df["A"].astype(pd.Int64Dtype())

In [18]: df.dtypes
Out[18]: 
A    Int64         <----- Note capitalization
B    int64
C    int64
D    int64
dtype: object

In [19]: df.select_dtypes(pd.Int64Dtype())
Out[19]: 
   A  B   C   D
0  0  1   2   3
1  4  5   6   7
2  8  9  10  11

I do agree that their documentation doesn't make it clear that this operates to return columns based on the underlying numpy dtype though. It is always possible to get this behavior with a list comprehension, so if we do want this sort of method, I would rather implement it as a custom DataSet method rather than overload the DataFrame one. Something like this, but with added support to handle inputting the dtype as str or object:

def select_mtzdtype(self, dtype):
    return self[[k for k in self if isinstance(self.dtypes[k], dtype)]]
kmdalton commented 2 years ago

I'm totally happy with your proposed solution.