facebook / Ax

Adaptive Experimentation Platform
https://ax.dev
MIT License
2.37k stars 308 forks source link

[FEATURE REQUEST]: Allow to specify the trial_index when calling ax_client.attach_trial() #2852

Open ricvolpi opened 1 week ago

ricvolpi commented 1 week ago

Motivation

Often we already have an ID for the experiment we are attaching. It can be convenient to use that. For example, if parameters come from a database.

Describe the solution you'd like to see implemented in Ax.

Being able to call ax_client.attach_trial(parameters=MY_PARAMETERS, index=MY_INDEX)

Describe any alternatives you've considered to the above solution.

def attach_custom_trial(ax_client, parameters, index):
    custom_trial = Trial(experiment=ax_client.experiment, index=index)
    custom_trial = custom_trial.add_arm(Arm(parameters=parameters, name=f'{index}_0'))
    custom_trial.mark_running(no_runner_required=True)
    return custom_trial

By inspecting ax_client.experiment.trials, this solution seems to create a Trial identical to the one created by running ax_client.attach_trial(parameters), the only differences are the Trial's index and the Arm's name (which is my desired behavior). Still beed to ensure everything work properly though.

Is this related to an existing issue in Ax or another repository? If so please include links to those Issues here.

No response

Code of Conduct

CristianLara commented 1 week ago

Hello @ricvolpi thank you for the feature request.

This would deviate from our existing behavior where trial indices are continuous integer ranges. Could you share more details on the specific benefits of controlling the trial index value?

ricvolpi commented 1 week ago

hey @CristianLara, thanks for following up.

If initial points (i.e. training set) are from a DB of experiments, where each has generally its own ID, and we want to save the new suggested parameters (sampled from the posterior) to the same DB, or a DB with the same structure, it comes handy to use the same trial ID as the experiment ID. I'm successfully using the above solution in my experiments.

lena-kashtelyan commented 5 days ago

@ricvolpi , @CristianLara , I think a good solution here might be to allow setting the arm names during attach_trial? @CristianLara is right that Ax doesn't allow customizing trial indices –– they are always assigned internally within Ax, for many good reasons. However, customizing arm names is done quite commonly and will work here if that is a satisfactory solution.

ricvolpi commented 4 days ago

@lena-kashtelyan thanks for following up on this.

That would definitely be a solution to my problem. Is there any function that allows retrieving the trial index from the arm name? That's simple to implement by looping over ax_client.experiment.trials, but I'm wondering if there's anything like ax_client.get_trial_id_from_arm_name(ARM_NAME). In order to call it before completing trials, etc.

mpolson64 commented 1 day ago

Unfortunately no, but you're right that it should be very straightforward to implement. Something like the following should do the trick.

def foo(client: AxClient, arm_name: str) -> int:
    for i, trial in client.experiment.trials.items():
        if trial.arm.name == arm_name:
            return i

    raise KeyError("bar")
ricvolpi commented 2 hours ago

I have this working on my code, so feel free to close in case this is not interesting for others. In case it can be, happy to help.