datadreamer-dev / DataDreamer

DataDreamer: Prompt. Generate Synthetic Data. Train & Align Models.   🤖💤
https://datadreamer.dev
MIT License
724 stars 39 forks source link

Assertion error when passing single device to VLLM #31

Open kaifronsdal opened 3 days ago

kaifronsdal commented 3 days ago

When trying to create a VLLM instance on a single device, datadreamer throws an assertion error:

File ******/lib/python3.11/site-packages/datadreamer/llms/vllm.py:73, in VLLM.model(self)
     70 @cached_property
     71 def model(self) -> Any:
     72     env = os.environ.copy()
---> 73     assert isinstance(self.device, list)
     74     env.update(get_device_env_variables(self.device))
     75     kwargs = self.kwargs.copy()

AssertionError:

This seems to be a consequence of the last if statement in validate_device in datadreamer/utils/device_utils.py where it strips singleton arrays.

def validate_device(
    device: None | int | str | torch.device | list[int | str | torch.device],
) -> None | int | str | torch.device | list[int | str | torch.device]:
    if isinstance(device, list):  # pragma: no cover
        use_cpu_as_backup, true_device_ids = get_true_device_ids(device)
        if len(true_device_ids) == 0:
            if use_cpu_as_backup:
                device = "cpu"
            else:
                raise RuntimeError(
                    f"The device list you specified ({device}) could not be found on this system."
                )
    if isinstance(device, list) and len(device) == 1:  # pragma: no cover
        device = device[0]
    return device