Accelerate

1. Accelerate

1.1. Accelerator

The Accelerator is the main class for adapting your code to work with Accelerate. It knows about the distributed setup you're using such as the number of different processes and your hardware type. This class also provides access to many of the necessary methods for enabling your PyTorch code to work in any distributed training environment and for managing and executing processes across devices.

That's why you should always start by importing and creating an Accelerator instance in your script.

1
from accelerate import Accelerator
2
accelerator = Accelerator()

The Accelerator also knows which device to move your PyTorch objects to, so it is recommended to let Accelerate handle this for you.

1
- device = "cuda"
2
+ device = accelerator.device
3
model.to(device)

1.2. Prepare PyTorch objects

Next, you need to prepare your PyTorch objects (model, optimizer, scheduler, etc.) for distributed training. The prepare() method takes care of placing your model in the appropriate container (like single GPU or multi-GPU) for your training setup, adapting the optimizer and scheduler to use Accelerate's AcceleratedOptimizer and AcceleratedScheduler, and creating a new dataloader that can be sharded across processes.

Accelerate only prepares objects that inherit from their respective PyTorch classes such as torch.optim.Optimizer.

The PyTorch objects are returned in the same order they're sent.

1
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
2
model, optimizer, training_dataloader, scheduler
3
)

1.3. Training loop

Finally, remove the to(device) calls to the inputs and targets in the training loop because Accelerate's DataLoader classes automatically places them on the right device. You should also replace the usual backward() pass with Accelerate's backward() method which scales the gradients for you and uses the appropriate backward() method depending on your distributed setup (for example, DeepSpeed or Megatron).

1
- inputs = inputs.to(device)
2
- targets = targets.to(device)
3
outputs = model(inputs)
4
loss = loss_function(outputs, targets)
5
- loss.backward()
6
+ accelerator.backward(loss)

1.4. Training features

  1. Gradient accumulation
  2. Gradient clipping
  3. Mixed precision

    • Mixed precision accelerates training by using a lower precision data type like fp16 (half-precision) to calculate the gradients. For the best performance with Accelerate, the loss should be computed inside your model (like in Transformers models) because computations outside of the model are computed in full precision.
    • Set the mixed precision type to use in the Accelerator, and then use the autocast() context manager to automatically cast the values to the specified data type.
    • Accelerate enables automatic mixed precision, so autocast() is only needed if there are other mixed precision operations besides those performed on loss by backward() which already handles the scaling.

1.5. Save and Load

Once all processes are complete, unwrap the model with the unwrap_model() method before saving it because the prepare() method wrapped your model into the proper interface for distributed training. If you don't unwrap the model, saving the model state dictionary also saves any potential extra layers from the larger model and you won't be able to load the weights back into your base model.

1.5.1. Single checkpoint

1
accelerator.wait_for_everyone()
2
accelerator.save_model(model, save_directory)

To load your weights, use the unwrap_model() method to unwrap the model first before loading the weights. All model parameters are references to tensors, so this loads your weights inside model.

1
unwrapped_model = accelerator.unwrap_model(model)
2
path_to_checkpoint = os.path.join(save_directory,"pytorch_model.bin")
3
unwrapped_model.load_state_dict(torch.load(path_to_checkpoint))

For models from the Transformers library, save the model with the save_pretrained() method so that it can be reloaded with the from_pretrained() method.

1
from transformers import AutoModel
2
3
unwrapped_model = accelerator.unwrap_model(model)
4
unwrapped_model.save_pretrained(
5
"path/to/my_model_directory",
6
is_main_process=accelerator.is_main_process,
7
save_function=accelerator.save,
8
)
9
10
model = AutoModel.from_pretrained("path/to/my_model_directory")

1.5.2. Multiple checkpoints

Set safe_serialization=True to save the model in the safetensor format.

1
accelerator.wait_for_everyone()
2
accelerator.save_model(model, save_directory, max_shard_size="1GB", safe_serialization=True)

To load a sharded checkpoint or a safetensor formatted checkpoint, use the load_checkpoint_in_model() method. This method allows you to load a checkpoint onto a specific device.

1
load_checkpoint_in_model(unwrapped_model, save_directory, device_map={"":device})

1.5.3. State

During training, you may want to save the current state of the model, optimizer, random generators, and potentially learning rate schedulers so they can be restored in the same script. You should add the save_state() and load_state() methods to your script to save and load states.

To further customize where and how states are saved through save_state(), use the ProjectConfiguration class. For example, if automatic_checkpoint_naming is enabled, each saved checkpoint is stored at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}.

Any other stateful items to be stored should be registered with the register_for_checkpointing() method so they can be saved and loaded. Every object passed to this method to be stored must have a load_state_dict and state_dict function.

If you have torchdata>=0.8.0 installed, you can additionally pass use_stateful_dataloader=True into your DataLoaderConfiguration. This extends Accelerate's DataLoader classes with a load_state_dict and state_dict function, and makes it so Accelerator.save_state and Accelerator.load_state also track how far into the training dataset it has read when persisting the model.

1.6. Execution process

1.6.1. Execute on one process

Certain code only needs to be run once on a given machine, such as printing a log statement or only displaying one progress bar on the local main process. You could also direct Accelerate to execute code once across all processes regardless of the number of machines. This is useful if you're uploading a final model to the Hub.

1.6.1.1. Statements

You should use accelerator.is_local_main_process to indicate code that should only be executed once.

1
from tqdm.auto import tqdm
2
3
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)

For standalone print statements that aren't wrapped in accelerator.is_local_main_process, replace print with Accelerate's print() method to only print once per process.

You should use accelerator.is_main_process to indicate code that should only be executed once across all processes.

1
if accelerator.is_main_process:
2
repo.push_to_hub()
1.6.1.2. function

For a function that should only be executed once, use on_local_main_process().

1
@accelerator.on_local_main_process
2
def do_my_thing():
3
"Something done once per server"
4
do_thing_once_per_server()

For a function that should only be executed once across all processes, use on_main_process().

1.6.1.3. Execute on a specific process

Accelerate can also help you execute functions that should only be executed on a specific process or a local process index. Use the on_process() method and specify the process index to execute a function on. Use the on_local_process() method and specify the local process index to execute a function on.

1.7. Defer execution

When you run your script on several GPUs at the same time, some code may be executed faster than others. You might need to wait for all processes to reach a certain point before executing the next set of instructions. For instance, you shouldn't save a model before making sure every process is done with training.

To do this, add wait_for_everyone() in your code. This blocks all processes that have finished first from continuing until all remaining processes have reached the same point (this has no effect if you're running on a single GPU or CPU).

1
accelerator.wait_for_everyone()

1.8. Launching Accelerate scripts

First, you should rewrite the above code into a function, and make it callable as a script. For example:

1
+ def main():
2
...
3
+ if __name__ == "__main__":
4
+ main()

Next, you need to launch it with accelerate launch.

It's recommended you run accelerate config before using accelerate launch to configure your environment to your liking. Otherwise Accelerate will use very basic defaults depending on your system setup.

For instance, here is how you would also launch that same script on two GPUs using mixed precision while avoiding all of the warnings:

1
accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=2 {script_name.py} {--arg1} {--arg2} ...

You can run your code on CPU as well! This is helpful for debugging and testing purposes on toy models and datasets.

1
accelerate launch --cpu {script_name.py} {--arg1} {--arg2}

1.8.1. Multi-node training

Multi-node training with Accelerate is similar to multi-node training with torchrun. The simplest way to launch a multi-node training run is to do the following:

  • Copy your codebase and data to all nodes. (or place them on a shared filesystem)
  • Setup your python packages on all nodes.
  • Run accelerate config on the main single node first. After specifying the number of nodes, you will be asked to specify the rank of each node (this will be 0 for the main/master node), along with the IP address and port for the main process. This is required for the worker nodes to communicate with the main process. Afterwards, you can copy or send this config file across all of your nodes, changing the machine_rank to 1, 2,3, etc. to avoid having to run the command (or just follow their directions directly for launching with torchrun as well)

It is required that the command be ran on all nodes for everything to start, not just running it from the main node. You can use something like SLURM or a different process executor to wrap around this requirement and call everything from a single command.

It is recommended to use the intranet IP of your main node over the public IP for better latency. This is the 192.168.x.x or the 172.x.x.x address you see when you run hostname -I on the main node.

1.9. Launching distributed training from Jupyter Notebooks

1.9.1. Using the notebook_launcher

You pass in the function, the arguments (as a tuple), and the number of processes to train on.

1
from accelerate import notebook_launcher
2
args = ("fp16", 42, 64)
3
notebook_launcher(training_loop, args, num_processes=2)

Some key notes to remember:

  • Make sure to save any code that use CUDA (or CUDA imports) for the function passed to notebook_launcher()
  • Set the num_processes to be the number of devices used for training (such as number of GPUs, CPUs, TPUs, etc)
  • If using the TPU, declare your model outside the training loop function

1.10. A boilerplate

1
import math
2
from dataclasses import dataclass, field
3
from pathlib import Path
4
5
import hydra
6
import torch
7
import transformers
8
from accelerate import Accelerator
9
from accelerate.logging import get_logger
10
from accelerate.utils import DataLoaderConfiguration, set_seed
11
from hydra.core.config_store import ConfigStore
12
from hydra.core.hydra_config import HydraConfig
13
from omegaconf import (
14
MISSING,
15
OmegaConf,
16
)
17
from torch.utils.data import DistributedSampler
18
from torchdata import StatefulDataLoader
19
from tqdm.auto import tqdm
20
from transformers import get_scheduler
21
22
try:
23
import datasets
24
except ImportError:
25
datasets = None
26
27
try:
28
import diffusers
29
except ImportError:
30
diffusers = None
31
32
logger = get_logger(__name__)
33
torch.backends.cuda.matmul.allow_tf32 = True
34
35
36
# ────────────────────────────────────────────────────────────────────────── ✣ ─
37
# Configuration Classes
38
# ────────────────────────────────────────────────────────────────────────── ✣ ─
39
40
41
@dataclass
42
class ModelConfig:
43
pass
44
45
46
@dataclass
47
class DataConfig:
48
data_dir: str = MISSING
49
batch_size: int = MISSING
50
num_workers: int = 4
51
pin_memory: bool = True
52
drop_last: bool = True
53
54
55
@dataclass
56
class SchedulerConfig:
57
name: str = "constant" # cosine, linear, constant
58
num_warmup_steps: int | None = None
59
60
61
@dataclass
62
class OptimizerConfig:
63
learning_rate: float = 0.001
64
adam_beta1: float = 0.9
65
adam_beta2: float = 0.999
66
adam_weight_decay: float = 0.01
67
adam_epsilon: float = 1e-08
68
69
70
@dataclass
71
class TrainingConfig:
72
model: ModelConfig = field(default_factory=ModelConfig)
73
data: DataConfig = field(default_factory=DataConfig)
74
optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
75
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
76
77
# Project settings
78
project_name: str = MISSING
79
exp_name: str = MISSING
80
81
# Training settings
82
num_epochs: int = MISSING
83
seed: int = 42
84
gradient_accumulation_steps: int = 1
85
mixed_precision: str = "fp16"
86
cpu: bool = False
87
88
# Checkpointing
89
resume_from_checkpoint: str | None = None
90
checkpoint_steps: int | str | None = None
91
save_strategy: str = "steps" # "epoch", "steps"
92
93
# Gradient clipping
94
max_grad_norm: float | None = None
95
clip_value: float | None = None
96
97
# Logging and evaluation
98
report_to: str | None = "wandb"
99
eval_steps: int | None = None
100
101
102
cs: ConfigStore = ConfigStore.instance()
103
cs.store(name="training_config", node=TrainingConfig)
104
105
106
# ────────────────────────────────────────────────────────────────────────── ✣ ─
107
# Model, Data, Optimizer, Scheduler, Loss, Metrics, Validation
108
# ────────────────────────────────────────────────────────────────────────── ✣ ─
109
110
111
def get_model(cfg: ModelConfig):
112
pass
113
114
115
def set_requires_grad(model: torch.nn.Module, flag: bool = True):
116
for p in model.parameters():
117
p.requires_grad = flag
118
119
120
def get_dataloader(cfg: DataConfig, accelerator: Accelerator):
121
train_dataset = None
122
sampler: DistributedSampler = DistributedSampler(
123
train_dataset,
124
num_replicas=accelerator.num_processes,
125
rank=accelerator.process_index,
126
shuffle=True,
127
drop_last=cfg.drop_last,
128
)
129
130
train_dataloader: StatefulDataLoader[int] = StatefulDataLoader(
131
train_dataset,
132
batch_size=cfg.batch_size,
133
sampler=sampler,
134
num_workers=cfg.num_workers,
135
pin_memory=cfg.pin_memory,
136
drop_last=cfg.drop_last,
137
)
138
return sampler, train_dataloader
139
140
141
def get_optimizer(
142
model: torch.nn.Module,
143
cfg: OptimizerConfig,
144
):
145
return torch.optim.AdamW(
146
model.parameters(),
147
lr=cfg.learning_rate,
148
betas=(cfg.adam_beta1, cfg.adam_beta2),
149
weight_decay=cfg.adam_weight_decay,
150
eps=cfg.adam_epsilon,
151
)
152
153
154
def get_loss(cfg: ModelConfig):
155
pass
156
157
158
def compute_metrics(outputs, targets):
159
pass
160
161
162
def validate_model(
163
model: torch.nn.Module,
164
accelerator: Accelerator,
165
):
166
model.eval()
167
pass
168
169
170
# ────────────────────────────────────────────────────────────────────────── ✣ ─
171
# Utility Functions
172
# ────────────────────────────────────────────────────────────────────────── ✣ ─
173
174
175
class CheckpointTracker:
176
"""Keeps track of which epoch and step we last processed."""
177
178
def __init__(self, epoch: int = 0, step: int = 0) -> None:
179
self.epoch: int = epoch
180
self.step: int = step
181
182
def state_dict(self) -> dict[str, int]:
183
return {"epoch": self.epoch, "step": self.step}
184
185
def load_state_dict(self, state: dict[str, int]) -> None:
186
self.epoch = state["epoch"]
187
self.step = state["step"]
188
189
190
def log_model_info(model: torch.nn.Module):
191
total_params = sum(p.numel() for p in model.parameters())
192
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
193
194
logger.info(f"Total parameters: {total_params:,}")
195
logger.info(f"Trainable parameters: {trainable_params:,}")
196
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
197
198
199
# ────────────────────────────────────────────────────────────────────────── ✣ ─
200
# Training Loop
201
# ────────────────────────────────────────────────────────────────────────── ✣ ─
202
203
204
def train(cfg: TrainingConfig):
205
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
206
accelerator = Accelerator(
207
cpu=cfg.cpu,
208
log_with=cfg.report_to if cfg.report_to else None,
209
mixed_precision=cfg.mixed_precision,
210
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
211
dataloader_config=dataloader_config,
212
)
213
tracker: CheckpointTracker = CheckpointTracker()
214
accelerator.register_for_checkpointing(tracker)
215
logger.info(accelerator.state, main_process_only=False)
216
217
if accelerator.is_local_main_process:
218
if datasets is not None:
219
datasets.utils.logging.set_verbosity_warning()
220
if diffusers is not None:
221
diffusers.utils.logging.set_verbosity_info()
222
transformers.utils.logging.set_verbosity_info()
223
else:
224
if datasets is not None:
225
datasets.utils.logging.set_verbosity_error()
226
if diffusers is not None:
227
diffusers.utils.logging.set_verbosity_error()
228
transformers.utils.logging.set_verbosity_error()
229
230
set_seed(cfg.seed)
231
232
# ────────────────────────────────────────────────────────────────────────── ✣ ─
233
# Model, Loss, Data, Optimizer, Scheduler
234
# ────────────────────────────────────────────────────────────────────────── ✣ ─
235
236
model = get_model(cfg.model)
237
set_requires_grad(model, True)
238
239
loss_fn = get_loss(cfg.model)
240
241
log_model_info(model)
242
243
sampler, train_dataloader = get_dataloader(cfg.data, accelerator)
244
245
# Calculate total training steps for scheduler
246
num_training_steps = (
247
math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
248
* cfg.num_epochs
249
)
250
251
optimizer = get_optimizer(model, cfg.optimizer)
252
scheduler = get_scheduler(
253
cfg.scheduler.name,
254
optimizer,
255
cfg.scheduler.num_warmup_steps * accelerator.num_processes,
256
num_training_steps,
257
)
258
259
(
260
model,
261
optimizer,
262
scheduler,
263
) = accelerator.prepare(model, optimizer, scheduler)
264
265
# ────────────────────────────────────────────────────────────────────────── ✣ ─
266
# Checkpointing, Logging
267
# ────────────────────────────────────────────────────────────────────────── ✣ ─
268
269
if cfg.resume_from_checkpoint:
270
logger.info(f"Resumed from checkpoint: {cfg.resume_from_checkpoint}")
271
accelerator.load_state(cfg.resume_from_checkpoint)
272
# Restore sampler epoch for correct shuffling
273
sampler.set_epoch(tracker.epoch)
274
# Load dataloader state
275
loader_state = torch.load(Path(cfg.resume_from_checkpoint) / "loader_state.pt")
276
train_dataloader.load_state_dict(loader_state)
277
logger.info(f"Resuming from epoch {tracker.epoch}, step {tracker.step}")
278
279
if cfg.report_to:
280
accelerator.init_trackers(
281
project_name=cfg.project_name,
282
config=OmegaConf.to_container(cfg, resolve=True),
283
init_kwargs={"wandb": {"name": f"{cfg.exp_name}"}},
284
)
285
286
# Calculate starting step if resuming
287
if cfg.resume_from_checkpoint:
288
completed_epochs = tracker.epoch - 1 if tracker.epoch > 1 else 0
289
steps_per_epoch = len(train_dataloader)
290
resume_step = completed_epochs * steps_per_epoch + tracker.step
291
else:
292
resume_step = 0
293
294
# Main progress bar for total training steps
295
total_progress = tqdm(
296
range(num_training_steps),
297
initial=resume_step,
298
desc="Training",
299
disable=not accelerator.is_local_main_process,
300
dynamic_ncols=True,
301
)
302
303
# ────────────────────────────────────────────────────────────────────────── ✣ ─
304
# Training Loop
305
# ────────────────────────────────────────────────────────────────────────── ✣ ─
306
307
for epoch in range(
308
tracker.epoch if cfg.resume_from_checkpoint else 1, cfg.num_epochs + 1
309
):
310
model.train()
311
312
# Epoch progress bar
313
epoch_progress = tqdm(
314
train_dataloader,
315
desc=f"Epoch {epoch}/{cfg.num_epochs}",
316
disable=not accelerator.is_local_main_process,
317
dynamic_ncols=True,
318
leave=False,
319
)
320
321
for step, batch in enumerate(epoch_progress):
322
# Calculate global step
323
if epoch > 1:
324
completed_steps = (epoch - 1) * len(train_dataloader)
325
global_step = completed_steps + step
326
else:
327
global_step = step
328
329
tracker.epoch = epoch
330
tracker.step = global_step
331
with accelerator.accumulate(model):
332
loss = loss_fn(model, batch)
333
loss_mean = loss.mean()
334
accelerator.backward(loss_mean)
335
336
grad_norm = None
337
if accelerator.sync_gradients:
338
if cfg.clip_value is not None:
339
accelerator.clip_grad_value_(model.parameters(), cfg.clip_value)
340
if cfg.max_grad_norm is not None:
341
grad_norm = accelerator.clip_grad_norm_(
342
model.parameters(), cfg.max_grad_norm
343
)
344
else:
345
grad_norm = None
346
347
optimizer.step()
348
scheduler.step()
349
optimizer.zero_grad(set_to_none=True)
350
351
if accelerator.sync_gradients:
352
# Log training state
353
log_metrics = {
354
"loss": accelerator.gather(loss_mean).mean().detach().item(),
355
"learning_rate": scheduler.get_last_lr()[0],
356
}
357
if grad_norm is not None:
358
log_metrics["grad_norm"] = (
359
accelerator.gather(grad_norm).mean().detach().item()
360
)
361
362
if torch.cuda.is_available():
363
memory_used = torch.cuda.memory_reserved() / 1024**3
364
log_metrics["memory_gb"] = memory_used
365
366
if cfg.report_to:
367
accelerator.log(log_metrics)
368
369
total_progress.update(1)
370
371
display_metrics = {
372
"loss": f"{log_metrics['loss']:.4f}",
373
"lr": f"{log_metrics['learning_rate']:.2e}",
374
}
375
if "memory_gb" in log_metrics:
376
display_metrics["mem"] = f"{log_metrics['memory_gb']:.1f}GB"
377
if grad_norm is not None:
378
display_metrics["grad"] = f"{log_metrics['grad_norm']:.3f}"
379
380
total_progress.set_postfix(display_metrics)
381
epoch_progress.set_postfix(display_metrics)
382
383
if cfg.save_strategy == "steps" and isinstance(
384
cfg.checkpoint_steps, int
385
):
386
if global_step % cfg.checkpoint_steps == 0 and global_step > 0:
387
output_dir = (
388
Path(HydraConfig.get().runtime.output_dir)
389
/ f"step_{global_step}"
390
)
391
accelerator.save_state(output_dir)
392
logger.info(f"Saved {output_dir} checkpoint")
393
394
# Run validation if specified
395
if (
396
cfg.eval_steps is not None
397
and global_step % cfg.eval_steps == 0
398
and cfg.report_to
399
):
400
validate_model(model, accelerator)
401
model.train() # Switch back to training mode
402
403
# Close epoch progress bar
404
epoch_progress.close()
405
406
if cfg.save_strategy == "epoch":
407
output_dir = Path(HydraConfig.get().runtime.output_dir) / f"epoch_{epoch}"
408
accelerator.save_state(output_dir)
409
logger.info(f"Saved {output_dir} checkpoint")
410
411
accelerator.wait_for_everyone()
412
total_progress.close()
413
accelerator.save_model(model, Path(HydraConfig.get().runtime.output_dir))
414
accelerator.end_training()
415
416
417
# ────────────────────────────────────────────────────────────────────────── ✣ ─
418
# Main
419
# ────────────────────────────────────────────────────────────────────────── ✣ ─
420
421
422
@hydra.main(version_base=None, config_path="conf", config_name="config")
423
def main(cfg: TrainingConfig) -> None:
424
OmegaConf.to_container(cfg, throw_on_missing=True)
425
train(cfg)
426
427
428
if __name__ == "__main__":
429
main()
430

References

  1. Accelerate