iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.47k stars 548 forks source link

Allow flags to be set with greater flexibility #17659

Closed daveliddell closed 1 week ago

daveliddell commented 1 week ago

Changes to the python binding to allow iree.runtime.flags.parse_flags to take effect at times other than before the first time a driver is created. Also includes fixes for bugs exposed during the development of this feature.

daveliddell commented 1 week ago

Please hold off reviewing; will rework client of this API not to use the driver cache, so will git rid of unnecessary new features

daveliddell commented 1 week ago

Snippet of ugly but functional use of this new feature:

class vmfbRunner:
    def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None):
        flags = []
        clean_driver = False
        if extra_plugin:
            ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
            clean_driver = True
        haldriver = ireert.get_driver(device, clean_driver)
daveliddell commented 1 week ago

General note: in an offline discussion, we decided to keep the cache, so this PR is still a go. Making it ready-for-review again

daveliddell commented 1 week ago

New sample usage:

class vmfbRunner:
    def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None):
        flags = []

        # If an extra plugin is requested, add a global flag to load the plugin
        # and create the driver using the non-caching creation function, as
        # the caching creation function may ignore the flag.
        if extra_plugin:
            ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
            haldriver = create_hal_driver(device)

        # No plugin requested: create the driver with the caching create
        # function.
        else:
            haldriver = ireert.get_driver(device)
ScottTodd commented 1 week ago

CI failure is preexisting and unrelated to this change, merging through it. https://github.com/iree-org/iree/actions/runs/9512652449/job/26221447494?pr=17659

[ RUN      ] NCCLDynamicSymbolsTest.CreateFromSystemLoader
iree/runtime/src/iree/hal/drivers/hip/dynamic_symbols_test.cc:92: Failure
Expected equality of these values:
  21803
  nccl_version
    Which is: 21806