replicate / cog-flux

Cog inference for flux models
https://replicate.com/black-forest-labs/flux-dev
Apache License 2.0
272 stars 28 forks source link

optimize safety checker #15

Closed technillogue closed 1 month ago

technillogue commented 1 month ago

the list comprehension on line 257 accounts for 10% of the time spent in run_safety_checker, and is easy to optimize

profile

daanelson commented 1 month ago

though before I merge, @technillogue can you show the delta in inference time with this change? just a run with and a run without on H100s.

fofr commented 1 month ago

H100 test, fixed seed, all images safe

Before:

Safety checker time: 0.558
Safety checker time: 0.489
Safety checker time: 0.513
Safety checker time: 0.491
Safety checker time: 0.493

508.8ms

After:

Safety checker time: 0.483
Safety checker time: 0.493
Safety checker time: 0.501
Safety checker time: 0.549
Safety checker time: 0.480

501.2ms
fofr commented 1 month ago

Doing:


start_time = time.time()

np_images = [(127.5 * (rearrange(x[i], "c h w -> h w c").clamp(-1, 1) + 1.0)).cpu().byte().numpy() for i in range(num_outputs)]
images = [Image.fromarray(img) for img in np_images]
has_nsfw_content = [False] * len(images)
if not disable_safety_checker:
    _, has_nsfw_content = self.run_safety_checker(images, np_images) # always on gpu

end_time = time.time()
print(f"Safety checker time: {end_time - start_time:.3f}")```
technillogue commented 1 month ago

oh, maybe it is an improvement? testing on a brev PCIe H100 gave 1.59s+-0.01 vs 1.6s+-0.03s and prod gave me https://replicate.com/p/82ha9ntbmxrm20chwzy8cdgxbm vs https://replicate.com/p/4dfwkj1n81rm20chx01ssf1qp0, but if it's a real 6ms improvement that's still worthwhile

daanelson commented 1 month ago

great, thanks for getting the numbers @fofr @technillogue! good to merge.