tk-rusch / unicornn

Official code for UnICORNN (ICML 2021)
27 stars 3 forks source link

fixes for cupy 9.x #2

Closed tayalkshitij closed 2 years ago

tayalkshitij commented 2 years ago

Which part of the code should I delete and where should I insert mod = cupy.RawModule(code=UnICORNN_CODE, options=('--std=c++11',), name_expressions=('unicornn_fwd', 'unicornn_bwd'))

class UnICORNN_compile():
    _UnICORNN_PROG = Program(UnICORNN_CODE, 'unicornn_prog.cu')
    _UnICORNN_PTX = _UnICORNN_PROG.compile()
    _DEVICE2FUNC = {}

    def __init__(self):
        super(UnICORNN_compile, self).__init__()

    def compile_functions(self):
        device = torch.cuda.current_device()
        mod = function.Module()
        mod.load(bytes(self._UnICORNN_PTX.encode()))
        fwd_func = mod.get_function('unicornn_fwd')
        bwd_func = mod.get_function('unicornn_bwd')

        Stream = namedtuple('Stream', ['ptr'])
        current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)

        self._DEVICE2FUNC[device] = (current_stream, fwd_func, bwd_func)
        return current_stream, fwd_func, bwd_func

    def get_functions(self):
        res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None)
        return res if res else self.compile_functions()
tk-rusch commented 2 years ago

Hi,

in the network.py file, simply change

from pynvrtc.compiler import Program

to

import cupy as cp

Then change the whole UnICORNN_compile class to

class UnICORNN_compile():
    _DEVICE2FUNC = {}

    def __init__(self):
        super(UnICORNN_compile, self).__init__()

    def compile_functions(self):
        device = torch.cuda.current_device()
        mod = cp.RawModule(code=UnICORNN_CODE, options=('--std=c++11',),
        name_expressions=('unicornn_fwd', 'unicornn_bwd'))
        fwd_func = mod.get_function('unicornn_fwd')
        bwd_func = mod.get_function('unicornn_bwd')

        Stream = namedtuple('Stream', ['ptr'])
        current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)

        self._DEVICE2FUNC[device] = (current_stream, fwd_func, bwd_func)
        return current_stream, fwd_func, bwd_func

    def get_functions(self):
        res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None)
        return res if res else self.compile_functions()

I hope this helps. Let me know, if there's still an issue.