Open rbavery opened 1 year ago
@rbavery Well, first I'm actually very delighted to see a 4x improvement on a 3090 even though we mostly targeted the A100 :D Thank you for trying to run this and producing these numbers!
Next, in order to get a traceback you can try running with with --capture_output False
? I suspect that the kernel we generate just doesn't work on a 3090.
woops, I thought True would output logs! Yeah cool to see big improvements apply to the 3090.
here are the logs for the failed int8 run. kinda weird, it raises an OOM error, but int8 should save memory? and the next experiment for sparse does not raise OOM
→ python run_experiments.py 8 vit_b ../ ../../segment-anything experiments_data --run-experiments --num-workers 8 --capture_output False
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [09:24<00:00, 1.10it/s]
sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,epilogue_fusion_first,use_compile_decoder,use_nested_tensor,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
vit_b,8,19934,82,10.454507324852766,95.65252277577636,0.5335680786450683,False,None,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [04:13<00:00, 2.44it/s]
vit_b,8,10003,41,23.958872166619145,41.73819172478647,0.5420768795834657,False,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [03:43<00:00, 2.77it/s]
vit_b,8,8152,33,31.393271025276487,31.853960015662075,0.5425282997311265,max-autotune,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [03:01<00:00, 3.40it/s]
vit_b,8,4671,19,39.20810820050168,25.504928595029863,0.5363043800669344,max-autotune,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.44s)
creating index...
index created!
100%|█████████████████████████████████████████| 619/619 [03:01<00:00, 3.41it/s]
vit_b,8,4671,19,39.23741711519246,25.485877346722877,0.5363043800669344,max-autotune,torch.bfloat16,None,False,False,False,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.46s)
creating index...
index created!
0%| | 0/619 [00:00<?, ?it/s]/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
return _nested.nested_tensor(
100%|█████████████████████████████████████████| 619/619 [02:58<00:00, 3.47it/s]
vit_b,8,4671,19,40.273361025850825,24.83030903127544,0.5355440758154442,max-autotune,torch.bfloat16,None,False,False,True,True,True,8,619,4952,None,None
loading annotations into memory...
Done (t=0.46s)
creating index...
index created!
0%| | 0/619 [00:17<?, ?it/s]
Traceback (most recent call last):
File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 446, in <module>
fire.Fire(run)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 409, in run
results, avg_ms_per_img, num_batches, num_images = runner(build_results,
File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 228, in identity_runner
return fn(*args, **kwargs)
File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 199, in build_results
_ = batch_runner(predictor, batch, batch_size, pad_input_image_batch)
File "/home/rave/segment-anything-fast/experiments/eval_combo.py", line 74, in build_results_batch_nested
features_batch = encoder(input_image_batch)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 408, in _fn
return fn(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 569, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 675, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
return _compile(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 599, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 516, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
transformations(instructions, code_options)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 150, in _fn
return fn(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 481, in transform
tracer.run()
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2124, in run
super().run()
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 819, in run
and self.step()
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 782, in step
getattr(self, inst.opname)(inst)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2239, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 879, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1024, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1096, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1077, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/__init__.py", line 1616, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1022, in compile_fx
return compile_fx(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1239, in compile_fx
return aot_autograd(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 4926, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 4466, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2991, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2130, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1171, in fw_compiler_base
return inner_compile(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/debug.py", line 303, in inner
return fn(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 394, in compile_fx_inner
compiled_graph = fx_codegen_and_compile(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 605, in fx_codegen_and_compile
graph.run(*example_inputs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/graph.py", line 445, in run
return super().run(*args)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/fx/interpreter.py", line 138, in run
self.env[node] = self.run_node(node)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/graph.py", line 749, in run_node
result = super().run_node(n)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/fx/interpreter.py", line 195, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/graph.py", line 585, in call_function
return target(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/fx_passes/post_grad.py", line 957, in fused_int_mm_mul
return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/kernel/mm.py", line 305, in tuned_fused_int_mm_mul
return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 989, in autotune_select_algorithm
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 746, in __call__
timings = self.lookup(
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 287, in lookup
timings = benchmark(choices)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 737, in autotune
return make_benchmark_fn()(choices)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 846, in benchmark_in_current_process
timing = benchmark_choice_in_current_process(choice)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 836, in benchmark_choice_in_current_process
result = choice.benchmark(*example_inputs, out=out)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 602, in benchmark
return self.bmreq.benchmark(*args, output_tensor=out)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/autotune_process.py", line 442, in benchmark
out = do_bench(fn)
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/_inductor/utils.py", line 166, in do_bench
return triton_do_bench(*args, **kwargs)[0]
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/triton/testing.py", line 106, in do_bench
fn()
File "<string>", line 78, in triton_mm
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/triton/compiler/compiler.py", line 704, in __getattribute__
self._init_handles()
File "/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/triton/compiler/compiler.py", line 693, in _init_handles
raise OutOfResources(self.shared, max_shared, "shared memory")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
OutOfResources: out of resource: shared memory, Required: 147456, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/sparse/semi_structured.py:92: UserWarning: The PyTorch API of SparseSemiStructuredTensor is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.sparse module for further information about the project.
warnings.warn(
loading annotations into memory...
Done (t=0.44s)
creating index...
index created!
0%| | 0/619 [00:00<?, ?it/s]/home/rave/mambaforge/envs/sam-fast/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
return _nested.nested_tensor(
100%|█████████████████████████████████████████| 619/619 [02:41<00:00, 3.84it/s]
vit_b,8,4870,20,46.89488723882429,21.3242862682928,0.4862656356911103,max-autotune,torch.bfloat16,sparse,False,False,True,True,True
@rbavery Mhm, so that's likely because it's trying a configuration for which the 3090 doesn't have enough shared memory. cc @HDCharles for looking into guarding this.
I'm not sure if we can promise a timeline until when this is done. Also, I suspect that it wouldn't help too much on 3090 to be honest. Do you have a specific performance target in mind or are you just testing the waters? Thank you for running the experiments giving us feedback!
oh totally just testing the waters, sorry should have said! this isn't urgent for me, I'm more curious than anything and I'll probably just look to export the current fastest option
@rbavery Mhm, so that's likely because it's trying a configuration for which the 3090 doesn't have enough shared memory. cc @HDCharles for looking into guarding this.
I'm not sure if we can promise a timeline until when this is done. Also, I suspect that it wouldn't help too much on 3090 to be honest. Do you have a specific performance target in mind or are you just testing the waters? Thank you for running the experiments giving us feedback!
@cpuhrsch Hi,3090 is ok?I also encountered the same error
@qingfengmingyue - We did focus on A100 here where the consumer cards are the 40 series. 3090 is the consumer counterpart of V100 which kernels like FlashAttention etc. don't support. You can definitely run this on a 3090, as rbavery shows above and it's showing some gains (~4x), but they're not as strong (~8x) as for A100.
Just a note, I looked into flash attn support and it looks like the original FlashAttention 2 repo supports 3090 now: https://github.com/Dao-AILab/flash-attention#installation-and-features
there wasn't an error traceback, just the results for each method
int8,0.4473809043566386,local-fork,2.2.0.dev20231117+cu121,ERROR