learning-at-home / hivemind

Decentralized deep learning in PyTorch. Built to train models on thousands of volunteers across the world.
MIT License
2.03k stars 167 forks source link

Fine-tuning BERT on GLUE with hivemind #443

Closed elricwan closed 2 years ago

elricwan commented 2 years ago

Describe the bug while running hivemind albert experiment, we have one monitor peer and two worker peers. One of the nodes is working fine

But the other peer is stack at downloading parameters from peer. The reason I guess the is the speed of training. If first node train too fast then the other node cannot join, stuck in the download parameter. Can we limit the training speed or force the first node to wait others to join?

To Reproduce If applicable, please create a minimal script that reproduces the problem for you. It would be great to include script outputs as well.

If we change the albert to bert in the example, the speed for each iteration would be faster, then the new worker cannot join the training.

Environment Please list:

borzunov commented 2 years ago

Hi @elricwan!

I am afraid you're using an old version of hivemind (not really 1.1.0.dev0) and an old version of the albert example. That's because the current version have different log format (like that) and prints different outputs (like that).

Can you please upgrade hivemind to the latest version and try again? Maybe something is wrong with your environment or repository, so please try installing hivemind from pip (or from this repo) once again.

The version with the old log format is currently 4+ months old, and we've rewritten a large part of our code during this time.

borzunov commented 2 years ago

In case of pip, please try pip install --upgrade hivemind

elricwan commented 2 years ago

thank you for the quick response, I updated to the new version, and here is the log:

an 10 16:06:25.529 [INFO] Peer is out of sync | 2402/266628 [00:15<19:02, 231.33it/s] Jan 10 16:06:25.533 [INFO] Downloading parameters from peer QmfUbz3GUdCfjjNcXNyWAShHETS3HzFPJtnAG41XsHs4jo Jan 10 16:06:28.686 [INFO] Finished downloading state from QmfUbz3GUdCfjjNcXNyWAShHETS3HzFPJtnAG41XsHs4jo Jan 10 16:06:29.992 [INFO] bert_experiment accumulated 0 samples for epoch #367 from 1 peers. ETA 0.77 sec (refresh in 0.50 sec) Jan 10 16:06:30.001 [INFO] Step #362 Jan 10 16:06:30.001 [INFO] Your current contribution: 0 samples Jan 10 16:06:30.001 [INFO] Performance: 0.000 samples/sec Jan 10 16:06:30.001 [INFO] Local loss: 9.23920 {'loss': 9.3288, 'learning_rate': 0.000127424, 'epoch': 0.03} Jan 10 16:06:30.077 [INFO] Peer is out of sync | 2403/266628 [00:19<19:02, 231.33it/s] Jan 10 16:06:30.080 [INFO] Downloading parameters from peer QmfUbz3GUdCfjjNcXNyWAShHETS3HzFPJtnAG41XsHs4jo Jan 10 16:06:33.525 [INFO] Finished downloading state from QmfUbz3GUdCfjjNcXNyWAShHETS3HzFPJtnAG41XsHs4jo Jan 10 16:06:34.772 [INFO] bert_experiment accumulated 24 samples for epoch #371 from 1 peers. ETA 0.72 sec (refresh in 0.50 sec) | 2403/266628 [00:24<19:02, 231.33it/s] Jan 10 16:06:34.781 [INFO] Step #367 Jan 10 16:06:34.781 [INFO] Your current contribution: 0 samples Jan 10 16:06:34.781 [INFO] Performance: 0.000 samples/sec Jan 10 16:06:34.781 [INFO] Local loss: 9.32880 {'loss': 9.383, 'learning_rate': 0.00012918400000000002, 'epoch': 0.03} Jan 10 16:06:34.854 [INFO] Peer is out of sync

elricwan commented 2 years ago

In this example, the speed of the first peer is 8.5 iter/s, I think it is too fast for second peer to join. The target_batch_size I choose is 64, and the per_device_train_batch_size is 4.

If I change the per_device_train_batch_size to 16 and target_batch_size to 256, then the second peer is able to join.

borzunov commented 2 years ago

Thanks for providing the numbers and the logs!

You're right, the 1st peer is doing steps too quickly. When the 2nd peer finishes downloading the current model state, this state is already considered too old (the 1st peer have already made several steps after that). So, the 2nd peer starts to download it again.

I think target_batch_size = 64 is just too small for the distributed training with hivemind to start making sense. The "goal" of the collaboration is to accumulate target_batch_size as fast as possible and make an optimizer step (without waiting for lagging peers). In this case, it's faster for the 1st peer to accumulate the batch by itself than to synchronize its state with others and wait for their help.

Even if the 2nd peer had managed to download the initial state, it won't be able to synchronize the state efficiently later (synchronizing the state is at least 2x slower than downloading the initial one). Thus, the current behavior is correct from this point of view.

hivemind.Optimizer is developed for training with large batches (e.g., default target_batch_size is 4096 in the ALBERT example). We focus on large-batch training because it's optimal for training over relatively slow networks, such as the Internet (so the peers don't need to synchronize the state too often).

The large-batch training is common for large transformers. Usually, you can compensate the less frequent optimizer steps with a larger learning rate (you can jump farther because the larger batch gives a more accurate direction). We usually use the LAMB paper to choose the correct learning rate for a specific batch size (see Table 4 for BERT) - that's how we chose the learning rate for the example.

elricwan commented 2 years ago

I see, thank you for the quick response. I was trying to do GLUE experiment with hivemind, that's why I reduce the target batch size. It would great if you could add some examples about it when you have time. Thank you.

elricwan commented 2 years ago

Besides, I have a quick question about optimizer. When I define opt as example shows:

    opt = lambda params: Lamb(
        params,
        lr=training_args.learning_rate,
        betas=(training_args.adam_beta1, training_args.adam_beta2),
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        clamp_value=training_args.clamp_value,
        debias=True,
    )

    no_decay = ["bias", "LayerNorm.weight"]
    params = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    scheduler = lambda opt: get_linear_schedule_with_warmup(
        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
    )

And I set use_local_updates=False, as shown below:

optimizer = Optimizer(
        dht=dht,
        run_id=collaboration_args.experiment_prefix,
        target_batch_size=adjusted_target_batch_size,
        batch_size_per_step=total_batch_size_per_step,
        optimizer=opt,
        params=params,
        scheduler=scheduler,
        matchmaking_time=collaboration_args.matchmaking_time,
        averaging_timeout=collaboration_args.averaging_timeout,
        use_local_updates=False,
        offload_optimizer=True,
        delay_optimizer_step=True,
        delay_grad_averaging=True,
        client_mode=collaboration_args.client_mode,
        grad_compression=Float16Compression(),
        state_averaging_compression=Float16Compression(),
        averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
        tracker_opts=asdict(tracker_args),
        verbose=True,
    )

Then the model cannot learn, the glue task mrpc evaluation result is shown below:

***** eval metrics *****
  epoch                   =       5.99
  eval_accuracy           =     0.6838
  eval_combined_score     =      0.748
  eval_f1                 =     0.8122
  eval_loss               =     0.6513
  eval_runtime            = 0:00:01.11
  eval_samples            =        408
  eval_samples_per_second =    365.517
  eval_steps_per_second   =     91.379
Jan 12 13:49:18.235 [INFO] Sending goodbye to peers...
Jan 12 13:49:18.235 [INFO] No longer reporting progress for glue_experiment
Jan 12 13:49:18.235 [INFO] No longer fetching glue_experiment_progress
Jan 12 13:49:18.236 [INFO] Waiting for delayed updates to finish...
Jan 12 13:49:18.310 [INFO] Shutting down averagers...
Jan 12 13:49:18.472 [INFO] Optimizer is shut down

However, when I define my own optimizer:

    no_decay = ["bias", "LayerNorm.weight"]
    params = [
        {
            "params": [p for n, p in model.named_parameters() if n in no_decay],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if n not in no_decay],
            "weight_decay": 0.0,
        },
    ]

I must set use_local_updates=True, as below:

    optimizer = Optimizer(
        dht=dht,                                                  
        run_id=collaboration_args.experiment_prefix,              
        batch_size_per_step=total_batch_size_per_step,            
        target_batch_size=adjusted_target_batch_size,             
        optimizer=opt,                                            
        scheduler=scheduler,                                      
        use_local_updates=True,                                   
        matchmaking_time=collaboration_args.matchmaking_time,     
        averaging_timeout=collaboration_args.averaging_timeout,   
        verbose=True                                              
    )

Then the model works, as shown below:

  epoch                   =       5.99
  eval_accuracy           =     0.8456
  eval_combined_score     =     0.8704
  eval_f1                 =     0.8952
  eval_loss               =     0.7145
  eval_runtime            = 0:00:01.00
  eval_samples            =        408
  eval_samples_per_second =    407.919
  eval_steps_per_second   =     101.98
Jan 12 13:55:14.625 [INFO] Sending goodbye to peers...
Jan 12 13:55:14.626 [INFO] No longer fetching glue_experiment_progress
Jan 12 13:55:14.626 [INFO] No longer reporting progress for glue_experiment
Jan 12 13:55:14.626 [INFO] Shutting down averagers...
Jan 12 13:55:14.678 [INFO] Optimizer is shut down

So my question is, in regular cases, should we define our own optimizer and set use_local_updates to be true? Or do I miss anything for the first optimizer setup that make my model not working? Thank you!

justheuristic commented 2 years ago

Hi!

Q: In regular cases, should we define our own optimizer and set use_local_updates to be true?

Here's a rule of thumb that worked for me so far:

A) If your training task can be tuned for very large batch sizes, its best to set use_local_updates=False and train with a sufficiently large target_batch_size. In this mode, peers (1) accumulate gradients over that batch size, then (2) run all-reduce and average their gradients, and finally, (3) perform optimizer step using globally averaged gradients. This is equivalent to regular data-parallel training and hence easier to tune.

By "tuned for very large batch sizes" i mean reducing the number of optimizer steps, increasing the learning rate, etc. so as to train in a relatively few optimizer steps with very large batches. Since you're using LAMB, you can use the tables from the original paper to figure out how to scale up.

For instance, this is a table for training BERT from scratch image

B) If for some reason your task needs small batch sizes but a very large number of updates, you can set use_local_updates=False. In that case, each peer will run optimizer steps on their local gradients and only average model parameters after a certain number of iterations. This is equivalent to asynchronous decentralized SGD and is typically more difficult to tune.

Q: do I miss anything for the first optimizer setup that make my model not working?

If target_batch_size=64 and per_device_train_batch_size=4, use_local_updates=False will make 16 times less optimizer steps. My best guess so far is that your learning rate, warmup, etc. were tuned for small batch size of use_local_updates=True and did not scale well to target_batch_size=64

However, this is only an educated guess. I can make a more accurate prediction if you provide the full code.

Hello, guys, I have a quick question regarding the Optimizer in hivemind, when I set use_local_updates=True, it works fine. But when I set use_local_updates=False with lambda Optimizer, the lr is always 0 and the model cannot learn

That is quite curious. Can you please elaborate on the learning rate being 0? Your previous post suggests that the model is still somewhat better than random with use_local_updates=False, (or am I missing something?)

elricwan commented 2 years ago

Hi there, thank you for the quick response. The learning rate is not zero, I print it wrong, sorry about confusion. But the model is not training. Here is my entire code (most part is the same as the example, but I add two more classes as data and model arguments from huggingface github): In the run_training_monitor.py file:

import time
from dataclasses import asdict, dataclass, field
from ipaddress import ip_address
from typing import Optional

import requests
import torch
import wandb
from torch_optimizer import Lamb
from transformers import BertConfig, BertForMaskedLM, HfArgumentParser

import hivemind
from hivemind.optim.state_averager import TrainingStateAverager
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)

from datasets import load_dataset, load_metric

import utils
from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

@dataclass
class TrainingMonitorArguments(BaseTrainingArguments):
    """
    Note: You might want to have several initial peers so that if one dies,
    new workers still can join the collaboration via alive initial peers' addresses.
    Specify initial_peers argument for that purpose
    """

    use_google_dns: bool = field(
        default=False,
        metadata={
            "help": "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"
        },
    )
    refresh_period: float = field(default=30, metadata={"help": "Period (in seconds) for fetching the keys from DHT"})
    wandb_project: Optional[str] = field(
        default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
    )
    save_checkpoint_step_interval: int = field(
        default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
    )
    model_config_path: str = field(
        default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
        metadata={"help": "Path to the model config"},
    )
    repo_path: Optional[str] = field(
        default=None, metadata={"help": "Path to local repository to store the model and optimizer states"}
    )
    repo_url: Optional[str] = field(
        default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}
    )
    upload_interval: Optional[float] = field(
        default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
    )
    store_checkpoints: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
    )
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})

    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
        elif self.dataset_name is not None:
            pass
        elif self.train_file is None or self.validation_file is None:
            raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
        else:
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )

class CheckpointHandler:
    def __init__(
        self,
        model_args: ModelArguments,
        data_args: DataTrainingArguments,
        monitor_args: TrainingMonitorArguments,
        optimizer_args: OptimizerArguments,
        averager_args: AveragerArguments,
        dht: hivemind.DHT,
    ):
        self.save_checkpoint_step_interval = monitor_args.save_checkpoint_step_interval
        self.repo_path = monitor_args.repo_path
        self.repo_url = monitor_args.repo_url
        self.upload_interval = monitor_args.upload_interval
        self.previous_step = -1

        # arbitrary set a num_labels value
        num_labels = 3
        config = AutoConfig.from_pretrained(
            model_args.config_name if model_args.config_name else model_args.model_name_or_path,
            num_labels=num_labels,
            finetuning_task=data_args.task_name,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
        )

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.01,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

        opt = Lamb(
            optimizer_grouped_parameters,
            lr=0.00176,
            weight_decay=0.01,
            clamp_value=10000.0,
            debias=True,
        )

        self.state_averager = TrainingStateAverager(
            dht=dht,
            optimizer=opt,
            prefix=experiment_prefix,
            state_compression=hivemind.Float16Compression(),
            bandwidth=optimizer_args.bandwidth,
            client_mode=optimizer_args.client_mode,
            start=True,
            **asdict(averager_args),
        )
        self.previous_timestamp = time.time()

    def is_time_to_save_state(self, cur_step):
        if self.save_checkpoint_step_interval is None:
            return False
        elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
            return True
        else:
            return False

    def save_state(self, cur_step):
        logger.info("Saving state from peers")
        self.state_averager.load_state_from_peers()
        self.previous_step = cur_step

    def is_time_to_upload(self):
        if self.repo_path is None:
            return False
        elif time.time() - self.previous_timestamp >= self.upload_interval:
            return True
        else:
            return False

    def upload_checkpoint(self, current_loss):
        logger.info("Saving optimizer")
        torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt")
        self.previous_timestamp = time.time()
        logger.info("Started uploading to Model Hub")
        self.model.push_to_hub(
            repo_name=self.repo_path,
            repo_url=self.repo_url,
            commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
        )
        logger.info("Finished uploading to Model Hub")

if __name__ == "__main__":
    #parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
    model_args, data_args, monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses()

    if monitor_args.use_google_dns:
        request = requests.get("https://api.ipify.org")
        request.raise_for_status()

        address = request.text
        logger.info(f"Received public IP address of this machine: {address}")
        version = ip_address(address).version
        monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]

    experiment_prefix = monitor_args.experiment_prefix
    validators, local_public_key = utils.make_validators(experiment_prefix)

    dht = hivemind.DHT(
        start=True,
        initial_peers=monitor_args.initial_peers,
        record_validators=validators,
        use_ipfs=monitor_args.use_ipfs,
        host_maddrs=monitor_args.host_maddrs,
        announce_maddrs=monitor_args.announce_maddrs,
        identity_path=monitor_args.identity_path,
    )
    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)

    if monitor_args.wandb_project is not None:
        wandb.init(project=monitor_args.wandb_project)

    current_step = 0
    if monitor_args.store_checkpoints:
        checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)

    while True:
        metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
        if metrics_dict is not None:
            metrics_dict = metrics_dict.value
            metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
            latest_step = max(item.step for item in metrics)

            if latest_step != current_step:
                logger.debug(f"Got metrics from {len(metrics)} peers")

                for i, metrics_for_peer in enumerate(metrics):
                    logger.debug(f"{i} peer {metrics_for_peer}")

                current_step = latest_step
                alive_peers = 0
                sum_loss = 0
                num_samples = 0
                sum_perf = 0
                sum_mini_steps = 0

                for item in metrics:
                    sum_loss += item.loss
                    alive_peers += 1
                    sum_perf += item.samples_per_second
                    num_samples += item.samples_accumulated
                    sum_mini_steps += item.mini_steps
                current_loss = sum_loss / sum_mini_steps
                logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")

                if monitor_args.wandb_project is not None:
                    wandb.log(
                        {
                            "loss": current_loss,
                            "alive peers": alive_peers,
                            "samples": num_samples,
                            "performance": sum_perf,
                            "step": latest_step,
                        }
                    )

                if monitor_args.store_checkpoints:
                    if checkpoint_handler.is_time_to_save_state(current_step):
                        checkpoint_handler.save_state(current_step)
                        if checkpoint_handler.is_time_to_upload():
                            checkpoint_handler.upload_checkpoint(current_loss)
        logger.debug("Peer is still alive...")
        time.sleep(monitor_args.refresh_period)

I run the code with command:

MODEL_PATH=bert-base-uncased
TASK_NAME=MRPC

CUDA_VISIBLE_DEVICES=0 python run_training_monitor.py \
--experiment_prefix glue_experiment \
--wandb_project glue_wandb \
--model_name_or_path ${MODEL_PATH} \
--tokenizer_name bert-base-uncased \
--task_name $TASK_NAME \
--max_seq_length 128

In the run_trainer.py file:

import os
import pickle
import sys
from dataclasses import asdict
from pathlib import Path

import torch
import transformers
from datasets import load_from_disk
from torch.utils.data import DataLoader
from torch_optimizer import Lamb
from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed, AdamW
from transformers import BertForMaskedLM, BertConfig, BertConfig, AutoTokenizer
from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
from transformers.optimization import get_linear_schedule_with_warmup
from transformers.trainer import Trainer
from transformers.trainer_utils import is_main_process

from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

import utils
from arguments import (
    ModelTrainingArguments,
    AveragerArguments,
    CollaborationArguments,
    DatasetArguments,
    ProgressTrackerArguments,
)

import logging
import random
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import numpy as np
from datasets import load_dataset, load_metric

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from transformers.trainer_pt_utils import get_parameter_names
from torch import nn

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)

LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

# from run_glue.py
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
    )
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})

    def __post_init__(self):
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in task_to_keys.keys():
                raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
        elif self.dataset_name is not None:
            pass
        elif self.train_file is None or self.validation_file is None:
            raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
        else:
            train_extension = self.train_file.split(".")[-1]
            assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            validation_extension = self.validation_file.split(".")[-1]
            assert (
                validation_extension == train_extension
            ), "`validation_file` should have the same extension (csv or json) as `train_file`."

# from run_glue.py
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )

def setup_transformers_logging(process_rank: int):
    if is_main_process(process_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.disable_default_handler()
        transformers.utils.logging.enable_propagation()

# def get_model(training_args, config, tokenizer):
#     # Find latest checkpoint in output_dir
#     output_dir = Path(training_args.output_dir)
#     logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
#     latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)

#     if latest_checkpoint_dir is not None:
#         logger.info(f"Loading model from {latest_checkpoint_dir}")
#         #model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
#         model = BertForMaskedLM.from_pretrained(latest_checkpoint_dir)
#     else:
#         logger.info(f"Training from scratch")
#         #model = AlbertForPreTraining(config)
#         model = BertForMaskedLM(config)
#         model.resize_token_embeddings(len(tokenizer))

#     return model

class CollaborativeCallback(transformers.TrainerCallback):
    """
    This callback monitors and reports collaborative training progress.
    In case of a catastrophic failure, it can also revert training to a backup.
    """

    def __init__(
        self,
        dht: DHT,
        optimizer: Optimizer,
        model: torch.nn.Module,
        local_public_key: bytes,
        statistics_expiration: float,
        backup_every_steps: int,
    ):
        super().__init__()
        self.model = model
        self.dht, self.optimizer = dht, optimizer
        self.local_public_key = local_public_key
        self.statistics_expiration = statistics_expiration
        self.last_reported_collaboration_step = -1
        self.samples = 0
        self.steps = 0
        self.loss = 0
        self.total_samples_processed = 0
        self.backup_every_steps = backup_every_steps
        self.latest_backup = self.backup_state()

    def on_train_begin(
        self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
    ):
        logger.info("Loading state from peers")
        self.optimizer.load_state_from_peers()

    def on_step_end(
        self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
    ):
        control.should_log = True
        if not self.params_are_finite():
            self.restore_from_backup(self.latest_backup)
            return control

        local_progress = self.optimizer.local_progress

        if state.log_history:
            self.loss += state.log_history[-1]["loss"]
            self.steps += 1

            if self.optimizer.local_epoch != self.last_reported_collaboration_step:
                self.last_reported_collaboration_step = self.optimizer.local_epoch
                self.total_samples_processed += self.samples
                samples_per_second = local_progress.samples_per_second
                statistics = utils.LocalMetrics(
                    step=self.optimizer.local_epoch,
                    samples_per_second=samples_per_second,
                    samples_accumulated=self.samples,
                    loss=self.loss,
                    mini_steps=self.steps,
                )
                logger.info(f"Step #{self.optimizer.local_epoch}")
                logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
                if self.steps:
                    logger.info(f"Local loss: {self.loss / self.steps:.5f}")
                if self.optimizer.local_epoch % self.backup_every_steps == 0:
                    self.latest_backup = self.backup_state()

                self.loss = 0
                self.steps = 0
                if self.optimizer.is_synchronized_with_peers():
                    self.dht.store(
                        key=self.optimizer.run_id + "_metrics",
                        subkey=self.local_public_key,
                        value=statistics.dict(),
                        expiration_time=get_dht_time() + self.statistics_expiration,
                        return_future=True,
                    )

        self.samples = local_progress.samples_accumulated

        return control

    @torch.no_grad()
    def params_are_finite(self):
        for param in self.model.parameters():
            if not torch.all(torch.isfinite(param)):
                return False
        return True

    @torch.no_grad()
    def backup_state(self) -> bytes:
        return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})

    @torch.no_grad()
    def restore_from_backup(self, backup: bytes):
        state = pickle.loads(backup)
        self.model.load_state_dict(state["model"])
        self.optimizer.load_state_dict(state["optimizer"])

class NoOpScheduler(LRSchedulerBase):
    """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""

    def get_lr(self):
        return [group["lr"] for group in self.optimizer.param_groups]

    def print_lr(self, *args, **kwargs):
        if self.optimizer.scheduler:
            return self.optimizer.scheduler.print_lr(*args, **kwargs)

    def step(self):
        self._last_lr = self.get_lr()

    def state_dict(self):
        return {}

    def load_state_dict(self, *args, **kwargs):
        logger.debug("Called NoOpScheduler.load_state_dict")

def main():
    parser = HfArgumentParser(
        (
            ModelArguments,
            DataTrainingArguments,
            ModelTrainingArguments,
            CollaborationArguments,
            AveragerArguments,
            ProgressTrackerArguments,
        )
    )
    model_args, data_args, training_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
    # Setup logging, from run_glue.py
    # logging.basicConfig(
    #     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    #     datefmt="%m/%d/%Y %H:%M:%S",
    #     handlers=[logging.StreamHandler(sys.stdout)],
    # )

    # log_level = training_args.get_process_log_level()
    # logger.setLevel(log_level)
    # datasets.utils.logging.set_verbosity(log_level)
    # transformers.utils.logging.set_verbosity(log_level)
    # transformers.utils.logging.enable_default_handler()
    # transformers.utils.logging.enable_explicit_format()

    logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")

    setup_transformers_logging(training_args.local_rank)
    logger.info(f"Training/evaluation parameters:\n{training_args}")

    # Detecting last checkpoint. from run_glue.py; TrainingArguments.output_dir: The output directory where the model predictions and checkpoints will be written.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.
    #
    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.task_name is not None: #  data_args.task_name: 'mrpc'
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
        # max_steps_warmup is user for scheduler
        num_train_dataset = len(raw_datasets['train'])
        max_steps_warmup = ( num_train_dataset//training_args.per_device_train_batch_size ) * training_args.num_train_epochs
    elif data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
        )
    else:
        # Loading a dataset from your local files.
        # CSV/JSON training and evaluation files are needed.
        data_files = {"train": data_args.train_file, "validation": data_args.validation_file}

        # Get the test dataset: you can provide your own CSV/JSON test file (see below)
        # when you use `do_predict` without specifying a GLUE benchmark task.
        if training_args.do_predict:
            if data_args.test_file is not None:
                train_extension = data_args.train_file.split(".")[-1]
                test_extension = data_args.test_file.split(".")[-1]
                assert (
                    test_extension == train_extension
                ), "`test_file` should have the same extension (csv or json) as `train_file`."
                data_files["test"] = data_args.test_file
            else:
                raise ValueError("Need either a GLUE task or a test file for `do_predict`.")

        for key in data_files.keys():
            logger.info(f"load a local file for {key}: {data_files[key]}")

        if data_args.train_file.endswith(".csv"):
            # Loading a dataset from local csv files
            raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir)
        else:
            # Loading a dataset from local json files
            raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
    if data_args.task_name is not None:
        is_regression = data_args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)

    # Load pretrained model and tokenizer, from run_glue.py
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    model.to(training_args.device)

    # Preprocessing the raw_datasets
    if data_args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Padding strategy
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # We will pad later, dynamically at batch creation, to the max sequence length in each batch
        padding = False

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and data_args.task_name is not None
        and not is_regression
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif data_args.task_name is None and not is_regression:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}

    if data_args.max_seq_length > tokenizer.model_max_length:
        logger.warning(
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

        # Map labels to IDs (not necessary for GLUE tasks)
        if label_to_id is not None and "label" in examples:
            result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
        return result

    with training_args.main_process_first(desc="dataset map pre-processing"):
        raw_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            train_dataset = train_dataset.select(range(data_args.max_train_samples))

    if training_args.do_eval:
        if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
        if data_args.max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))

    if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
        if "test" not in raw_datasets and "test_matched" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
        if data_args.max_predict_samples is not None:
            predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))

    # Log a few random samples from the training set:
    if training_args.do_train:
        for index in random.sample(range(len(train_dataset)), 3):
            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Get the metric function
    if data_args.task_name is not None:
        metric = load_metric("glue", data_args.task_name)
    else:
        metric = load_metric("accuracy")

    # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p: EvalPrediction):
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
        if data_args.task_name is not None:
            result = metric.compute(predictions=preds, references=p.label_ids)
            if len(result) > 1:
                result["combined_score"] = np.mean(list(result.values())).item()
            return result
        elif is_regression:
            return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
        else:
            return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

    # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
    if data_args.pad_to_max_length:
        data_collator = default_data_collator
    elif training_args.fp16:
        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
    else:
        data_collator = None

    # set up optimizers
    no_decay = ["bias", "LayerNorm.weight"]
    params = [
        {
            "params": [p for n, p in model.named_parameters() if n in no_decay],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if n not in no_decay],
            "weight_decay": 0.0,
        },
    ]

    #opt = AdamW(params, lr=training_args.learning_rate)

    opt = lambda params: Lamb(
        params,
        lr=training_args.learning_rate,
        betas=(training_args.adam_beta1, training_args.adam_beta2),
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        clamp_value=training_args.clamp_value,
        debias=True,
    )

    training_args.max_steps = int(max_steps_warmup)

    # scheduler = get_linear_schedule_with_warmup(
    #     opt, 
    #     num_warmup_steps=training_args.warmup_steps, 
    #     num_training_steps=training_args.max_steps
    # )

    scheduler = lambda opt: get_linear_schedule_with_warmup(
        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
    )

    validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)

    dht = DHT(
        start=True,
        initial_peers=collaboration_args.initial_peers,
        client_mode=collaboration_args.client_mode,
        record_validators=validators,
        use_ipfs=collaboration_args.use_ipfs,
        host_maddrs=collaboration_args.host_maddrs,
        announce_maddrs=collaboration_args.announce_maddrs,
        identity_path=collaboration_args.identity_path,
    )
    utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)

    total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
    if torch.cuda.device_count() != 0:
        total_batch_size_per_step *= torch.cuda.device_count()

    adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead

    optimizer = Optimizer(
        dht=dht,
        run_id=collaboration_args.experiment_prefix,
        target_batch_size=adjusted_target_batch_size,
        batch_size_per_step=total_batch_size_per_step,
        optimizer=opt,
        params=params,
        scheduler=scheduler,
        matchmaking_time=collaboration_args.matchmaking_time,
        averaging_timeout=collaboration_args.averaging_timeout,
        use_local_updates=False,
        offload_optimizer=True,
        delay_optimizer_step=True,
        delay_grad_averaging=True,
        client_mode=collaboration_args.client_mode,
        grad_compression=Float16Compression(),
        state_averaging_compression=Float16Compression(),
        averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
        tracker_opts=asdict(tracker_args),
        verbose=True,
    )

    # optimizer = Optimizer(
    #     dht=dht,                                                  # use a DHT that is connected with other peers
    #     run_id=collaboration_args.experiment_prefix,              # unique identifier of this collaborative run
    #     batch_size_per_step=total_batch_size_per_step,            # each call to opt.step adds this many samples towards the next epoch
    #     target_batch_size=adjusted_target_batch_size,             # after peers collectively process this many samples, average weights and begin the next epoch 
    #     optimizer=opt,                                            # wrap the SGD optimizer defined above
    #     scheduler=scheduler,                                      # wrap the schedule defined above
    #     use_local_updates=True,                                   # perform optimizer steps with local gradients, average parameters in background
    #     matchmaking_time=collaboration_args.matchmaking_time,     # when averaging parameters, gather peers in background for up to this many seconds
    #     averaging_timeout=collaboration_args.averaging_timeout,   # give up on averaging if not successful in this many seconds
    #     verbose=True,                                             # print logs incessently
    #     client_mode=collaboration_args.client_mode,
    #     grad_compression=Float16Compression(),
    #     state_averaging_compression=Float16Compression(),
    #     averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
    #     tracker_opts=asdict(tracker_args),
    # )

    trainer = Trainer(
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        data_collator=data_collator,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        optimizers=(optimizer, NoOpScheduler(optimizer)),
        callbacks=[
            CollaborativeCallback(
                dht,
                optimizer,
                model,
                local_public_key,
                collaboration_args.statistics_expiration,
                collaboration_args.backup_every_steps,
            )
        ],
    )
    trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
    trainer.remove_callback(transformers.trainer_callback.ProgressCallback)

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.save_model()  # Saves the tokenizer too for easy upload

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # # Training
    # if training_args.do_train:
    #     latest_checkpoint_dir = max(
    #         Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
    #     )

    #     trainer.train(model_path=latest_checkpoint_dir)

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            tasks.append("mnli-mm")
            eval_datasets.append(raw_datasets["validation_mismatched"])

        for eval_dataset, task in zip(eval_datasets, tasks):
            metrics = trainer.evaluate(eval_dataset=eval_dataset)

            max_eval_samples = (
                data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
            )
            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

            trainer.log_metrics("eval", metrics)
            trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        logger.info("*** Predict ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        tasks = [data_args.task_name]
        predict_datasets = [predict_dataset]
        if data_args.task_name == "mnli":
            tasks.append("mnli-mm")
            predict_datasets.append(raw_datasets["test_mismatched"])

        for predict_dataset, task in zip(predict_datasets, tasks):
            # Removing the `label` columns because it contains -1 and Trainer won't like that.
            predict_dataset = predict_dataset.remove_columns("label")
            predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
            predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)

            output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
            if trainer.is_world_process_zero():
                with open(output_predict_file, "w") as writer:
                    logger.info(f"***** Predict results {task} *****")
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
                        if is_regression:
                            writer.write(f"{index}\t{item:3.3f}\n")
                        else:
                            item = label_list[item]
                            writer.write(f"{index}\t{item}\n")

    if training_args.push_to_hub:
        kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
        if data_args.task_name is not None:
            kwargs["language"] = "en"
            kwargs["dataset_tags"] = "glue"
            kwargs["dataset_args"] = data_args.task_name
            kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"

        trainer.push_to_hub(**kwargs)

if __name__ == "__main__":
    main()

I run the code with command:

IP=/ip4/192.168.0.79/tcp/46671/p2p/QmS5yv4bn99eKtJwTGpWSfd9Freq71WYpL24BQd845ZfWG

TASK_NAME=MRPC
DATA_NAME=MRPC
GLUE_DIR=/home/protago/Xiangpeng/hivemind/examples/glue/glue_data
# MODEL_PATH=roberta-base # error???
# MODEL_PATH=distilbert-base-uncased
MODEL_PATH=bert-base-uncased
TOKENIZER_NAME=bert-base-uncased

WANDB_DISABLED=true CUDA_VISIBLE_DEVICES=0 python run_trainer.py \
--experiment_prefix glue_experiment \
--initial_peers $IP \
--logging_first_step \
--logging_dir ./logs \
--save_steps=500 \
--warmup_steps 0 \
--model_name_or_path $MODEL_PATH \
--tokenizer_name $TOKENIZER_NAME \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--dataset_name $GLUE_DIR/$DATA_NAME \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs=3 \
--overwrite_output_dir \
--target_batch_size 256 \
--output_dir ./outputs/$TASK_NAME/

The goal is to evaluate MRPC task with bert. in my case the target batch size is 256 and train batch size is 32,

And the evaluation result is not good, as shown above. on the other hand, the num_training_epoch showed in the logger is 6 but I actually set it to be 3. When I set it to be 2 it would show as 4, a little bit confused.

Anyway, if I switch the optimizer the evaluation result would be good.

mryab commented 2 years ago

Hi @elricwan! I wanted to also ask a more high-level question about your motivation for using distributed training on a finetuning task for bert-base: as said by @borzunov, since the batch sizes you're using are quite small, it might be the case that averaging parameters with hivemind will result in a significant bottleneck. You might be better off running independent experiments for different GLUE tasks, for example.

Regarding the MRPC experiment: have you tested the performance of this finetuning script without using hivemind (ideally just with one GPU) or with just one hivemind peer? This might give us an upper bound on the finetuning performance. Also, I'd look at the variance of the results with respect to the random seed, since the setup might be quite sensitive to it. Finally, make sure that you're restarting the DHT between the experiments: otherwise, you might be observing some side effects.

By the way, could you please also state the environment that you are running these experiments in? For example, at least the number of peers, their GPUs and their connection speeds.

elricwan commented 2 years ago

Sure, my motivation is to kindly test hivemind framework for different training setups, since pre-training task takes more time to finish while finetuning is much faster. I use glue task to test the performance.

Also, I have tested MRPC without hivemind enviroment with one GPU and the result is similar to the: eval_accuracy = 0.84 eval_f1 = 0.88.

The random seed might influence the result, but since the same random seed works for the local optimizer which I get fairly good results.

  1. opt = AdamW(params, lr=training_args.learning_rate)
  2. optimizer = Optimizer( dht=dht,
    run_id=collaboration_args.experiment_prefix,
    batch_size_per_step=total_batch_size_per_step,
    target_batch_size=adjusted_target_batch_size,
    optimizer=opt,
    scheduler=scheduler,
    use_local_updates=True,
    matchmaking_time=collaboration_args.matchmaking_time,
    averaging_timeout=collaboration_args.averaging_timeout,
    verbose=True
    ) I feel like random seed is not the cause. Could the optimizer influence the result?

Also, I restart the DHT for each test, I use two 3090 gpu, one peer, one server, fully connected.

mryab commented 2 years ago

Thanks for clarifying! In that case, I guess the differences between optimization setups should affect the quality somewhat (at least that's my experience with BERT finetuning). However, in general, please note that the intended use of hivemind is training in more resource-intensive scenarios (such as unsupervised pretraining and large-scale supervised learning).

So do I assume correctly that for the working setup, you use AdamW and not LAMB? Also, what is the batch size in this case? You should be able to emulate a single-GPU setup with one peer connected to the DHT.

elricwan commented 2 years ago

Sure, I would use hivemind for pretraining in the future. Yes, I use AdamW for the working setup, the batchsize is also 32. And yes, I use one gpu so far, the machine contains two gpus so I could test two peers too.

mryab commented 2 years ago

In that case, I would suggest to avoid comparing the quality in two different settings: it seems that for the non-hivemind setup, you're using AdamW and a batch of 32 examples, and in the hivemind setup, it's LAMB and 256 examples (judging by target_batch_size). As @justheuristic suggested, it might be the case that LAMB (used in our example for pretraining) is not a good match for your task and hyperparameters, hence the worse results.

To make the experiments comparable, I would suggest to either switch to LAMB and batch of 256 in the vanilla case or to switch to AdamW and a target batch size of 32 in the hivemind case (since you're going to run it on just one peer, you shouldn't see the deadlock from before). Feel free to show us the results from running in the same setup, and then we'll be able to see the differences caused by hivemind.Optimizer

elricwan commented 2 years ago

Just to make sure the question I ask is clear, when I say local optimizer, I mean it is still hivemind optimizer but set use_local_updates=True, in this case, I use target batch size 256, AdamW optimizer and a batch of 32 examples. The result is similar to the upper boundary compared to non-hivemind cases. The setup is shows below:

  1. opt = AdamW(params, lr=training_args.learning_rate)
  2. optimizer = Optimizer( dht=dht, run_id=collaboration_args.experiment_prefix, batch_size_per_step=total_batch_size_per_step, target_batch_size=adjusted_target_batch_size, optimizer=opt, scheduler=scheduler, use_local_updates=True, matchmaking_time=collaboration_args.matchmaking_time, averaging_timeout=collaboration_args.averaging_timeout, verbose=True )

The result is shown below: eval_accuracy = 0.8456 eval_f1 = 0.8952 In non-hivemind case, I also use AdamW optimizer and a batch of 32 examples but without target batch size.

The situation that did not work is when I use the lambda optimizer in hivemind case. It shows as follows:

  1. opt = lambda params: Lamb( params, lr=training_args.learning_rate, betas=(training_args.adam_beta1, training_args.adam_beta2), eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, clamp_value=training_args.clamp_value, debias=True, )
  2. optimizer = Optimizer( dht=dht, run_id=collaboration_args.experiment_prefix, target_batch_size=adjusted_target_batch_size, batch_size_per_step=total_batch_size_per_step, optimizer=opt, params=params, scheduler=scheduler, matchmaking_time=collaboration_args.matchmaking_time, averaging_timeout=collaboration_args.averaging_timeout, use_local_updates=False, offload_optimizer=True, delay_optimizer_step=True, delay_grad_averaging=True, client_mode=collaboration_args.client_mode, grad_compression=Float16Compression(), state_averaging_compression=Float16Compression(), averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)}, tracker_opts=asdict(tracker_args), verbose=True, )

Now I would like to change lambda optimizer from Lamb to AdamW, in that case, I would write:

  1. opt = lambda params: AdamW( params, lr=training_args.learning_rate, )
  2. scheduler = lambda opt: get_linear_schedule_with_warmup( opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps )
  3. optimizer = Optimizer( dht=dht, run_id=collaboration_args.experiment_prefix, target_batch_size=adjusted_target_batch_size, batch_size_per_step=total_batch_size_per_step, optimizer=opt, params=params, scheduler=scheduler, matchmaking_time=collaboration_args.matchmaking_time, averaging_timeout=collaboration_args.averaging_timeout, use_local_updates=False, offload_optimizer=True, delay_optimizer_step=True, delay_grad_averaging=True, client_mode=collaboration_args.client_mode, grad_compression=Float16Compression(), state_averaging_compression=Float16Compression(), averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)}, tracker_opts=asdict(tracker_args), verbose=True, )

And the evaluation result is still not good, as shown below: eval_accuracy = 0.6838 eval_f1 = 0.8122

mryab commented 2 years ago

I see; judging by the docstring of use_local_updates, it seems to be the case that you are actually using batches of 32 examples for gradient descent steps when the option is set to True. However, when disabling it, the optimizer sticks to the target batch size and makes optimizer steps for batches of 256 examples: so, both of your underperforming runs are using 8 times larger batches than the baseline.

I guess the problem then lies either in the fact that your LR needs to be scaled with the batch size or in generally worse performance for MRPC with batches this large.

elricwan commented 2 years ago

I see, that make sense, I would do more tests to confirm, thank you!

Besides, in this example, I set the num_train_epochs to be 3, however, it trained six epochs. Shown as below: Jan 12 18:24:15.824 [INFO] Num Epochs = 6 If I set num_train_epochs to be 2, it would train 4 epochs, any idea of why that happened? Thanks

mryab commented 2 years ago

I set the num_train_epochs to be 3, however, it trained six epochs

No idea as of now, since training with hivemind should not interfere with your training loop. Does this bug disappear when removing all mentions of hivemind from the script?

elricwan commented 2 years ago

I would try and fix it. Thank you for your patience!

borzunov commented 2 years ago

I've renamed the issue to better reflect the topic of the discussion. Feel free to rename it again if I got something wrong :)

elricwan commented 2 years ago

Hi there, is there a place to add gradient clip in hivemind optimizer? In regular training, we could simply use:

from torch.nn.utils import clip_gradnorm clip_gradnorm(model.parameters(), self.args.max_clip_norm)

Could we use this in hivemind optimizer? and how?

Thank you.

justheuristic commented 2 years ago

Hi! In order to 100% match training on a single machine, you will need to clip global gradients right before the optimizer step The easiest way to do that is to wrap the optimizer. For instance,

import torch
from torch_optimizer import Lamb

class LambWithGradientClipping(Lamb):
    """A version of LAMB that clips gradients based on their norm."""

    def __init__(self, *args, max_grad_norm: float, **kwargs):
        self.max_grad_norm = max_grad_norm
        super().__init__(*args, **kwargs)

    def step(self, *args, **kwargs):
        iter_params = (param for group in self.param_groups for param in group["params"])
        torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
        return super().step(*args, **kwargs)

p.s. we will probably benefit from having some kind of "recipe book" with these things in the docs

elricwan commented 2 years ago

Got it, thank you!

elricwan commented 2 years ago

I tried to implement the clip norm. My code looks like:

class LambWithGradientClipping(Lamb):

    def __init__(self, *args, max_grad_norm=1.0, **kwargs):
        self.max_grad_norm = max_grad_norm
        super().__init__(*args, **kwargs)

    def step(self, *args, **kwargs):
        iter_params = (param for group in self.param_groups for param in group["params"])
        torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
        return super().step(*args, **kwargs)

lambGrad = LambWithGradientClipping(Lamb)

opt = lambda params: lambGrad(
        params,
        lr=training_args.learning_rate,
        betas=(training_args.adam_beta1, training_args.adam_beta2),
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        clamp_value=training_args.clamp_value,
        debias=True,
    )

But I got the error:

File "run_trainer.py", line 270, in main lambGrad = LambWithGradientClipping(Lamb) File "run_trainer.py", line 70, in init super().init(*args, **kwargs) File "/home/protago/miniconda3/envs/hivemind/lib/python3.8/site-packages/torch_optimizer/lamb.py", line 80, in init super(Lamb, self).init(params, defaults) File "/home/protago/miniconda3/envs/hivemind/lib/python3.8/site-packages/torch/optim/optimizer.py", line 47, in init param_groups = list(params) TypeError: 'type' object is not iterable.

Could you tell me how to make it right? Thank you! (I am looking forward to the recipe book)

borzunov commented 2 years ago

@elricwan Please try this code instead:

class LambWithGradientClipping(Lamb):
    def __init__(self, *args, max_grad_norm=1.0, **kwargs):
        self.max_grad_norm = max_grad_norm
        super().__init__(*args, **kwargs)

    def step(self, *args, **kwargs):
        iter_params = (param for group in self.param_groups for param in group["params"])
        torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
        return super().step(*args, **kwargs)

opt = lambda params: LambWithGradientClipping(
    params,
    lr=training_args.learning_rate,
    betas=(training_args.adam_beta1, training_args.adam_beta2),
    eps=training_args.adam_epsilon,
    weight_decay=training_args.weight_decay,
    clamp_value=training_args.clamp_value,
    debias=True,
)

Here, the LambWithGradientClipping class is inherited from the Lamb class. So, you should call the constructor of the new class directly, there is no need to write things like LambWithGradientClipping(Lamb).

Also, for the sake of convenience, could you please use the code formatting blocks when pasting the logs and code to the issue?

For the logs, please use:

```
Insert the logs here
```

For the Python code, please use:

```python
# Python code here
```

The latter will give you syntax highlighting, as is the case for my code above :)

elricwan commented 2 years ago

Sure, I would use the code formatting blocks when pasting the logs in the future. And it works, thank you!

elricwan commented 2 years ago

Hi there, if I would like to add one more peer to boost training speed in the middle of the task, how could I make sure that the new added peer has the same or similar learning rate as previous peers? Since we use Scheduler so the learning rate changes over time, is there a quick way to do that? Thank you!

justheuristic commented 2 years ago

Hi! TL;DR use the same hyperparameters and make sure that learning rate scheduler is fed into optimizer image (source)

The optimizer guarantees that peers are on the same epoch of their scheduler. If a new peer joins, it checks how it's local epoch is different from others. If others have epoch 100 while you've just started at epoch 0 (or any smaller), your peer will download state from others -- and update the learning rate.

Here's what happens under the hood:

  1. Typically, learning rate is a function of number of steps - measured in global batches (target_batch_size examples, also referred to as epochs) For instance, this is how learning rate is defined in examples/albert [1] and in CALM [2]. Hence, if you know current epoch, you know current learning rate

  2. Each time hivemind.Optimizer reports new progress, it also reports its current iteration and checks what other peers are reporting. If they report a different (higher) epoch, you detect that and automatically .load_state_from_peers. This is where it happens

  3. When you download state from a peer, assuming everything went successfully, you also set your epoch to whatever epoch that other peer had. So if you download state from someone who was at epoch 150, you are now also at epoch 150. This is where this happens

elricwan commented 2 years ago

Hi,

Thank you for the informative response! but I am not sure if I understand it correctly. From my understanding, the hivemind would check the global epoch for running peers. Since each peer may start with a different total training epochs. For instance, one peer could start by training total 3 epochs while others may start by training total 1 epoch, as shown below. And apparently the total training epochs affect the scheduler and learning rate. In that case, the hivemind would take the longest training epochs as global epochs? How hivemind make sure those two schedule with different training epochs ends up with the same learning rate? ... IP=/ip4/192.168.1.13/tcp/33495/p2p/QmdySuj4wSpQ9cjucJ9KVUEb2gJaq4enTkoSx21xibT16o

WANDB_DISABLED=true CUDA_VISIBLE_DEVICES=0 python run_trainer.py \ --experiment_prefix bert_experiment \ --initial_peers $IP \ --logging_first_step \ --config_path='bert-base-uncase' \ --output_dir ./outputs \ --overwrite_output_dir \ --logging_dir ./logs \ --do_eval=False \ --dataset_path="data" \ --warmup_steps 5000 \ --per_device_train_batch_size 16 \ --learning_rate 0.00176 \ --num_train_epochs=3 \ --overwrite_output_dir \ --target_batch_size 4096 \ --save_steps=10000

IP=/ip4/192.168.1.13/tcp/33495/p2p/QmdySuj4wSpQ9cjucJ9KVUEb2gJaq4enTkoSx21xibT16o

WANDB_DISABLED=true CUDA_VISIBLE_DEVICES=0 python run_trainer.py \ --experiment_prefix bert_experiment \ --initial_peers $IP \ --logging_first_step \ --config_path='bert-base-uncase' \ --output_dir ./outputs \ --overwrite_output_dir \ --logging_dir ./logs \ --do_eval=False \ --dataset_path="data" \ --warmup_steps 5000 \ --per_device_train_batch_size 16 \ --learning_rate 0.00176 \ --num_train_epochs=1 \ --overwrite_output_dir \ --target_batch_size 4096 \ --save_steps=10000 ...

justheuristic commented 2 years ago

And apparently the total training epochs affect the scheduler and learning rate.

I'm afraid, the best answer i can give is "make sure the schedule has the same hyperparameters on all peers"

In your example, the distinction is that --max_epochs corresponds to the number of full passes over the training data, and this is not the same as learning rate scheduler epochs.

In a typical large-scale training scenario, the training data is so huge that a single peer will not be able to run a single. For instance, a single 3090 would take years to train GPT-J-6B over their training dataset (PILE). So instead, it is more convenient to think in terms of local and global training steps.

In examples/albert, there are two distinct parameters:

Admittedly, this can be a bit confusing, but we'll try to make it more intuitive in the nearest update to examples.

elricwan commented 2 years ago

I see, but for linear scheduler, the learning rate decay as: image

It means after total_steps, the learning rate becomes 0, if the max_steps is larger than the total_steps, the peer would still not update the parameter, right?

elricwan commented 2 years ago

Oh, I understand now, the max_step means the training step for local batch size, instead of target batch size, thanks!