Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AnyPrecision optimizer] add automatic BF16 support check (network and gpu) #65

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 107 additions & 30 deletions src/python/torchdistx/optimizers/anyprecision_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
# with optional Kahan summation for high precision weight updates.
# Allows direct control over momentum, variance and auxiliary compensation
# buffer dtypes.
# Optional Kahan summation is used to offset precision reduction for
# the weight updates. This allows full training in BFloat16 (equal or
# better than FP32 results in many cases) due to high precision weight upates.
# Optional Kahan summation is used to enable high precision for
# the weight updates. This allows sucessful training in pure BFloat16
# (often equal or better than FP32 results) due to high precision weight
# updates, while training with reduced GPU memory and
# increased training speed.

import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from torch.optim.optimizer import Optimizer


Expand All @@ -31,33 +35,61 @@ def __init__(
):
"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)

# Any Precision specific
use_kahan_summation = creates auxiliary buffer to ensure high precision
model param updates (default: False)
momentum_dtype = dtype for momentum (default: BFloat32)
variance_dtype = dtype for uncentered variance (default: BFloat16)
compensation_buffer_dtype = dtype for Kahan summation
buffer (default: BFloat16). Only used if
``use_kahan_summation=True``.

# Usage
This optimizer implements optimizer states, and Kahan summation
for high precision updates, all in user controlled dtypes.
Defaults are variance in BF16, Momentum in FP32.
This can be run in FSDP mixed precision, amp, or full precision,
depending on what training pipeline you wish to work with.

Setting to use_kahan_summation = False, and changing momentum and
variance dtypes to FP32, reverts this to a standard AdamW optimizer.
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)

# AnyPrecision specific
use_kahan_summation = use auxiliary buffer to ensure high precision
model param updates (default: False)
momentum_dtype = dtype for momentum (default: BFloat32)
variance_dtype = dtype for uncentered variance (default: BFloat16)
compensation_buffer_dtype = dtype for Kahan summation
buffer (default: BFloat16)

# Usage
This optimizer implements optimizer states, and Kahan summation
for high precision updates, all in user controlled dtypes.
The high precision updates enable successful training in pure
BF16 with corresponding reductions in memory and increases in
training speed.

Defaults are Variance in BF16, Momentum in FP32.
This can be run in FSDP mixed precision, amp, or full precision,
depending on what training pipeline you wish to work with.

Setting to use_kahan_summation = False, and changing momentum and
variance dtypes to FP32, reverts this to a standard AdamW optimizer.

AnyPrecision will automatically verify proper support is present
for BF16, for both GPU and network (NCCL).

To train in pure BF16:
1 - use model.to(torch.bfloat16) to move your model
to BF16.
2 - Set momentum_dtype and variance_dtype to torch.bfloat16
3 - Set use_kahan_summation = True

Example:
# init model
my_model = build_model(config_args)

# ensure model is moved to all bf16
my_model.to(torch.bfloat16)

# setup AnyPrecision to run in pure BF16 with high precision updates
optimizer = AnyPrecisionAdamW(my_model.parameters(), lr=lr, ...,
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
use_kahan_summation=True
)


"""
defaults = dict(
lr=lr,
Expand All @@ -72,6 +104,28 @@ def __init__(

super().__init__(params, defaults)

# confirm bfloat16 support if applicable
if (
torch.bfloat16
in [
momentum_dtype,
variance_dtype,
]
or torch.bfloat16 in [compensation_buffer_dtype]
and use_kahan_summation
):
gpu_support, network_support = self._verify_bfloat_support()

if not gpu_support or not network_support:
reason = ""
if not gpu_support:
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
reason += "Your GPU does not support native Bfloat16. "

if not network_support:
reason += "Your NCCL version does not support BFloat16. "

raise ValueError(f"Missing BFloat16 support. Details: {reason}")

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Expand Down Expand Up @@ -180,3 +234,26 @@ def step(self, closure=None):
else:
# usual AdamW updates
p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)

def _verify_bfloat_support(
self,
):
"""verify gpu and network support for BF16"""
# requires cuda >= 11.0
required_cuda_major = 11

# requires nccl >= 2.10
required_nccl_major = 2
required_nccl_minor = 10

gpu_support = torch.version.cuda and torch.cuda.is_bf16_supported()

cuda_version_major, _ = torch.version.cuda.split(".", maxsplit=1)

network_support = (
int(cuda_version_major) >= required_cuda_major
and dist.is_nccl_available()
and nccl.version() >= (required_nccl_major, required_nccl_minor)
)

return gpu_support, network_support