minnervva / torchdetscan

This is a tool for finding non-deterministic functions in your pytorch code.
https://github.com/minnervva/torchdetscan
MIT License
1 stars 1 forks source link

deepmd kernel isolation #19

Open mtaillefumier opened 9 months ago

mtaillefumier commented 9 months ago

Isolate deepmd GPU kernels that use atomicAdd floating point, study them and compare them to their CPU counterparts. Random data are enough for this comparison. Then remove the atomic operations and test again.

Atomic operations are present because they have multiple levels of parallelization. Contrary to my initial thoughts the atomic add is used because one parallelization level involves the neighbors. This can be removed completely with a minimal amount of change.

example

template <typename FPTYPE>
__global__ void force_deriv_wrt_neighbors_r(FPTYPE* force,
                                            const FPTYPE* net_deriv,
                                            const FPTYPE* in_deriv,
                                            const int* nlist,
                                            const int nloc,
                                            const int nall,
                                            const int nnei) {
  // idy -> nnei
  const int_64 idx = blockIdx.x;
  const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
  const unsigned int idz = threadIdx.y;
  const int ndescrpt = nnei * 1;
  if (idy >= nnei) {
    return;
  }
  // deriv wrt neighbors
  int j_idx = nlist[idx * nnei + idy];
  if (j_idx < 0) {
    return;
  }
  const int_64 kk = idx / nloc;  // frame index
  atomicAdd(force + kk * nall * 3 + j_idx * 3 + idz,
            net_deriv[idx * ndescrpt + idy] *
                in_deriv[idx * ndescrpt * 3 + idy * 3 + idz]);
}

the y dimension of the block is used for parallelization over neighbors. That's why the atomic add is used.

better way to do this

template <typename FPTYPE>
__global__ void force_deriv_wrt_neighbors_r(FPTYPE* force,
                                            const FPTYPE* net_deriv,
                                            const FPTYPE* in_deriv,
                                            const int* nlist,
                                            const int nloc,
                                            const int nall,
                                            const int nnei) {
  // idy -> nnei
  const int_64 idx = blockIdx.x;
  //const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
  const unsigned int idz = threadIdx.y;
  const int ndescrpt = nnei * 1;
  if (idy >= nnei) {
    return;
  }
  // deriv wrt neighbors
  //int j_idx = nlist[idx * nnei + idy];
  //if (j_idx < 0) {
  //  return;
  //}
  const int kk = idx / nloc;  // frame index
  double3 f_ = make_double3(0.0, 0.0, 0.0);
  for (int i = 0; i < nnei; i++)  {
        j_idx = nlist[idx * nnei + i];
        if (j_idx < 0)
             continue;
   f_.x +=   net_deriv[idx * ndescrpt + i] *
                in_deriv[idx * ndescrpt * 3 + i * 3 + idz])
   f_.y +=   net_deriv[idx * ndescrpt + i] *
                in_deriv[idx * ndescrpt * 3 + i * 3 + idz + 1])
   f_.z +=   net_deriv[idx * ndescrpt + i] *
                in_deriv[idx * ndescrpt * 3 + i * 3 + idz + 2])
}
force[kk * nall * 3 + j_idx * 3 + idz] += f_.x;
force[kk * nall * 3 + j_idx * 3 + idz + 1] += f_.y;
force[kk * nall * 3 + j_idx * 3 + idz + 2] += f_.z;
}

but it requires different launch parameters. so it should be tested.