Open DJNing opened 4 months ago
A quick fix for the code with latest tyro.
Change the following line: https://github.com/SkyworkAI/Gamba/blob/995d8d7ef7054213b457b71d3f9a060dd027be5c/core/options.py#L110
to AllConfigs = Options
thanks for your detailed solution, we will update it into our env requirements
After the correction in DJNing, some bugs were fixed, but later some CUDA functions were found to be unusable (such as xFormers). After installing xFormers with pip, the system uninstalled dcuda11.8+torch 2.1.0 and updated it to nvidia-nccl-cu12+torch 2.4.0. There are quite a few questions and I don't know how to operate it. Could you please give me some guidance
sure, hgdzhx. xformers actually is not neccessary during inference in current code, so you can ignore the installation of xformers.
If you want to install xformer for speedup, a compromise is to build the xformers from source.
I sincerely appreciate your reply. In the follow-up repairing and reviewing of the code, I found that the open source code can only support fixed camera parameters of the ''S3'' format image input. So can you open the complete code? I find your work very innovative and will keep an eye on you. Hope you have a nice day.
Thanks a lot for the great work! Yet I encourtered some problems while runing the demo
Need help for the error for the configuration loading
I am trying to run the test based on the code. Yet, I found one error in: https://github.com/SkyworkAI/Gamba/blob/995d8d7ef7054213b457b71d3f9a060dd027be5c/core/options.py#L110
It says the input arg for defaults (
config_defaults
) should have len(defaults) > 2.While I am not familiar to tyro, could you guys tell a bit on how to solve this?
Extra dependency for the repo:
pip install timm jaxtyping omegaconf typeguard colorama
Minor change for the mamba_ssm:
For the latest version of mamba_ssm (2.2.2) the following line: https://github.com/SkyworkAI/Gamba/blob/995d8d7ef7054213b457b71d3f9a060dd027be5c/core/gambaformer.py#L8
should be changed to
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn