google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Fix quantizaiton and jax lint check #52

Closed FanhaiLu1 closed 4 months ago

FanhaiLu1 commented 4 months ago

After this PR: Your code has been rated at 10.00/10

Before this PR:

Your code has been rated at 5.73/10 jax_test.py:21:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:23:2: W0621: Redefining name 'f' from outer scope (line 200) (redefined-outer-name) jax_test.py:35:2: W0621: Redefining name 'A' from outer scope (line 174) (redefined-outer-name) jax_test.py:24:4: R1705: Unnecessary "else" after "return", remove the "else" and de-indent the code inside it (no-else-return) jax_test.py:35:2: C0115: Missing class docstring (missing-class-docstring) jax_test.py:40:4: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:35:2: R0903: Too few public methods (1/2) (too-few-public-methods) jax_test.py:44:2: E0102: function already defined line 23 (function-redefined) jax_test.py:54:0: C0413: Import "from jax.sharding import PositionalSharding" should be placed at the top of the module (wrong-import-position) jax_test.py:55:0: C0413: Import "from jax.experimental import mesh_utils" should be placed at the top of the module (wrong-import-position) jax_test.py:58:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:58:0: R0914: Too many local variables (22/15) (too-many-locals) jax_test.py:89:30: W0613: Unused argument 'caches_v' (unused-argument) jax_test.py:89:50: W0613: Unused argument 'val' (unused-argument) jax_test.py:104:6: E1101: Instance of 'int32' has no 'block_until_ready' member (no-member) jax_test.py:140:12: E1101: Instance of 'int32' has no 'block_until_ready' member (no-member) jax_test.py:142:29: E1102: func is not callable (not-callable) jax_test.py:124:8: W0612: Unused variable 'i' (unused-variable) jax_test.py:150:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:158:2: W0621: Redefining name 'f' from outer scope (line 200) (redefined-outer-name) jax_test.py:151:2: C0415: Import outside toplevel (torch) (import-outside-toplevel) jax_test.py:152:2: C0415: Import outside toplevel (torch_xla2) (import-outside-toplevel) jax_test.py:153:2: C0415: Import outside toplevel (torch_xla2.extra) (import-outside-toplevel) jax_test.py:165:7: E1101: Module 'torch_xla2.tensor' has no 'XLAFunctionMode' member (no-member) jax_test.py:171:0: C0413: Import "from flax import struct" should be placed at the top of the module (wrong-import-position) jax_test.py:174:0: C0115: Missing class docstring (missing-class-docstring) jax_test.py:179:2: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:174:0: R0903: Too few public methods (1/2) (too-few-public-methods) jax_test.py:183:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:183:0: C0103: Function name "flatten_A" doesn't conform to snake_case naming style (invalid-name) jax_test.py:187:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:187:0: C0103: Function name "unflatten_A" doesn't conform to snake_case naming style (invalid-name) jax_test.py:188:2: C0415: Import outside toplevel (pdb) (import-outside-toplevel) jax_test.py:190:2: W1515: Leaving functions creating breakpoints in production code is not recommended (forgotten-debug-statement) jax_test.py:187:16: W0613: Unused argument 'aux_data' (unused-argument) jax_test.py:196:0: W0404: Reimport 'functools' (imported line 18) (reimported) jax_test.py:196:0: C0413: Import "import functools" should be placed at the top of the module (wrong-import-position) jax_test.py:200:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:205:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:211:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:211:0: R0914: Too many local variables (24/15) (too-many-locals) jax_test.py:230:44: W0613: Unused argument 'head_indexes' (unused-argument) jax_test.py:237:44: W0613: Unused argument 'head_indexes' (unused-argument) jax_test.py:271:6: E1101: Instance of 'int32' has no 'block_until_ready' member (no-member) jax_test.py:274:2: W0127: Assigning the same variable 'update_indexes' to itself (self-assigning-variable) jax_test.py:297:12: E1101: Instance of 'int32' has no 'block_until_ready' member (no-member) jax_test.py:299:19: E1102: func is not callable (not-callable) jax_test.py:219:2: W0612: Unused variable 'caches_v' (unused-variable) jax_test.py:287:8: W0612: Unused variable 'i' (unused-variable) jax_test.py:291:8: W0612: Unused variable 'key' (unused-variable) jax_test.py:306:0: C0116: Missing function or method docstring (missing-function-docstring) jax_test.py:307:2: C0415: Import outside toplevel (torch_xla2) (import-outside-toplevel) jax_test.py:308:2: C0415: Import outside toplevel (torch) (import-outside-toplevel) jax_test.py:18:0: C0411: standard import "functools" should be placed before third party imports "jax", "jax.numpy" (wrong-import-order) jax_test.py:196:0: C0411: standard import "functools" should be placed before third party imports "jax", "jax.numpy", "jax.sharding.PositionalSharding", "jax.experimental.mesh_utils", "flax.struct" (wrong-import-order) jax_test.py:171:0: W0611: Unused struct imported from flax (unused-import)