Closed tayalkshitij closed 2 years ago
Hi,
in the network.py file, simply change
from pynvrtc.compiler import Program
to
import cupy as cp
Then change the whole UnICORNN_compile class to
class UnICORNN_compile():
_DEVICE2FUNC = {}
def __init__(self):
super(UnICORNN_compile, self).__init__()
def compile_functions(self):
device = torch.cuda.current_device()
mod = cp.RawModule(code=UnICORNN_CODE, options=('--std=c++11',),
name_expressions=('unicornn_fwd', 'unicornn_bwd'))
fwd_func = mod.get_function('unicornn_fwd')
bwd_func = mod.get_function('unicornn_bwd')
Stream = namedtuple('Stream', ['ptr'])
current_stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)
self._DEVICE2FUNC[device] = (current_stream, fwd_func, bwd_func)
return current_stream, fwd_func, bwd_func
def get_functions(self):
res = self._DEVICE2FUNC.get(torch.cuda.current_device(), None)
return res if res else self.compile_functions()
I hope this helps. Let me know, if there's still an issue.
Which part of the code should I delete and where should I insert
mod = cupy.RawModule(code=UnICORNN_CODE, options=('--std=c++11',), name_expressions=('unicornn_fwd', 'unicornn_bwd'))