ApolloResearch / rib

Library for methods related to the Local Interaction Basis (LIB)
MIT License
2 stars 0 forks source link

Nix' magic memory saving changes #338

Closed stefan-apollo closed 4 months ago

stefan-apollo commented 4 months ago

Save memory with better division

Credit to Nix who came up with this!

Description

Moving the /normalization_factor to after the einsum saves VRAM.

Motivation and Context

Yay we can have more batch size!

How Has This Been Tested?

Ran pytest. Ran this config which works now, and OOM'ed before.

exp_name: tinystories
seed: 0
tlens_pretrained: tiny-stories-1M
tlens_model_path: null
dataset:
  dataset_type: huggingface
  name: roneneldan/TinyStories # or skeskinen/TinyStories-GPT4, but not clear if part of training
  tokenizer_name: EleutherAI/gpt-neo-125M
  return_set: train
  return_set_frac: null
  n_samples: 10000 # avg ~235 toks / story
  n_documents: 10000
  return_set_portion: first
  n_ctx: 100 # needs to be <= 511 for the model to behave reasonably
node_layers:
  - ln1.0
  - ln1_out.0
  - attn_in.0
  - ln2.0
  - ln2_out.0
  - mlp_in.0
  - ln1.1
  - ln1_out.1
  - attn_in.1
  - ln2.1
  - ln2_out.1
  - mlp_in.1
  - ln1.2
  - ln1_out.2
  - attn_in.2
  - ln2.2
  - ln2_out.2
  - mlp_in.2
  - ln1.3
  - ln1_out.3
  - attn_in.3
  - ln2.3
  - ln2_out.3
  - mlp_in.3
  - ln1.4
  - ln1_out.4
  - attn_in.4
  - ln2.4
  - ln2_out.4
  - mlp_in.4
  - ln1.5
  - ln1_out.5
  - attn_in.5
  - ln2.5
  - ln2_out.5
  - mlp_in.5
  - ln1.6
  - ln1_out.6
  - attn_in.6
  - ln2.6
  - ln2_out.6
  - mlp_in.6
  - ln1.7
  - ln1_out.7
  - attn_in.7
  - ln2.7
  - ln2_out.7
  - mlp_in.7
  - unembed
batch_size: 600
gram_batch_size: 500
edge_batch_size: 400
truncation_threshold: 1e-15
rotate_final_node_layer: true
n_intervals: 0
dtype: float64
calculate_edges: true
eval_type: null
basis_formula: jacobian
edge_formula: squared
center: true
n_stochastic_sources_basis_pos: 20
n_stochastic_sources_basis_hidden: 40
n_stochastic_sources_edges: 6
integration_method: gradient
out_dir: /mnt/ssd-rib/tinystories

Does this PR introduce a breaking change?

No.

stefan-apollo commented 4 months ago

Huh, pytest --runslow gets some errors

FAILED tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-alpha)^2-functional-trapezoidal] - AssertionError: Tensor-likes are not close!
FAILED tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-0)*alpha-functional-trapezoidal] - AssertionError: Tensor-likes are not close!
FAILED tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-alpha)^2-squared-trapezoidal] - AssertionError: Tensor-likes are not close!
FAILED tests/test_build_graph.py::test_modular_arithmetic_build_graph[(1-0)*alpha-squared-trapezoidal] - AssertionError: Tensor-likes are not close!
FAILED tests/test_build_graph.py::test_modular_arithmetic_build_graph[jacobian-squared-trapezoidal] - AssertionError: Tensor-likes are not close!
FAILED tests/test_build_graph.py::test_modular_arithmetic_build_graph[jacobian-squared-gauss-legendre] - AssertionError: Tensor-likes are not close!
stefan-apollo commented 4 months ago

Oh these also fail on main, on an A100

stefan-apollo commented 4 months ago

Okay, ignore these failing tests. Tests pass on an A6000. Tracked in #339.