ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.42k stars 937 forks source link

boolean mask or filter? #246

Open wjessup opened 8 months ago

wjessup commented 8 months ago

Seems like all methods return arrays of same length. Need to be able to filter and return an array with less elements than what went in. Whats the best way to do so now?

awni commented 8 months ago

Can you share more details on your use case?

Indeed boolean indices are not well supported because we don't know the values when you make the call, and hence we don't know the size of the output. We've found in almost all cases there is a suitable workaround but would be great to understand your in a little more detail.

wjessup commented 8 months ago

Sure the code I was replace was from a DQN with replay memory. We need to select only the non terminated cases to make predictions, so:

` terminated_batch = torch.tensor([False, True, False]) #did it terminate or not? state_batch = torch.tensor([[.3, .2], [.1, .4], [.6, .7]]) only_falses = state_batch[~terminated_batch] # can use mx.logical_not(terminated_batch) in mlx. print("only falses = ", only_falses)

only falses = tensor([[0.3000000119, 0.2000000030], [0.6000000238, 0.6999999881]])

`

awni commented 8 months ago

Got it, and what do you do with only_falses?

If it's not too costly you can use a mask instead of selecting the subset?

I see it would be nice to support boolean indexing and there are some options for how we can do it (some harder than others). Anything you can add about how that computation with only_falses goes would be helpful to understand the trade-offs!

wjessup commented 8 months ago

For simplicity I didn't show the part where we're selecting random samples from the replay memory and instead just showing example tensors.

The idea here is that in the replay memory batch, some of the memories lead to failed episodes and some don't.

We want to filter out the episodes that failed, and then send the non failed episodes into the DQNet to get the state+1 value predictions. No reason to run predictions if they failed, so need a way to filter.

terminated_batch = torch.tensor([False, True, False]) # did it terminate or not?
state_batch = torch.tensor([[.3, .2], [.1, .4], [.6, .7]])
non_final_states = state_batch[~terminated_batch] # can use mx.logical_not(terminated_batch) in mlx.

next_state_values = torch.zeros(BATCH_SIZE=3, device=device).unsqueeze(1)
next_state_values[~terminated_batch] = target_net(non_final_states).max(1).values # target_net is the deep Q network
print(next_state_values)

[{some prediction}, 0, {some prediction}]

awni commented 8 months ago

Got it. We need boolean masking for this. But it looks like you don't need to differentiate / transform through this line:

non_final_states = state_batch[~terminated_batch] # can use mx.logical_not(terminated_batch) in mlx.

correct?

For now, a simple option is to cast to numpy get the states and cast back:

inputs = mx.array(np.array(state_batch)[mx.logical_not(terminated_batch)])
wjessup commented 8 months ago

no need to differentiate.

the numpy method looks good. will try it out!

mzbac commented 8 months ago

I encountered the similar issue when trying to implement top_p for the LLM sampling. I got an error of boolean indices not supported, for example:

    sorted_probs = mx.sort(probs)[::-1]
    sorted_indices = mx.argsort(probs)[::-1]
    cumulative_probs = mx.cumsum(sorted_probs)

    mask = cumulative_probs <= p
    sorted_indices_to_keep = sorted_indices[mask] # error on here

    remaining_probs = probs[sorted_indices_to_keep]
    remaining_probs /= mx.sum(remaining_probs)

Given the sampling rate per token generation, I am not sure if casting back to np and then back to mx array is an efficient approach. wondering if there is a better way to support that.

awni commented 8 months ago

It's unlikely going to Numpy will slow things down much for just that step.

But you may be able to get this to work using where:

top_probs = mx.where(cumulative_probs > p, sorted_probs, zeros)
sorted_tok = mx.random.categorical(mx.log(top_probs))
tok = sorted_indices[sorted_tok]
mzbac commented 8 months ago

"where` works great! Thanks for the advice. In case anyone is having trouble with the top_p implementation, here is my solution.

    def sample(logits):
        if temp == 0:
            return mx.argmax(logits, axis=-1)
        else:
            probs = mx.softmax(logits / temp, axis=-1)

            sorted_probs = mx.sort(probs)[::-1]
            sorted_indices = mx.argsort(probs)[::-1]
            cumulative_probs = mx.cumsum(sorted_probs, axis=-1)

            top_probs = mx.where(
                cumulative_probs > 0.95, sorted_probs, mx.zeros_like(sorted_probs) ## TODO hardcoded P for now
            )
            sorted_tok = mx.random.categorical(mx.log(top_probs))
            tok = sorted_indices.squeeze(0)[sorted_tok]
            return tok
arnoldnyan commented 5 months ago

Hi @awni, what if I need to call boolean mask or filter in C++?

awni commented 5 months ago

You can use where(cond, x, y) in c++. But maybe you can you give an example of what you are trying to do?

arnoldnyan commented 5 months ago

Suppose I have an array with bools as mask and want to filter out another array in the same shape. No need for differentiable. Just similar to this use case: https://github.com/ml-explore/mlx/issues/246#issuecomment-1868335885. But in c++, we cannot convert to numpy.

awni commented 5 months ago

I recommend trying to find an equivalent computation that does not require you to have an array whose shape depends on another array's data. Based on your comment I can't help you much because I don't know the full computation.

But usually this is doable with a combination of masks and mx.where and some other operations.

The way to look at it is find the point in the computation which you need to use the boolean mask and then the next point where you know the output shape definitively (i.e. it does not depend on input data). Then try to replace that stretch of the compute with something that uses masks. (Feel free to share it here, maybe we can help).