Accelerate
1. Accelerate
1.1. Accelerator
The Accelerator is the main class for adapting your code to work with Accelerate. It knows about the distributed setup you're using such as the number of different processes and your hardware type. This class also provides access to many of the necessary methods for enabling your PyTorch code to work in any distributed training environment and for managing and executing processes across devices.
That's why you should always start by importing and creating an Accelerator instance in your script.
1from accelerate import Accelerator2accelerator = Accelerator()
The Accelerator also knows which device to move your PyTorch objects to, so it is recommended to let Accelerate handle this for you.
1- device = "cuda"2+ device = accelerator.device3 model.to(device)
1.2. Prepare PyTorch objects
Next, you need to prepare your PyTorch objects (model, optimizer, scheduler, etc.) for distributed training. The prepare() method takes care of placing your model in the appropriate container (like single GPU or multi-GPU) for your training setup, adapting the optimizer and scheduler to use Accelerate's AcceleratedOptimizer and AcceleratedScheduler, and creating a new dataloader that can be sharded across processes.
Accelerate only prepares objects that inherit from their respective PyTorch classes such as torch.optim.Optimizer.
The PyTorch objects are returned in the same order they're sent.
1model, optimizer, training_dataloader, scheduler = accelerator.prepare(2 model, optimizer, training_dataloader, scheduler3)
1.3. Training loop
Finally, remove the to(device) calls to the inputs and targets in the training loop because Accelerate's DataLoader classes automatically places them on the right device. You should also replace the usual backward() pass with Accelerate's backward() method which scales the gradients for you and uses the appropriate backward() method depending on your distributed setup (for example, DeepSpeed or Megatron).
1- inputs = inputs.to(device)2- targets = targets.to(device)3 outputs = model(inputs)4 loss = loss_function(outputs, targets)5- loss.backward()6+ accelerator.backward(loss)
1.4. Training features
- Gradient accumulation
- Gradient clipping
-
Mixed precision
- Mixed precision accelerates training by using a lower precision data type like fp16 (half-precision) to calculate the gradients. For the best performance with Accelerate, the loss should be computed inside your model (like in Transformers models) because computations outside of the model are computed in full precision.
- Set the mixed precision type to use in the Accelerator, and then use the
autocast()context manager to automatically cast the values to the specified data type. - Accelerate enables automatic mixed precision, so
autocast()is only needed if there are other mixed precision operations besides those performed on loss bybackward()which already handles the scaling.
1.5. Save and Load
Once all processes are complete, unwrap the model with the unwrap_model() method before saving it because the prepare() method wrapped your model into the proper interface for distributed training. If you don't unwrap the model, saving the model state dictionary also saves any potential extra layers from the larger model and you won't be able to load the weights back into your base model.
1.5.1. Single checkpoint
1accelerator.wait_for_everyone()2accelerator.save_model(model, save_directory)
To load your weights, use the unwrap_model() method to unwrap the model first before loading the weights. All model parameters are references to tensors, so this loads your weights inside model.
1unwrapped_model = accelerator.unwrap_model(model)2path_to_checkpoint = os.path.join(save_directory,"pytorch_model.bin")3unwrapped_model.load_state_dict(torch.load(path_to_checkpoint))
For models from the Transformers library, save the model with the save_pretrained() method so that it can be reloaded with the from_pretrained() method.
1from transformers import AutoModel23unwrapped_model = accelerator.unwrap_model(model)4unwrapped_model.save_pretrained(5 "path/to/my_model_directory",6 is_main_process=accelerator.is_main_process,7 save_function=accelerator.save,8)910model = AutoModel.from_pretrained("path/to/my_model_directory")
1.5.2. Multiple checkpoints
Set safe_serialization=True to save the model in the safetensor format.
1accelerator.wait_for_everyone()2accelerator.save_model(model, save_directory, max_shard_size="1GB", safe_serialization=True)
To load a sharded checkpoint or a safetensor formatted checkpoint, use the load_checkpoint_in_model() method. This method allows you to load a checkpoint onto a specific device.
1load_checkpoint_in_model(unwrapped_model, save_directory, device_map={"":device})
1.5.3. State
During training, you may want to save the current state of the model, optimizer, random generators, and potentially learning rate schedulers so they can be restored in the same script. You should add the save_state() and load_state() methods to your script to save and load states.
To further customize where and how states are saved through save_state(), use the ProjectConfiguration class. For example, if automatic_checkpoint_naming is enabled, each saved checkpoint is stored at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}.
Any other stateful items to be stored should be registered with the register_for_checkpointing() method so they can be saved and loaded. Every object passed to this method to be stored must have a load_state_dict and state_dict function.
If you have torchdata>=0.8.0 installed, you can additionally pass use_stateful_dataloader=True into your DataLoaderConfiguration. This extends Accelerate's DataLoader classes with a load_state_dict and state_dict function, and makes it so Accelerator.save_state and Accelerator.load_state also track how far into the training dataset it has read when persisting the model.
1.6. Execution process
1.6.1. Execute on one process
Certain code only needs to be run once on a given machine, such as printing a log statement or only displaying one progress bar on the local main process. You could also direct Accelerate to execute code once across all processes regardless of the number of machines. This is useful if you're uploading a final model to the Hub.
1.6.1.1. Statements
You should use accelerator.is_local_main_process to indicate code that should only be executed once.
1from tqdm.auto import tqdm23progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
For standalone print statements that aren't wrapped in accelerator.is_local_main_process, replace print with Accelerate's print() method to only print once per process.
You should use accelerator.is_main_process to indicate code that should only be executed once across all processes.
1if accelerator.is_main_process:2 repo.push_to_hub()
1.6.1.2. function
For a function that should only be executed once, use on_local_main_process().
1@accelerator.on_local_main_process2def do_my_thing():3 "Something done once per server"4 do_thing_once_per_server()
For a function that should only be executed once across all processes, use on_main_process().
1.6.1.3. Execute on a specific process
Accelerate can also help you execute functions that should only be executed on a specific process or a local process index. Use the on_process() method and specify the process index to execute a function on. Use the on_local_process() method and specify the local process index to execute a function on.
1.7. Defer execution
When you run your script on several GPUs at the same time, some code may be executed faster than others. You might need to wait for all processes to reach a certain point before executing the next set of instructions. For instance, you shouldn't save a model before making sure every process is done with training.
To do this, add wait_for_everyone() in your code. This blocks all processes that have finished first from continuing until all remaining processes have reached the same point (this has no effect if you're running on a single GPU or CPU).
1accelerator.wait_for_everyone()
1.8. Launching Accelerate scripts
First, you should rewrite the above code into a function, and make it callable as a script. For example:
1+ def main():2 ...3+ if __name__ == "__main__":4+ main()
Next, you need to launch it with accelerate launch.
It's recommended you run accelerate config before using accelerate launch to configure your environment to your liking. Otherwise Accelerate will use very basic defaults depending on your system setup.
For instance, here is how you would also launch that same script on two GPUs using mixed precision while avoiding all of the warnings:
1accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 {script_name.py} {--arg1} {--arg2} ...
You can run your code on CPU as well! This is helpful for debugging and testing purposes on toy models and datasets.
1accelerate launch --cpu {script_name.py} {--arg1} {--arg2}
1.8.1. Multi-node training
Multi-node training with Accelerate is similar to multi-node training with torchrun. The simplest way to launch a multi-node training run is to do the following:
- Copy your codebase and data to all nodes. (or place them on a shared filesystem)
- Setup your python packages on all nodes.
- Run
accelerate configon the main single node first. After specifying the number of nodes, you will be asked to specify the rank of each node (this will be 0 for the main/master node), along with the IP address and port for the main process. This is required for the worker nodes to communicate with the main process. Afterwards, you can copy or send this config file across all of your nodes, changing the machine_rank to 1, 2,3, etc. to avoid having to run the command (or just follow their directions directly for launching withtorchrunas well)
It is required that the command be ran on all nodes for everything to start, not just running it from the main node. You can use something like SLURM or a different process executor to wrap around this requirement and call everything from a single command.
It is recommended to use the intranet IP of your main node over the public IP for better latency. This is the 192.168.x.x or the 172.x.x.x address you see when you run hostname -I on the main node.
1.9. Launching distributed training from Jupyter Notebooks
1.9.1. Using the notebook_launcher
You pass in the function, the arguments (as a tuple), and the number of processes to train on.
1from accelerate import notebook_launcher2args = ("fp16", 42, 64)3notebook_launcher(training_loop, args, num_processes=2)
Some key notes to remember:
- Make sure to save any code that use CUDA (or CUDA imports) for the function passed to
notebook_launcher() - Set the
num_processesto be the number of devices used for training (such as number of GPUs, CPUs, TPUs, etc) - If using the TPU, declare your model outside the training loop function
1.10. A boilerplate
1import math2from dataclasses import dataclass, field3from pathlib import Path45import hydra6import torch7import transformers8from accelerate import Accelerator9from accelerate.logging import get_logger10from accelerate.utils import DataLoaderConfiguration, set_seed11from hydra.core.config_store import ConfigStore12from hydra.core.hydra_config import HydraConfig13from omegaconf import (14 MISSING,15 OmegaConf,16)17from torch.utils.data import DistributedSampler18from torchdata import StatefulDataLoader19from tqdm.auto import tqdm20from transformers import get_scheduler2122try:23 import datasets24except ImportError:25 datasets = None2627try:28 import diffusers29except ImportError:30 diffusers = None3132logger = get_logger(__name__)33torch.backends.cuda.matmul.allow_tf32 = True343536# ────────────────────────────────────────────────────────────────────────── ✣ ─37# Configuration Classes38# ────────────────────────────────────────────────────────────────────────── ✣ ─394041@dataclass42class ModelConfig:43 pass444546@dataclass47class DataConfig:48 data_dir: str = MISSING49 batch_size: int = MISSING50 num_workers: int = 451 pin_memory: bool = True52 drop_last: bool = True535455@dataclass56class SchedulerConfig:57 name: str = "constant" # cosine, linear, constant58 num_warmup_steps: int | None = None596061@dataclass62class OptimizerConfig:63 learning_rate: float = 0.00164 adam_beta1: float = 0.965 adam_beta2: float = 0.99966 adam_weight_decay: float = 0.0167 adam_epsilon: float = 1e-08686970@dataclass71class TrainingConfig:72 model: ModelConfig = field(default_factory=ModelConfig)73 data: DataConfig = field(default_factory=DataConfig)74 optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)75 scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)7677 # Project settings78 project_name: str = MISSING79 exp_name: str = MISSING8081 # Training settings82 num_epochs: int = MISSING83 seed: int = 4284 gradient_accumulation_steps: int = 185 mixed_precision: str = "fp16"86 cpu: bool = False8788 # Checkpointing89 resume_from_checkpoint: str | None = None90 checkpoint_steps: int | str | None = None91 save_strategy: str = "steps" # "epoch", "steps"9293 # Gradient clipping94 max_grad_norm: float | None = None95 clip_value: float | None = None9697 # Logging and evaluation98 report_to: str | None = "wandb"99 eval_steps: int | None = None100101102cs: ConfigStore = ConfigStore.instance()103cs.store(name="training_config", node=TrainingConfig)104105106# ────────────────────────────────────────────────────────────────────────── ✣ ─107# Model, Data, Optimizer, Scheduler, Loss, Metrics, Validation108# ────────────────────────────────────────────────────────────────────────── ✣ ─109110111def get_model(cfg: ModelConfig):112 pass113114115def set_requires_grad(model: torch.nn.Module, flag: bool = True):116 for p in model.parameters():117 p.requires_grad = flag118119120def get_dataloader(cfg: DataConfig, accelerator: Accelerator):121 train_dataset = None122 sampler: DistributedSampler = DistributedSampler(123 train_dataset,124 num_replicas=accelerator.num_processes,125 rank=accelerator.process_index,126 shuffle=True,127 drop_last=cfg.drop_last,128 )129130 train_dataloader: StatefulDataLoader[int] = StatefulDataLoader(131 train_dataset,132 batch_size=cfg.batch_size,133 sampler=sampler,134 num_workers=cfg.num_workers,135 pin_memory=cfg.pin_memory,136 drop_last=cfg.drop_last,137 )138 return sampler, train_dataloader139140141def get_optimizer(142 model: torch.nn.Module,143 cfg: OptimizerConfig,144):145 return torch.optim.AdamW(146 model.parameters(),147 lr=cfg.learning_rate,148 betas=(cfg.adam_beta1, cfg.adam_beta2),149 weight_decay=cfg.adam_weight_decay,150 eps=cfg.adam_epsilon,151 )152153154def get_loss(cfg: ModelConfig):155 pass156157158def compute_metrics(outputs, targets):159 pass160161162def validate_model(163 model: torch.nn.Module,164 accelerator: Accelerator,165):166 model.eval()167 pass168169170# ────────────────────────────────────────────────────────────────────────── ✣ ─171# Utility Functions172# ────────────────────────────────────────────────────────────────────────── ✣ ─173174175class CheckpointTracker:176 """Keeps track of which epoch and step we last processed."""177178 def __init__(self, epoch: int = 0, step: int = 0) -> None:179 self.epoch: int = epoch180 self.step: int = step181182 def state_dict(self) -> dict[str, int]:183 return {"epoch": self.epoch, "step": self.step}184185 def load_state_dict(self, state: dict[str, int]) -> None:186 self.epoch = state["epoch"]187 self.step = state["step"]188189190def log_model_info(model: torch.nn.Module):191 total_params = sum(p.numel() for p in model.parameters())192 trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)193194 logger.info(f"Total parameters: {total_params:,}")195 logger.info(f"Trainable parameters: {trainable_params:,}")196 logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")197198199# ────────────────────────────────────────────────────────────────────────── ✣ ─200# Training Loop201# ────────────────────────────────────────────────────────────────────────── ✣ ─202203204def train(cfg: TrainingConfig):205 dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)206 accelerator = Accelerator(207 cpu=cfg.cpu,208 log_with=cfg.report_to if cfg.report_to else None,209 mixed_precision=cfg.mixed_precision,210 gradient_accumulation_steps=cfg.gradient_accumulation_steps,211 dataloader_config=dataloader_config,212 )213 tracker: CheckpointTracker = CheckpointTracker()214 accelerator.register_for_checkpointing(tracker)215 logger.info(accelerator.state, main_process_only=False)216217 if accelerator.is_local_main_process:218 if datasets is not None:219 datasets.utils.logging.set_verbosity_warning()220 if diffusers is not None:221 diffusers.utils.logging.set_verbosity_info()222 transformers.utils.logging.set_verbosity_info()223 else:224 if datasets is not None:225 datasets.utils.logging.set_verbosity_error()226 if diffusers is not None:227 diffusers.utils.logging.set_verbosity_error()228 transformers.utils.logging.set_verbosity_error()229230 set_seed(cfg.seed)231232 # ────────────────────────────────────────────────────────────────────────── ✣ ─233 # Model, Loss, Data, Optimizer, Scheduler234 # ────────────────────────────────────────────────────────────────────────── ✣ ─235236 model = get_model(cfg.model)237 set_requires_grad(model, True)238239 loss_fn = get_loss(cfg.model)240241 log_model_info(model)242243 sampler, train_dataloader = get_dataloader(cfg.data, accelerator)244245 # Calculate total training steps for scheduler246 num_training_steps = (247 math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)248 * cfg.num_epochs249 )250251 optimizer = get_optimizer(model, cfg.optimizer)252 scheduler = get_scheduler(253 cfg.scheduler.name,254 optimizer,255 cfg.scheduler.num_warmup_steps * accelerator.num_processes,256 num_training_steps,257 )258259 (260 model,261 optimizer,262 scheduler,263 ) = accelerator.prepare(model, optimizer, scheduler)264265 # ────────────────────────────────────────────────────────────────────────── ✣ ─266 # Checkpointing, Logging267 # ────────────────────────────────────────────────────────────────────────── ✣ ─268269 if cfg.resume_from_checkpoint:270 logger.info(f"Resumed from checkpoint: {cfg.resume_from_checkpoint}")271 accelerator.load_state(cfg.resume_from_checkpoint)272 # Restore sampler epoch for correct shuffling273 sampler.set_epoch(tracker.epoch)274 # Load dataloader state275 loader_state = torch.load(Path(cfg.resume_from_checkpoint) / "loader_state.pt")276 train_dataloader.load_state_dict(loader_state)277 logger.info(f"Resuming from epoch {tracker.epoch}, step {tracker.step}")278279 if cfg.report_to:280 accelerator.init_trackers(281 project_name=cfg.project_name,282 config=OmegaConf.to_container(cfg, resolve=True),283 init_kwargs={"wandb": {"name": f"{cfg.exp_name}"}},284 )285286 # Calculate starting step if resuming287 if cfg.resume_from_checkpoint:288 completed_epochs = tracker.epoch - 1 if tracker.epoch > 1 else 0289 steps_per_epoch = len(train_dataloader)290 resume_step = completed_epochs * steps_per_epoch + tracker.step291 else:292 resume_step = 0293294 # Main progress bar for total training steps295 total_progress = tqdm(296 range(num_training_steps),297 initial=resume_step,298 desc="Training",299 disable=not accelerator.is_local_main_process,300 dynamic_ncols=True,301 )302303 # ────────────────────────────────────────────────────────────────────────── ✣ ─304 # Training Loop305 # ────────────────────────────────────────────────────────────────────────── ✣ ─306307 for epoch in range(308 tracker.epoch if cfg.resume_from_checkpoint else 1, cfg.num_epochs + 1309 ):310 model.train()311312 # Epoch progress bar313 epoch_progress = tqdm(314 train_dataloader,315 desc=f"Epoch {epoch}/{cfg.num_epochs}",316 disable=not accelerator.is_local_main_process,317 dynamic_ncols=True,318 leave=False,319 )320321 for step, batch in enumerate(epoch_progress):322 # Calculate global step323 if epoch > 1:324 completed_steps = (epoch - 1) * len(train_dataloader)325 global_step = completed_steps + step326 else:327 global_step = step328329 tracker.epoch = epoch330 tracker.step = global_step331 with accelerator.accumulate(model):332 loss = loss_fn(model, batch)333 loss_mean = loss.mean()334 accelerator.backward(loss_mean)335336 grad_norm = None337 if accelerator.sync_gradients:338 if cfg.clip_value is not None:339 accelerator.clip_grad_value_(model.parameters(), cfg.clip_value)340 if cfg.max_grad_norm is not None:341 grad_norm = accelerator.clip_grad_norm_(342 model.parameters(), cfg.max_grad_norm343 )344 else:345 grad_norm = None346347 optimizer.step()348 scheduler.step()349 optimizer.zero_grad(set_to_none=True)350351 if accelerator.sync_gradients:352 # Log training state353 log_metrics = {354 "loss": accelerator.gather(loss_mean).mean().detach().item(),355 "learning_rate": scheduler.get_last_lr()[0],356 }357 if grad_norm is not None:358 log_metrics["grad_norm"] = (359 accelerator.gather(grad_norm).mean().detach().item()360 )361362 if torch.cuda.is_available():363 memory_used = torch.cuda.memory_reserved() / 1024**3364 log_metrics["memory_gb"] = memory_used365366 if cfg.report_to:367 accelerator.log(log_metrics)368369 total_progress.update(1)370371 display_metrics = {372 "loss": f"{log_metrics['loss']:.4f}",373 "lr": f"{log_metrics['learning_rate']:.2e}",374 }375 if "memory_gb" in log_metrics:376 display_metrics["mem"] = f"{log_metrics['memory_gb']:.1f}GB"377 if grad_norm is not None:378 display_metrics["grad"] = f"{log_metrics['grad_norm']:.3f}"379380 total_progress.set_postfix(display_metrics)381 epoch_progress.set_postfix(display_metrics)382383 if cfg.save_strategy == "steps" and isinstance(384 cfg.checkpoint_steps, int385 ):386 if global_step % cfg.checkpoint_steps == 0 and global_step > 0:387 output_dir = (388 Path(HydraConfig.get().runtime.output_dir)389 / f"step_{global_step}"390 )391 accelerator.save_state(output_dir)392 logger.info(f"Saved {output_dir} checkpoint")393394 # Run validation if specified395 if (396 cfg.eval_steps is not None397 and global_step % cfg.eval_steps == 0398 and cfg.report_to399 ):400 validate_model(model, accelerator)401 model.train() # Switch back to training mode402403 # Close epoch progress bar404 epoch_progress.close()405406 if cfg.save_strategy == "epoch":407 output_dir = Path(HydraConfig.get().runtime.output_dir) / f"epoch_{epoch}"408 accelerator.save_state(output_dir)409 logger.info(f"Saved {output_dir} checkpoint")410411 accelerator.wait_for_everyone()412 total_progress.close()413 accelerator.save_model(model, Path(HydraConfig.get().runtime.output_dir))414 accelerator.end_training()415416417# ────────────────────────────────────────────────────────────────────────── ✣ ─418# Main419# ────────────────────────────────────────────────────────────────────────── ✣ ─420421422@hydra.main(version_base=None, config_path="conf", config_name="config")423def main(cfg: TrainingConfig) -> None:424 OmegaConf.to_container(cfg, throw_on_missing=True)425 train(cfg)426427428if __name__ == "__main__":429 main()430