Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate the bf16 precisions across models #2203

Closed
xuzhao9 opened this issue Mar 19, 2024 · 2 comments
Closed

Validate the bf16 precisions across models #2203

xuzhao9 opened this issue Mar 19, 2024 · 2 comments

Comments

@xuzhao9
Copy link
Contributor

xuzhao9 commented Mar 19, 2024

A couple of models only run on the default precisions, and we would like to to enable bf16 precisions on them.

The reason models are failing because their inputs either can't be directly cast to bf16, or not completely being cast to bf16.

Per @drisspg:

Using this script:

from transformer_nuggets.utils.shape_trace import ShapeLog
import torch
from pathlib import Path
from tqdm import tqdm
import logging
import json

logging.basicConfig(level=logging.INFO)

def main():
    import torchbenchmark.models as models
    models = []
    success_count = 0
    failure_count = 0
    model_failures = {}
    for file in Path("torchbenchmark/models/").iterdir():
        if file.is_dir():
            models.append(file.name)
    for model_name in tqdm(models, desc="Logging models", aunit="model"):
        try:
            module = __import__(f"torchbenchmark.models.{model_name}", fromlist=[model_name])
            model, example_inputs = module.Model(test="train", device="cuda", extra_args=["--precision=bf16",]).get_module()
            model(*example_inputs)
            success_count += 1
        except Exception as e:
            tqdm.write(f"Failed to log {module}: {e}")
            failure_count += 1
            model_failures[model_name] = str(e)

    tqdm.write(f"Successfully logged {success_count} models")
    tqdm.write(f"Failed to log {failure_count} models")
    with open("model_failures_bf16.txt", "w") as f:
        json.dump(model_failures, f)



if __name__ == "__main__":
    main()

returns the following model failures;

internalfb.com/intern/paste/P1197341526

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Mar 21, 2024

We also want to validate amp_bf16 on CPU.

@xuzhao9 xuzhao9 closed this as completed Jun 20, 2024
@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Jun 20, 2024

We are migrating to pt2 benchmark runner, so we do not plan to support bf16 in the legacy runner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant