PPPLDeepLearning / plasma-python

PPPL deep learning disruption prediction package
http://tigress-web.princeton.edu/~alexeys/docs-web/html/
79 stars 43 forks source link

Support for float16 training #15

Closed ASvyatkovskiy closed 7 years ago

ASvyatkovskiy commented 7 years ago

This PR enables ability to train FRNN on float16 half-precision floats. Main changes.

  1. Introduce a custom MPI_Datatype of 2 contiguous bytes, and save it in the type dictionary:

    mpi_float16 = MPI.BYTE.Create_contiguous(2).Commit()
    MPI._typedict['e'] = mpi_float16
  2. Define a custom reduction operation for the new type:

    def sum_f16_cb(buffer_a, buffer_b, t):
    assert t == mpi_float16
    array_a = np.frombuffer(buffer_a, dtype='float16')
    array_b = np.frombuffer(buffer_b, dtype='float16')
    array_b += array_a
  3. Register it as an OP:

    mpi_sum_f16 = MPI.Op.Create(sum_f16_cb, commute=True)
  4. Introduce a switch in the code based on the Keras floatX variable:

    if K.floatx() == 'float16':
    self.comm.Allreduce(arr,arr_global,op=mpi_sum_f16)
    else:
    self.comm.Allreduce(arr,arr_global,op=MPI.SUM)

Other side changes:

  1. Move configuration files to plasma/conf_parser.py, cleanup.