HazyResearch / hippo-code

Apache License 2.0
168 stars 30 forks source link

Need to call from java.need some guide in porting it. #1

Closed xwaeaewcrhomesysplug closed 3 years ago

xwaeaewcrhomesysplug commented 4 years ago

I wanted to call the function from java.For course there are many library around to help. However I am not experienced in Cpp so need your help to help me navigate around.

For the best performance we could use jni directly.But it will probably take a long time. so I decided to use javacpp.https://github.com/bytedeco/javacpp

From them there are few useful sources: https://github.com/bytedeco/javacpp/wiki/Mapping-Recipes https://github.com/bytedeco/javacpp-presets/wiki/Create-New-Presets

However it seems that your cpp code is dependent on the pytorch cpp lib. and I did went and take torch part i think it is in csrc/torch In which have custom_class.h custom_class_details.h extension.h script.h csrc(a folder which have more header files)

So I would like to ask for your help to help me make sure the required stuff is there. Of course the java jar,javacpp setup will be handled by me. But I am kinda lost in the c setup.

So in summary: Can you help me check.Did I have the C codes needed to compile and setup? How do I refer to and call your c function? From (Specifying names to use in Java)in https://github.com/bytedeco/javacpp/wiki/Mapping-Recipes is it something like "full::namespace::FunctioNameInCPP"? How in your case?I am not familiar wih cpp syntax.

Of course the java port is kinda unrelated to this repo so you may just ignore and close it. However if you plan to do the java port in future feel free to invite me to the repo.

tridao commented 4 years ago

Thanks for checking out our code!

Re: which pytorch cpp header files: we use extension.h, which is a catch-all header that then includes all the other pytorch header files.

I actually don't know how to call the C++ code from languages other than Python. Our workflow has been running python setup.py install to install the extension, then calling it from Python (https://pytorch.org/tutorials/advanced/cpp_extension.html).

I think it might actually be faster to reimplement the functionality in Java, rather than getting Java to call C++. Our C++ code is for inference (only forward pass is implemented, not backward pass), and on CPU (no CUDA code). This is to demonstrate that HiPPO-LegS can be implemented efficiently. Our training and experimentation are in Python. If you want to use HiPPO-LegS for inference, you can just reimplement the trapezoidal function (https://github.com/HazyResearch/hippo-code/blob/master/csrc/hippolegs.cpp#L89). It boils down to a few for loops over a 3D array.

xwaeaewcrhomesysplug commented 4 years ago

Thank you for the reply.

Your setup experience is a good info.I think I will copy that and analyze what i have at the end.

also Thanks for highlighting the implementations,as I did go through the code and kinda shocked at how neat it is. I originally thought your code will call CUDA or uses some specific torch functions.

Since you suggest to directly use Java and reimplement I will consider it. Of course,if the java speed is not up to it,then maybe for I will vectorize it and use javacpp call openblas.

Thanks for the suggestions and info and your time for answering the question. You can close this issue.Of course,if you keep it open then I will inform you when I finished the implementation. But,I am also implementing other parts,so it is probably going to take a month.