JayZhang42 / FederatedGPT-Shepherd

Shepherd: A foundational framework enabling federated instruction tuning for large language models
https://arxiv.org/pdf/2305.05644.pdf
Apache License 2.0
199 stars 31 forks source link

The uploaded model appears to be an untrained version #8

Open mid2doubao opened 6 months ago

mid2doubao commented 6 months ago

In initiate_local_training, self_params_dict_new is recorded, but these parameters are already detach(), and only the transient parameter state is recorded. Therefore new_adapter_weight, which is saved to the appropriate path in terminate_local_training for aggregation, is an untrained parameter.

def initiate_local_training(self):
        self.model.config.use_cache = False
        self.params_dict_old = copy.deepcopy(
            OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
                        "default" in name))
        self.params_dict_new = OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
                                           "default" in name)
        self.model.state_dict = (
            lambda instance, *_, **__: get_peft_model_state_dict(
                instance, self.params_dict_new, "default"
            )
        ).__get__(self.model, type(self.model))
def terminate_local_training(self, epoch, local_dataset_len_dict, previously_selected_clients_set):

        local_dataset_len_dict[self.client_id] = len(self.local_train_dataset)
        new_adapter_weight = self.model.state_dict()
        single_output_dir = os.path.join(self.output_dir, str(epoch), "local_output_{}".format(self.client_id))
        os.makedirs(single_output_dir, exist_ok=True)
        torch.save(new_adapter_weight, single_output_dir + "/pytorch_model.bin")

        older_adapter_weight = get_peft_model_state_dict(self.model, self.params_dict_old, "default")
        set_peft_model_state_dict(self.model, older_adapter_weight, "default")
        previously_selected_clients_set = previously_selected_clients_set | set({self.client_id})
        last_client_id = self.client_id

        return self.model, local_dataset_len_dict, previously_selected_clients_set, last_client_id
mid2doubao commented 6 months ago

@JayZhang42