waltsims / k-wave-python

A Python interface to k-Wave GPU accelerated binaries
https://k-wave-python.readthedocs.io/en/latest/
GNU General Public License v3.0
105 stars 31 forks source link

@dataclass usage inconsistent #115

Open bshieh-bfly opened 1 year ago

bshieh-bfly commented 1 year ago

The use of @dataclass is not applied consistently throughout. For example, kSource uses @dataclass but the attributes are declared as class attributes instead so there is no benefit for kwarg instantiation or repr. There are several other classes that use @dataclass with no benefit. On the other hand, kWaveMedium uses @dataclass in the standard way. It would be nice to have all the classes with MATLAB k-wave analogs to have proper implementations as dataclasses.

faridyagubbayli commented 1 year ago

You make a valid point! We also considered using data classes when we initially implemented the classes. However, given the prevalence of Matlab-style code, it wasn't always feasible or suitable to employ data classes. We're committed to progressively adopting better Python practices and transitioning away from Matlab-style implementation. While it may still be impractical to use data classes in some instances, we aim to reach that milestone in the coming weeks. I suggest keeping this issue open until we achieve this goal.

bshieh-bfly commented 1 year ago

I appreciate all the hard work that has gone into this project!

If for some reason it is difficult to implement some of the classes as dataclasses, you could also consider encapsulating the current classes to expose a simplified interface. I'm using something like this:

@dataclass
class KWaveGrid3D:

    # attributes required for init
    Nx: int
    dx: float
    Ny: int
    dy: float
    Nz: int
    dz: float

    # only editable post init attribute
    t_array: np.ndarray = field(init=False, repr=False)

    # general read-only attributes
    k: np.ndarray = field(init=False, repr=False)
    k_max: float = field(init=False, repr=False)
    Nt: int = field(init=False)
    dt: float = field(init=False)
    dim: int = field(init=False)
    total_grid_points: int = field(init=False)

    # x dimension read-only attributes
    x: np.ndarray = field(init=False, repr=False)
    x_vec: np.ndarray = field(init=False, repr=False)
    x_size: float = field(init=False)
    kx: np.ndarray = field(init=False, repr=False)
    kx_vec: np.ndarray = field(init=False, repr=False)
    kx_max: float = field(init=False)

    # y dimension read-only attributes
    y: np.ndarray = field(init=False, repr=False)
    y_vec: np.ndarray = field(init=False, repr=False)
    y_size: float = field(init=False)
    ky: np.ndarray = field(init=False, repr=False)
    ky_vec: np.ndarray = field(init=False, repr=False)
    ky_max: float = field(init=False)

    # z dimension read-only attributes
    z: np.ndarray = field(init=False, repr=False)
    z_vec: np.ndarray = field(init=False, repr=False)
    z_size: float = field(init=False)
    kz: np.ndarray = field(init=False, repr=False)
    kz_vec: np.ndarray = field(init=False, repr=False)
    kz_max: float = field(init=False)

    # encapsulated kWaveGrid object
    _grid: _kWaveGrid = field(default=None, init=False, repr=False)

    def __post_init__(self):
        self._grid = _kWaveGrid([self.Nx, self.Ny, self.Nz],
                                [self.dx, self.dy, self.dz])

    @property
    def k(self):
        return self._grid.k

    @k.setter
    def k(self, val):
        pass

    @property
    def k_max(self):
        return self._grid.k_max

    @k_max.setter
    def k_max(self, val):
        pass

    @property
    def t_array(self):
        return self._grid.t_array

    @t_array.setter
    def t_array(self, val):
        if self._grid is not None:
            self._grid.t_array = val

    @property
    def Nt(self):
        return self._grid.Nt

    @Nt.setter
    def Nt(self, val):
        pass

    @property
    def dt(self):
        return self._grid.dt

    @dt.setter
    def dt(self, val):
        pass

    @property
    def dim(self):
        return self._grid.dim

    @dim.setter
    def dim(self, val):
        pass

    @property
    def total_grid_points(self):
        return self._grid.total_grid_points

    @total_grid_points.setter
    def total_grid_points(self, val):
        pass

    @property
    def x(self):
        return self._grid.x

    @x.setter
    def x(self, val):
        pass

    @property
    def x_vec(self):
        return self._grid.x_vec

    @x_vec.setter
    def x_vec(self, val):
        pass

    @property
    def x_size(self):
        # x_size mistakenly not defined in kWaveGrid
        return self._grid.size[0]

    @x_size.setter
    def x_size(self, val):
        pass

    @property
    def kx(self):
        return self._grid.kx

    @kx.setter
    def kx(self, val):
        pass

    @property
    def kx_vec(self):
        return self._grid.k_vec[0]

    @kx_vec.setter
    def kx_vec(self, val):
        pass

    @property
    def kx_size(self):
        return len(self.kx_vec)

    @kx_size.setter
    def kx_size(self, val):
        pass

    @property
    def kx_max(self):
        return self.k_max[0]

    @kx_max.setter
    def kx_max(self, val):
        pass

    @property
    def y(self):
        return self._grid.y

    @y.setter
    def y(self, val):
        pass

    @property
    def y_vec(self):
        return self._grid.y_vec

    @y_vec.setter
    def y_vec(self, val):
        pass

    @property
    def y_size(self):
        return self._grid.y_size

    @y_size.setter
    def y_size(self, val):
        pass

    @property
    def ky(self):
        return self._grid.ky

    @ky.setter
    def ky(self, val):
        pass

    @property
    def ky_vec(self):
        return self._grid.k_vec[2]

    @ky_vec.setter
    def ky_vec(self, val):
        pass

    @property
    def ky_size(self):
        return len(self.ky_vec)

    @ky_size.setter
    def ky_size(self, val):
        pass

    @property
    def ky_max(self):
        return self.k_max[1]

    @ky_max.setter
    def ky_max(self, val):
        pass

    @property
    def z(self):
        return self._grid.z

    @z.setter
    def z(self, val):
        pass

    @property
    def z_vec(self):
        return self._grid.z_vec

    @z_vec.setter
    def z_vec(self, val):
        pass

    @property
    def z_size(self):
        return self._grid.z_size

    @z_size.setter
    def z_size(self, val):
        pass

    @property
    def kz(self):
        return self._grid.kz

    @kz.setter
    def kz(self, val):
        pass

    @property
    def kz_vec(self):
        return self._grid.k_vec[2]

    @kz_vec.setter
    def kz_vec(self, val):
        pass

    @property
    def kz_size(self):
        return len(self.kz_vec)

    @kz_size.setter
    def kz_size(self, val):
        pass

    @property
    def kz_max(self):
        return self.k_max[2]

    @kz_max.setter
    def kz_max(self, val):
        pass

    @property
    def _kwave_object(self):
        return self._grid

    @_kwave_object.setter
    def _kwave_object(self, val):
        pass

    def makeTime(self, c, cfl=0.3, t_end=None):
        return self._grid.makeTime(c, cfl, t_end)
waltsims commented 1 year ago

This looks very clean and could be interesting to incorporate! You can start a PR, and we can see if we can get the dataclass usage more consistent across the project.