xorbitsai / xorbits

Scalable Python DS & ML, in an API compatible & lightning fast way.
https://xorbits.io
Apache License 2.0
1.06k stars 67 forks source link

ENH: Modify `JAX` fusion optimization #755

Closed JiaYaobo closed 7 months ago

JiaYaobo commented 7 months ago

Modify original PR https://github.com/xorbitsai/xorbits/pull/440, which doesn't make full use of the compilation of jax,

What do these changes do?

this PR re-implements the core _evaluate function, referring to https://jax.readthedocs.io/en/latest/autodidax.html#part-2-jaxprs, this PR add a no-recursive and fully jittable _eval function.

Check code requirements

JiaYaobo commented 7 months ago

@ChengjieLi28 @UranusSeven @aresnow1 PTAL

RandomY-2 commented 7 months ago

https://doc.xorbits.io/en/latest/development/contributing.html#autofixing-formatting-errors

lint fails due to isort, you can either run isort or set up pre commit to auto fix styling issues

codecov[bot] commented 7 months ago

Codecov Report

Attention: 5 lines in your changes are missing coverage. Please review.

Comparison is base (1964357) 93.59% compared to head (a31297e) 93.59%.

Files Patch % Lines
python/xorbits/_mars/tensor/fuse/jax.py 90.19% 2 Missing and 3 partials :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #755 +/- ## ======================================= Coverage 93.59% 93.59% ======================================= Files 1059 1059 Lines 79786 79797 +11 Branches 16506 16508 +2 ======================================= + Hits 74673 74687 +14 + Misses 3439 3430 -9 - Partials 1674 1680 +6 ``` | [Flag](https://app.codecov.io/gh/xorbitsai/xorbits/pull/755/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=xorbitsai) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/xorbitsai/xorbits/pull/755/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=xorbitsai) | `93.48% <88.46%> (+<0.01%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=xorbitsai#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

ChengjieLi28 commented 7 months ago

And I think we should add an option for enable JAX optimization? @JiaYaobo

ChengjieLi28 commented 7 months ago

And I think we should add an option for enable JAX optimization? @JiaYaobo

Never mind, we can follow the previous behaviour.