databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

OSError: Stale file handle with dMoE #106

Open Muennighoff opened 2 months ago

Muennighoff commented 2 months ago

I am getting the below error upon the first step of multinode training with dMoE. Meanwhile, multinode MoE training & single node dMoE works fine. Any ideas what the problem might be? Thanks!

File "/env/lib/conda/llmoe/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/stk/backend/autocast.py", line 27, in decorate_fwd
    return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/megablocks/ops/scatter.py", line 28, in forward
    return kernels.scatter(
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/megablocks/backend/kernels.py", line 208, in scatter
    return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/megablocks/backend/kernels.py", line 190, in padded_scatter
    _padded_copy[(indices.shape[0],)](
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 122, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench
    fn()
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call
    self.fn.run(
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/runtime/jit.py", line 532, in run
    self.cache[device][key] = compile(
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/compiler/compiler.py", line 503, in compile
    metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
File "/env/lib/conda/llmoe/lib/python3.10/site-packages/triton/runtime/cache.py", line 90, in get_group
    grp_data = json.load(f)
File "/env/lib/conda/llmoe/lib/python3.10/json/__init__.py", line 293, in load
    return loads(fp.read(),
OSError: [Errno 116] Stale file handle
mvpatel2000 commented 2 months ago

@Muennighoff Hm... this looks like an error in triton... what version are you running on? It could be an issue on their end

Muennighoff commented 2 months ago

I tried with 2.1.0 & 2.2.0 & 2.3.0 and get it everywhere

mvpatel2000 commented 2 months ago

I tried with 2.1.0 & 2.2.0 & 2.3.0 and get it everywhere

Hm... would you mind providing a minimal repro? it seems to work fine on my end so wondering if its a setup thing