pytorch-labs / segment-anything-fast

A batched offline inference oriented version of segment-anything
Apache License 2.0
1.18k stars 68 forks source link

error for int8 inference on Nvidia 3090 #88

Open rbavery opened 9 months ago

rbavery commented 9 months ago
python run_experiments.py 8 vit_b ../ ../../segment-anything experiments_data --run-experiments --num-workers 8 --capture_output True

there wasn't an error traceback, just the results for each method

int8,0.4473809043566386,local-fork,2.2.0.dev20231117+cu121,ERROR

technique time sam_commit_name pytorch_version sam_model_type batch_size memory(MiB) memory(%) img_s(avg) batch_ms(avg)/batch_size mIoU use_compile use_half compress epilogue_fusion_first use_compile_decoder use_nested_tensor use_rel_pos pad_input_image_batch num_workers num_batches num_images profile_path memory_path
fp32 9.687259 default 2.2.0.dev20231117+cu121 vit_b 8 19934 82 10.270197 97.369114 0.5335680786450683 False None None False False False True True 8 619 4952 None None
bf16 4.469203 codesign 2.2.0.dev20231117+cu121 vit_b 8 10003 41 23.441896 42.658666 0.5420768795834657 False torch.bfloat16 None False False False True True 8 619 4952 None None
compile 5.114826 codesign 2.2.0.dev20231117+cu121 vit_b 8 8159 33 30.676929 32.597788 0.5425282997311265 max-autotune torch.bfloat16 None False False False True True 8 619 4952 None None
SDPA 3.309412 sdpa-decoder 2.2.0.dev20231117+cu121 vit_b 8 4858 20 38.438522 26.015569 0.5363043800669344 max-autotune torch.bfloat16 None False False False True True 8 619 4952 None None
Triton 3.219514 local-fork 2.2.0.dev20231117+cu121 vit_b 8 4671 19 38.492990 25.978756 0.5363043800669344 max-autotune torch.bfloat16 None False False False True True 8 619 4952 None None
NT 3.210428 local-fork 2.2.0.dev20231117+cu121 vit_b 8 4671 19 39.489851 25.322962 0.5355440758154442 max-autotune torch.bfloat16 None False False True True True 8 619 4952 None None
int8 0.447381 local-fork 2.2.0.dev20231117+cu121 ERROR
sparse 3.139057 local-fork 2.2.0.dev20231117+cu121 vit_b 8 4969 20 46.041364 21.719600 0.4862656356911103 max-autotune torch.bfloat16 sparse False False False True True 8 619 4952 None None
cpuhrsch commented 9 months ago

@rbavery Well, first I'm actually very delighted to see a 4x improvement on a 3090 even though we mostly targeted the A100 :D Thank you for trying to run this and producing these numbers!

Next, in order to get a traceback you can try running with with --capture_output False? I suspect that the kernel we generate just doesn't work on a 3090.

rbavery commented 9 months ago

woops, I thought True would output logs! Yeah cool to see big improvements apply to the 3090.

here are the logs for the failed int8 run. kinda weird, it raises an OOM error, but int8 should save memory? and the next experiment for sparse does not raise OOM

→ python run_experiments.py 8 vit_b ../ ../../segment-anything experiments_data --run-experiments --num-workers 8 --capture_output False
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [09:24<00:00,  1.10it/s]
sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
vit_b,8,19934,82,10.454507324852766,95.65252277577636,0.5335680786450683,False,None,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [04:13<00:00,  2.44it/s]
vit_b,8,10003,41,23.958872166619145,41.73819172478647,0.5420768795834657,False,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [03:43<00:00,  2.77it/s]
vit_b,8,8152,33,31.393271025276487,31.853960015662075,0.5425282997311265,max-autotune,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [03:01<00:00,  3.40it/s]
vit_b,8,4671,19,39.20810820050168,25.504928595029863,0.5363043800669344,max-autotune,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.44s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [03:01<00:00,  3.41it/s]
vit_b,8,4671,19,39.23741711519246,25.485877346722877,0.5363043800669344,max-autotune,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.46s)
creating index...
index created!
  0%|                                                   | 0/619 [00:00<?, ?it/s]/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
  return _nested.nested_tensor(
100%|█████████████████████████████████████████| 619/619 [02:58<00:00,  3.47it/s]
vit_b,8,4671,19,40.273361025850825,24.83030903127544,0.5355440758154442,max-autotune,torch.bfloat16,None,False,False,True,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.46s)
creating index...
index created!
  0%|                                                   | 0/619 [00:17<?, ?it/s]
Traceback (most recent call last):
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 446, in <module>
    fire.Fire(run)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 409, in run
    results, avg_ms_per_img, num_batches, num_images = runner(build_results,
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 228, in identity_runner
    return fn(*args, **kwargs)
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 199, in build_results
    _ = batch_runner(predictor, batch, batch_size, pad_input_image_batch)
  File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 74, in build_results_batch_nested
    features_batch = encoder(input_image_batch)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 408, in _fn
    return fn(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 569, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 675, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 599, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 516, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 150, in _fn
    return fn(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 481, in transform
    tracer.run()
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2124, in run
    super().run()
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in run
    and self.step()
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 782, in step
    getattr(self, inst.opname)(inst)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2239, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 879, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1024, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1096, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1077, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/__init__.py", line 1616, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1022, in compile_fx
    return compile_fx(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1239, in compile_fx
    return aot_autograd(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 4926, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 4466, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2991, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2130, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1171, in fw_compiler_base
    return inner_compile(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/debug.py", line 303, in inner
    return fn(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 394, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 605, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/graph.py", line 445, in run
    return super().run(*args)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/graph.py", line 749, in run_node
    result = super().run_node(n)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/fx/interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/graph.py", line 585, in call_function
    return target(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 957, in fused_int_mm_mul
    return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/kernel/mm.py", line 305, in tuned_fused_int_mm_mul
    return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 989, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 746, in __call__
    timings = self.lookup(
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 287, in lookup
    timings = benchmark(choices)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 737, in autotune
    return make_benchmark_fn()(choices)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 846, in benchmark_in_current_process
    timing = benchmark_choice_in_current_process(choice)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 836, in benchmark_choice_in_current_process
    result = choice.benchmark(*example_inputs, out=out)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 602, in benchmark
    return self.bmreq.benchmark(*args, output_tensor=out)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/autotune_process.py", line 442, in benchmark
    out = do_bench(fn)
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/utils.py", line 166, in do_bench
    return triton_do_bench(*args, **kwargs)[0]
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/triton/testing.py", line 106, in do_bench
    fn()
  File "<string>", line 78, in triton_mm
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/triton/compiler/compiler.py", line 704, in __getattribute__
    self._init_handles()
  File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/triton/compiler/compiler.py", line 693, in _init_handles
    raise OutOfResources(self.shared, max_shared, "shared memory")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/sparse/semi_structured.py:92: UserWarning: The PyTorch API of SparseSemiStructuredTensor is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.sparse module for further information about the project.
  warnings.warn(
loading annotations into memory...
Done (t=0.44s)
creating index...
index created!
  0%|                                                   | 0/619 [00:00<?, ?it/s]/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
  return _nested.nested_tensor(
100%|█████████████████████████████████████████| 619/619 [02:41<00:00,  3.84it/s]
vit_b,8,4870,20,46.89488723882429,21.3242862682928,0.4862656356911103,max-autotune,torch.bfloat16,sparse,False,False,True,True,True
cpuhrsch commented 9 months ago

@rbavery Mhm, so that's likely because it's trying a configuration for which the 3090 doesn't have enough shared memory. cc @HDCharles for looking into guarding this.

I'm not sure if we can promise a timeline until when this is done. Also, I suspect that it wouldn't help too much on 3090 to be honest. Do you have a specific performance target in mind or are you just testing the waters? Thank you for running the experiments giving us feedback!

rbavery commented 9 months ago

oh totally just testing the waters, sorry should have said! this isn't urgent for me, I'm more curious than anything and I'll probably just look to export the current fastest option

qingfengmingyue commented 9 months ago

@rbavery Mhm, so that's likely because it's trying a configuration for which the 3090 doesn't have enough shared memory. cc @HDCharles for looking into guarding this.

I'm not sure if we can promise a timeline until when this is done. Also, I suspect that it wouldn't help too much on 3090 to be honest. Do you have a specific performance target in mind or are you just testing the waters? Thank you for running the experiments giving us feedback!

@cpuhrsch Hi,3090 is ok?I also encountered the same error

cpuhrsch commented 9 months ago

@qingfengmingyue - We did focus on A100 here where the consumer cards are the 40 series. 3090 is the consumer counterpart of V100 which kernels like FlashAttention etc. don't support. You can definitely run this on a 3090, as rbavery shows above and it's showing some gains (~4x), but they're not as strong (~8x) as for A100.

rbavery commented 9 months ago

Just a note, I looked into flash attn support and it looks like the original FlashAttention 2 repo supports 3090 now: https://github.com/Dao-AILab/flash-attention#installation-and-features