bigscience-workshop / petals

🌸 Run LLMs at home, BitTorrent-style. Fine-tuning and inference up to 10x faster than offloading
https://petals.dev
MIT License
8.89k stars 490 forks source link

Petals doesn't deal with server failure properly #587

Open oldcpple opened 6 days ago

oldcpple commented 6 days ago

Hi there, we'd like to report our findings on testing Petals' availability of fault tolerance.

We note that the current implementation of the method step in the class _ServerInferenceSession from _inferencesession.py contains the following content:

if self.history is None:
    self.history = inputs
elif self.history.shape[1] == self._position:
    self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)

assert self.history.shape[1] == self._position + n_input_tokens, (
    f"Broken input cache: span={self.span} shape={self.history.shape} "
    f"position={self._position} n_input_tokens={n_input_tokens}"
)

where the attributes self.history and self.position are initialized as None and 0 respectively when an object of the class _ServerInferenceSession is created. The problem is, when a server fails, Petals replaces it with another server that serves the same blocks. However, the new server session is just initialized when joining the inference session, and its attribute position is 0.

In the method __updatesequence of the class InferenceSession, the new server session's history will be assigned the history of the failed server session: updated_sessions[0].history = self._server_sessions[server_idx].history And during the inference, _n_inputtokens will always be 1. Thus, the assertion: assert self.history.shape[1] == self._position + n_input_tokens is always likely to throw exceptions of "Broken input cache".

One possible solution is described as follow: Delete the assert statement so that no exception will be thrown during the recovery process. Then, change the last few lines of code of method step to:

        self._position += n_input_tokens
        s1 = outputs[0].shape[1]
        if self.recover:
            return outputs[0][0:1, s1 - 1:s1, :]
            #return outputs[0][0][outputs[0].shape[1] - 1]
        return outputs[0]

In which the self.recover is a newly difined attribute of class _ServerInferenceSession, representing whether or not this server session is to recover from a failed one, initialized as False, and will be set to True in the method __updatesequence. This change is to tackle the problem that: when simply delete the assert statement, the returned value of outputs[0] in the recoverd session will be a tensor of shape [1, (num of it's history inputs), 8192] instead of expected [1,1,8192].

By testing tens of examples, we believe this change work properly when dealing with server failures. The final outputs in case some server fail, are the same as the ones where no server fails.

Please check if there are such problems. Many thanks.