diff --git a/.clang-format b/.clang-format new file mode 100755 index 000000000..38790238f --- /dev/null +++ b/.clang-format @@ -0,0 +1,155 @@ +--- +# Refer to the following link for the explanation of each params: +# http://releases.llvm.org/8.0.0/tools/clang/docs/ClangFormatStyleOptions.html +Language: Cpp +# BasedOnStyle: Google +AccessModifierOffset: -4 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: true +# This is deprecated +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + # disabling the below splits, else, they'll just add to the vertical length of source files! + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: WebKit +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^' + Priority: 2 + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Never +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 4 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' + - Language: TextProto + Delimiters: + - pb + - PB + - proto + - PROTO + EnclosingFunctions: + - EqualsProto + - EquivToProto + - PARSE_PARTIAL_TEXT_PROTO + - PARSE_TEST_PROTO + - PARSE_TEXT_PROTO + - ParseTextOrDie + - ParseTextProtoOrDie + CanonicalDelimiter: '' + BasedOnStyle: google +# Enabling comment reflow causes doxygen comments to be messed up in their formats! +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +# Be consistent with indent-width, even for people who use tab for indentation! +TabWidth: 4 +UseTab: Never diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 4ad22c860..643d547c7 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -21,7 +21,7 @@ Steps to reproduce the behavior: A clear and concise description of what you expected to happen. **Proposed solution** -If you have an idea for how we can fix this problem, describe it here. +If you have an idea for how we can fix this problem, describe it here. **Screenshots** If applicable, add screenshots to help explain your problem. diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 000000000..a657fe68d --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,13 @@ +name: Pull Request + +on: [pull_request] + +jobs: + pre-commit: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + - uses: pre-commit/action@v2.0.3 diff --git a/.gitignore b/.gitignore index 98a29e87f..34de4e774 100644 --- a/.gitignore +++ b/.gitignore @@ -131,7 +131,7 @@ dmypy.json # wandb logs wandb/ -# data files +# data files data/**/*.idx data/**/*.bin data/**/*.json* @@ -149,4 +149,4 @@ test_checkpoint/ test_logs/ logs/ tensorboard/ -src/ \ No newline at end of file +src/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..1907f005f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 + hooks: + - id: check-case-conflict + - id: check-json + - id: check-symlinks + - id: check-yaml + - id: destroyed-symlinks + - id: end-of-file-fixer + exclude: docs/CNAME + - id: fix-byte-order-marker + - id: fix-encoding-pragma + args: [--remove] + - id: mixed-line-ending + args: [--fix=lf] + - id: requirements-txt-fixer + - id: trailing-whitespace + - repo: https://gitlab.com/daverona/pre-commit-cpp + rev: 0.8.0 + hooks: + - id: clang-format # formatter of C/C++ code based on a style guide: LLVM, Google, Chromium, Mozilla, and WebKit available + args: [] + + - repo: https://github.com/psf/black + rev: 21.8b0 + hooks: + - id: black + language_version: python3.8 diff --git a/CITATION.cff b/CITATION.cff index de7be046d..4fae700b3 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,6 +1,6 @@ # YAML 1.2 --- -authors: +authors: - affiliation: EleutherAI family-names: Andonian given-names: Alex @@ -47,7 +47,7 @@ authors: family-names: Weinbach given-names: Samuel cff-version: "1.1.0" -keywords: +keywords: - Transformers - "Massive language model" - "Autoregressive language model" diff --git a/Dockerfile b/Dockerfile index a9fdaf4ea..6cdc80f92 100644 --- a/Dockerfile +++ b/Dockerfile @@ -89,4 +89,3 @@ RUN mkdir -p /tmp && chmod 0777 /tmp #### SWITCH TO mchorse USER USER mchorse WORKDIR /home/mchorse - diff --git a/LICENSE b/LICENSE index 462788cf7..99cf99888 100644 --- a/LICENSE +++ b/LICENSE @@ -199,7 +199,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - + -- This repository also contains code from Hugging Face Inc., Google Research, diff --git a/README.md b/README.md index 0b12272bc..218898400 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # GPT-NeoX -This repository records [EleutherAI](https://www.eleuther.ai)'s work-in-progress for training large-scale language models on GPUs. Our current framework is based on NVIDIA's [Megatron Language Model](https://github.com/NVIDIA/Megatron-LM) and has been augmented with techniques from [DeepSpeed](https://www.deepspeed.ai) as well as some novel optimizations. +This repository records [EleutherAI](https://www.eleuther.ai)'s work-in-progress for training large-scale language models on GPUs. Our current framework is based on NVIDIA's [Megatron Language Model](https://github.com/NVIDIA/Megatron-LM) and has been augmented with techniques from [DeepSpeed](https://www.deepspeed.ai) as well as some novel optimizations. We aim to make this repo a centralized and accessible place to gather techniques for training large-scale autoregressive language models, and accelerate research into large-scale training. Additionally, we hope to train and open source a 175B parameter GPT-3 replication along the way. Please note, however, that this is a research codebase that is primarily designed for performance over ease of use. We endeavour to make it as easy to use as is feasible, but if there's anything in the readme that is unclear or you think you've found a bug, please open an issue. @@ -65,12 +65,12 @@ wget --cut-dirs=5 -nH -r --no-parent --reject "index.html*" https://mystic.the-e First make sure you are in an environment with Python 3.8 or later with an appropriate version of PyTorch 1.8 or later installed. -To install the remaining basic dependencies, run: +To install the remaining basic dependencies, run: ```bash pip install -r requirements/requirements.txt python ./megatron/fused_kernels/setup.py install # optional if not using fused kernels -``` +``` from the repository root. @@ -99,7 +99,7 @@ GPT-NeoX parameters are defined in a YAML configuration file which is passed to ```yaml "vocab-file": "./20B_checkpoints/20B_tokenizer.json", "save": "./20B_checkpoints", - "load": "./20B_checkpoints", + "load": "./20B_checkpoints", ``` changing `./20B_checkpoints` to the path to the root folder of the downloaded checkpoints. If the checkpoints exist at `./20B_checkpoints` you can leave this as is. @@ -128,7 +128,7 @@ We currently offer three main functions: and can be launched with: ```bash -./deepy.py [script.py] [./path/to/config_1.yml] [./path/to/config_2.yml] ... [./path/to/config_n.yml] +./deepy.py [script.py] [./path/to/config_1.yml] [./path/to/config_2.yml] ... [./path/to/config_n.yml] ``` E.G To generate text unconditionally with the GPT-NeoX-20B model, you can use the following: @@ -338,9 +338,9 @@ This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyri Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/configs/13B.yml b/configs/13B.yml index d60fead2b..b2f1e1368 100644 --- a/configs/13B.yml +++ b/configs/13B.yml @@ -57,7 +57,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, diff --git a/configs/175B.yml b/configs/175B.yml index 6df6110d9..baaad3c82 100644 --- a/configs/175B.yml +++ b/configs/175B.yml @@ -57,7 +57,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, diff --git a/configs/2-7B.yml b/configs/2-7B.yml index 617d756ff..b795c310c 100644 --- a/configs/2-7B.yml +++ b/configs/2-7B.yml @@ -19,7 +19,7 @@ "scaled-upper-triang-masked-softmax-fusion": false, "bias-gelu-fusion": false, - + # optimizer settings "optimizer": { "type": "Adam", @@ -58,7 +58,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, diff --git a/configs/20B.yml b/configs/20B.yml index 55a1b1938..7b0d5e481 100644 --- a/configs/20B.yml +++ b/configs/20B.yml @@ -2,44 +2,44 @@ # GPUs. Depending on your system configuration, you may need to change some parameters in order to fit # the model in memory. -{ +{ # Tokenizer / checkpoint settings - you will need to change these to the location you have them saved in "vocab-file": "./20B_checkpoints/20B_tokenizer.json", "save": "./20B_checkpoints", - "load": "./20B_checkpoints", + "load": "./20B_checkpoints", # If finetuning, edit the following to the location of your finetuning dataset: "data-path": "./data/pile_20B_tokenizer/pile_20B_tokenizer_text_document", - - # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages - # across the node boundaries ) - "pipe-parallel-size": 4, - "model-parallel-size": 2, - - # model settings - "num-layers": 44, - "hidden-size": 6144, - "num-attention-heads": 64, - "seq-length": 2048, - "max-position-embeddings": 2048, - "norm": "layernorm", + + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe-parallel-size": 4, + "model-parallel-size": 2, + + # model settings + "num-layers": 44, + "hidden-size": 6144, + "num-attention-heads": 64, + "seq-length": 2048, + "max-position-embeddings": 2048, + "norm": "layernorm", "pos-emb": "rotary", "rotary_pct": 0.25, "no-weight-tying": true, - "gpt_j_residual": true, + "gpt_j_residual": true, "output_layer_parallelism": "column", - "scaled-upper-triang-masked-softmax-fusion": true, - "bias-gelu-fusion": true, + "scaled-upper-triang-masked-softmax-fusion": true, + "bias-gelu-fusion": true, # init methods "init_method": "small_init", "output_layer_init_method": "wang_init", - - # optimizer settings - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.97e-4, + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.97e-4, "betas": [0.9, 0.95], "eps": 1.0e-8, } @@ -76,7 +76,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, @@ -107,4 +107,4 @@ "tensorboard-dir": "./tensorboard", "log-dir": "./logs", -} \ No newline at end of file +} diff --git a/configs/6-7B.yml b/configs/6-7B.yml index fe7983bda..777848781 100644 --- a/configs/6-7B.yml +++ b/configs/6-7B.yml @@ -19,7 +19,7 @@ "scaled-upper-triang-masked-softmax-fusion": false, "bias-gelu-fusion": false, - + # optimizer settings "optimizer": { "type": "Adam", @@ -58,7 +58,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, diff --git a/configs/README.md b/configs/README.md index 715b908ed..046f6d50b 100644 --- a/configs/README.md +++ b/configs/README.md @@ -68,7 +68,7 @@ For a detailed list of all the arguments available for neox, see [neox_arguments "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, @@ -80,7 +80,7 @@ For a detailed list of all the arguments available for neox, see [neox_arguments "lr-decay-iters": 320000, "lr-decay-style": "cosine", "warmup": 0.01, - + # misc. training settings "distributed-backend": "nccl", "save-interval": 10000, @@ -123,14 +123,14 @@ These can be set to any integer between `0` and `num_gpus`, and `num_gpus` must "scaled-upper-triang-masked-softmax-fusion": false, "train-iters": 320000, ``` -An example of some basic settings used to configure your model's architecture and number of training steps. - +An example of some basic settings used to configure your model's architecture and number of training steps. + ### Optimizer Settings: -Our optimizer configuration has a similar syntax to deepspeed's. Different optimizers will have different arguments for "params". +Our optimizer configuration has a similar syntax to deepspeed's. Different optimizers will have different arguments for "params". Learning rate should be configured from here using the `"lr"` field of `optimizer["params"]`. -```yaml +```yaml # optimizer settings "optimizer": { "type": "Adam", @@ -156,12 +156,12 @@ Available optimizer types are: "comm_backend_name": "nccl" } ``` - -- `"CPU_Adam"`/`"CPU_torch_adam"`: Adam optimizer on CPU. Either megatron's version ("CPU_Adam") or torch's ("CPU_torch_adam") + +- `"CPU_Adam"`/`"CPU_torch_adam"`: Adam optimizer on CPU. Either megatron's version ("CPU_Adam") or torch's ("CPU_torch_adam") - `"SM3"`: SM3 or [Memory adaptive efficient optimization optimizer](https://arxiv.org/pdf/1901.11150.pdf). We have found this doesn't work well with fp16 training. - `"madgrad_wd"`: MADGRAD or [A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic - Optimizer] weight decay has been implemented AdamW style instead of the original madgrad Adam style. https://arxiv.org/abs/2101.11075 - + Optimizer] weight decay has been implemented AdamW style instead of the original madgrad Adam style. https://arxiv.org/abs/2101.11075 + ### ZeRO Optimization: ```yaml @@ -195,7 +195,7 @@ N.B - ZeRO stages 2+ are incompatible with pipeline parallelism. Please set `"pi Our global batch size configuration follows deepspeed's and can be configured in a number of ways. At least any one of `"train_batch_size"` and `"train_micro_batch_size_per_gpu"`. - `"train_batch_size"`: The effective training batch size. This is the amount of data samples that leads to one step of model update. train_batch_size is aggregated by the batch size that a single GPU processes in one forward/backward pass (a.k.a., train_step_batch_size), the gradient accumulation steps (a.k.a., gradient_accumulation_steps), and the number of GPUs. - `"train_micro_batch_size_per_gpu""`: Batch size to be processed by one GPU in one step (without gradient accumulation). When specified, `gradient_accumulation_steps` is automatically calculated using train_batch_size and number of GPUs. -- `"gradient_accumulation_steps"`: Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. When specified, train_step_batch_size is automatically calculated using train_batch_size and number of GPUs. +- `"gradient_accumulation_steps"`: Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. When specified, train_step_batch_size is automatically calculated using train_batch_size and number of GPUs. ### Dataset / Tokenizer / Checkpoint / Logging Settings: @@ -226,7 +226,7 @@ Our global batch size configuration follows deepspeed's and can be configured in "warmup": 0.01, ``` -Settings used to modify the learning rate over time. +Settings used to modify the learning rate over time. N.B - `OneBitAdam` requires you to use deepspeed's internal lr scheduler because reasons. Currently the lr decay style defaults to deepspeed's `WarmupDecay @@ -246,7 +246,7 @@ gpt-neox's mixed precision training is configured identically to DeepSpeed's, pl An example config for fp16 training: ```yaml - "fp16": { + "fp16": { "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, @@ -257,4 +257,4 @@ An example config for fp16 training: To train in fp32, simply set `fp16["enabled"]` to `false`. -** TODO: bf16 docs ** \ No newline at end of file +** TODO: bf16 docs ** diff --git a/configs/XL.yml b/configs/XL.yml index bc466110a..16aa48d02 100644 --- a/configs/XL.yml +++ b/configs/XL.yml @@ -57,7 +57,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, diff --git a/configs/bnb_small.yml b/configs/bnb_small.yml index 82af1e8d7..5d7e8dcd3 100644 --- a/configs/bnb_small.yml +++ b/configs/bnb_small.yml @@ -59,7 +59,7 @@ "attention-dropout": 0.0, # precision settings - "fp16": { + "fp16": { "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, diff --git a/configs/eleutherai_cluster.yml b/configs/eleutherai_cluster.yml index f67978cbf..0a2c6e0e1 100644 --- a/configs/eleutherai_cluster.yml +++ b/configs/eleutherai_cluster.yml @@ -1,7 +1,7 @@ # Data paths and options when using EleutherAI cluster { "data-path": "/mnt/ssd-1/data/enron/enron_text_document", - # or for weighted datasets: + # or for weighted datasets: # "train-data-paths": ["/mnt/ssd-1/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"], # "test-data-paths": ["/mnt/ssd-1/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"], # "valid-data-paths": ["/mnt/ssd-1/data/enron/enron_text_document", "/mnt/ssd-cluster/data/enron/enron_text_document"], diff --git a/configs/gmlp_small.yml b/configs/gmlp_small.yml index d08de61b4..6724b371a 100644 --- a/configs/gmlp_small.yml +++ b/configs/gmlp_small.yml @@ -46,7 +46,7 @@ "attention-dropout": 0.0, # precision settings - "fp16": { + "fp16": { "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, diff --git a/configs/large.yml b/configs/large.yml index 35bd79866..b03348d49 100644 --- a/configs/large.yml +++ b/configs/large.yml @@ -19,7 +19,7 @@ "scaled-upper-triang-masked-softmax-fusion": false, "bias-gelu-fusion": false, - + # optimizer settings "optimizer": { "type": "Adam", @@ -58,7 +58,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, diff --git a/configs/local_setup.yml b/configs/local_setup.yml index b45dbd63b..64a57c354 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -1,8 +1,8 @@ # Suggested data paths when using GPT-NeoX locally { "data-path": "data/enron/enron_text_document", - - # or for weighted datasets: + + # or for weighted datasets: # "train-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"], # "test-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"], # "valid-data-paths": ["data/enron/enron_text_document", "data/enron/enron_text_document"], @@ -10,7 +10,7 @@ # "test-data-weights": [2., 1.], # "valid-data-weights": [0.5, 0.4], - # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group. + # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group. # WARNING: setting this to True will override any user provided weights # "weight_by_num_documents": false, # "weighted_sampler_alpha": 0.3, @@ -21,10 +21,10 @@ "save": "checkpoints", "load": "checkpoints", "checkpoint_validation_with_forward_pass": False, - + "tensorboard-dir": "tensorboard", "log-dir": "logs", "use_wandb": True, "wandb_host": "https://api.wandb.ai", "wandb_project": "neox" -} \ No newline at end of file +} diff --git a/configs/medium.yml b/configs/medium.yml index 9a93e83aa..0e7ca304b 100644 --- a/configs/medium.yml +++ b/configs/medium.yml @@ -20,7 +20,7 @@ "bias-gelu-fusion": false, - + # optimizer settings "optimizer": { "type": "Adam", @@ -58,7 +58,7 @@ "attention-dropout": 0, # precision settings - "fp16": { + "fp16": { "fp16": true, "enabled": true, "loss_scale": 0, diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 26fc61678..0f7d4838d 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -146,7 +146,7 @@ Logging Arguments Default = False Log the frob norm of the gradients to wandb / tensorboard (useful for debugging). - (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because + (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.) @@ -306,18 +306,18 @@ Model Arguments Default = None Attention configuration for gpt-neox - - The first item in the list specifies the attention type(s), and should be a list of strings. The second item + + The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird] - + So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] - + or a 12 layer network with alternating global / local like: [[[`global`, `local`], 6]] - + If none is specified, this defaults to [[[`global`], n_layers]] @@ -328,13 +328,13 @@ Model Arguments Default = None Sparsity configuration dict as defined in https://www.deepspeed.ai/docs/config-json/#sparse-attention - - Note that since neox is autoregressive, attention is always "unidirectional" and `horizontal_global_attention` is + + Note that since neox is autoregressive, attention is always "unidirectional" and `horizontal_global_attention` is always false. - - The main difference between our sparsity config and deepspeed's is that `mode` is ignored - since it is instead + + The main difference between our sparsity config and deepspeed's is that `mode` is ignored - since it is instead specified in attention_config defining each layer. - + An example config is given below: "sparse_attention": { "block": 16, @@ -475,7 +475,7 @@ Model Arguments Default = normal - Init function used on all layers except ff residual outputs - choose from + Init function used on all layers except ff residual outputs - choose from ["normal", "scaled_normal", "orthogonal", "scaled_orthogonal", "xavier_uniform", "xavier_normal", "wang_init", "small_init"] @@ -484,7 +484,7 @@ Model Arguments Default = scaled_normal - Init function used for ff residual outputs - choose from + Init function used for ff residual outputs - choose from ["normal", "scaled_normal", "orthogonal", "scaled_orthogonal", "xavier_uniform", "xavier_normal", "wang_init", "small_init"] @@ -515,7 +515,7 @@ Model Arguments Default = None - Dictionary configuring the soft prompt tuning parameters. + Dictionary configuring the soft prompt tuning parameters. If enabled, will train *only* the soft prompt, and freezes the rest of the model. parameters in the dict are: 'enabled': bool = True # enables soft prompting @@ -787,8 +787,8 @@ Parallelism Arguments Default = type:transformer|mlp - method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number - of parameters on each pipeline stage, "uniform", which naively balances the number of layers per stage, or + method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number + of parameters on each pipeline stage, "uniform", which naively balances the number of layers per stage, or "type:[regex]", which balances layers whose class names match [regex] @@ -805,7 +805,7 @@ Parallelism Arguments Default = False - flag to determine whether pipeline parallelism is on - shouldn't be set by user, is automatically determined + flag to determine whether pipeline parallelism is on - shouldn't be set by user, is automatically determined according to pipeline parallel size. @@ -898,7 +898,7 @@ Text Generation arguments - **eval_results_prefix**: str - Default = + Default = prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json @@ -930,7 +930,7 @@ Tokenizer Arguments Default = None - Total (padded) vocabulary size of tokenizer. Configured after launching of training, + Total (padded) vocabulary size of tokenizer. Configured after launching of training, as it's dependent on the parallelism size. @@ -1005,7 +1005,7 @@ Training Arguments Default = False If True, Builds dataset weights from a multinomial distribution over groups of data according to the number of - documents in each group. + documents in each group. WARNING: setting this to True will override any user provided weights @@ -1421,7 +1421,7 @@ Args for deepspeed config dict containing the keys type and params type: The scheduler name. See here (https://deepspeed.readthedocs.io/en/latest/schedulers.html) for list of support schedulers. - + params: Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. @@ -1486,7 +1486,7 @@ Args for deepspeed config Default = None - + @@ -1542,7 +1542,7 @@ Args for deepspeed runner (deepspeed.launcher.runner). Default = None list of hostnames / ssh aliases and the number of GPUs per host - + example file contents: worker-1 slots=4 worker-2 slots=4 @@ -1612,4 +1612,3 @@ Args for deepspeed runner (deepspeed.launcher.runner). Default = False If true, autodetects nvlink pairs and remaps cuda visible devices to place them next to each other. This is an Eleuther addition to deepspeed, and should speed up model parallel training on setups with nvlink pairs when mp=2. - diff --git a/configs/small.yml b/configs/small.yml index 06787c78a..746743ff1 100644 --- a/configs/small.yml +++ b/configs/small.yml @@ -58,7 +58,7 @@ "attention-dropout": 0.0, # precision settings - "fp16": { + "fp16": { "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, diff --git a/configs/small_bf16.yml b/configs/small_bf16.yml index 1463089e9..5aa81be16 100644 --- a/configs/small_bf16.yml +++ b/configs/small_bf16.yml @@ -58,7 +58,7 @@ "attention-dropout": 0.0, # precision settings - "fp16": { + "fp16": { "enabled": true, "type": "bfloat16", # set bf16 as precision "loss_scale": 0, @@ -66,7 +66,7 @@ "hysteresis": 2, "min_loss_scale": 1 }, - + "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 # misc. training settings "train-iters": 320000, diff --git a/configs/text_generation.yml b/configs/text_generation.yml index ebd29c4ae..1dfa6d931 100644 --- a/configs/text_generation.yml +++ b/configs/text_generation.yml @@ -3,18 +3,18 @@ { # Text gen type: `input-file`, `unconditional` or `interactive` "text-gen-type": "unconditional", - + # Params for all "maximum_tokens": 102, "temperature": 1.0, "top_p": 0.0, "top_k": 0, "recompute": false, - + # `unconditional`: samples "num-samples": 10, # input/output file "sample-input-file": "sample_input.txt", "sample-output-file": "sample_output.txt", -} \ No newline at end of file +} diff --git a/deepy.py b/deepy.py index 108b00649..94faddebf 100755 --- a/deepy.py +++ b/deepy.py @@ -25,15 +25,14 @@ from megatron.utils import get_wandb_api_key - neox_args = NeoXArgs.consume_deepy_args() deepspeed_main_args = neox_args.get_deepspeed_main_args() # Extract wandb API key and inject into worker environments wandb_token = get_wandb_api_key(neox_args=neox_args) if wandb_token is not None: - deepspeed.launcher.runner.EXPORT_ENVS.append('WANDB_API_KEY') - os.environ['WANDB_API_KEY'] = wandb_token + deepspeed.launcher.runner.EXPORT_ENVS.append("WANDB_API_KEY") + os.environ["WANDB_API_KEY"] = wandb_token -if __name__ == '__main__': +if __name__ == "__main__": main(deepspeed_main_args) diff --git a/evaluate.py b/evaluate.py index 7cfff9cc2..f14f6a194 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # diff --git a/generate.py b/generate.py index cb12ffe62..6e0c20874 100755 --- a/generate.py +++ b/generate.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# coding=utf-8 # Copyright (c) 2021 Josh Levy-Kramer . All rights reserved. # This file is based on code by the authors denoted below and has been modified from its original version. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. @@ -32,7 +31,9 @@ def main(): """ model, neox_args = setup_for_inference_or_eval(use_cache=True) if neox_args.recompute: - model.module.inference_mode(use_cache=False) # don't use kv cache if recomputing + model.module.inference_mode( + use_cache=False + ) # don't use kv cache if recomputing if neox_args.text_gen_type == "unconditional": print_rank_0( f"Generating samples unconditionally and saving results to {neox_args.sample_output_file}" diff --git a/megatron/__init__.py b/megatron/__init__.py index 79a7b7137..4a9f98a31 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +13,7 @@ # limitations under the License. import torch + def print_rank_0(*message): """If distributed is initialized print only on rank 0.""" if torch.distributed.is_initialized(): @@ -25,4 +25,3 @@ def print_rank_0(*message): from .initialize import initialize_megatron from .neox_arguments import NeoXArgs - diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 1faaf6eb6..10eddb929 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -183,8 +182,8 @@ def save_ds_checkpoint(iteration, model, neox_args): if neox_args.checkpoint_validation_with_forward_pass: logits = do_forward_pass(neox_args=neox_args, model=model) - sd['checkpoint_validation_logits'] = logits - + sd["checkpoint_validation_logits"] = logits + # checkpoint folder name tag = f"global_step{iteration}" @@ -192,13 +191,14 @@ def save_ds_checkpoint(iteration, model, neox_args): model.save_checkpoint(neox_args.save, tag=tag, client_state=sd) # save config files - if torch.distributed.get_rank() == 0 and neox_args.config_files is not None: + if torch.distributed.get_rank() == 0 and neox_args.config_files is not None: configs_directory = os.path.join(neox_args.save, tag, "configs") os.makedirs(configs_directory, exist_ok=True) for config_filename, config_data in neox_args.config_files.items(): with open(os.path.join(configs_directory, config_filename), "w") as f: f.write(config_data) + def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py index 7cfb232d5..9c714237f 100644 --- a/megatron/data/blendable_dataset.py +++ b/megatron/data/blendable_dataset.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +24,6 @@ class BlendableDataset(torch.utils.data.Dataset): - def __init__(self, datasets, weights): self.datasets = datasets num_datasets = len(datasets) @@ -48,13 +46,22 @@ def __init__(self, datasets, weights): self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) from megatron.data import helpers - helpers.build_blending_indices(self.dataset_index, - self.dataset_sample_index, - weights, num_datasets, self.size, - torch.distributed.get_rank() == 0) - print('> RANK {} elapsed time for building blendable dataset indices: ' - '{:.2f} (sec)'.format(torch.distributed.get_rank(), time.time() - start_time)) + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + + print( + "> RANK {} elapsed time for building blendable dataset indices: " + "{:.2f} (sec)".format( + torch.distributed.get_rank(), time.time() - start_time + ) + ) def __len__(self): return self.size @@ -67,5 +74,6 @@ def __getitem__(self, idx): except IndexError: new_idx = idx % len(self) print( - f'WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})') + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) return self[new_idx] diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 3f3eed484..2447fc9ad 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -2,8 +2,8 @@ import torch import numpy as np from typing import List, Tuple -from itertools import zip_longest -from functools import partial +from itertools import zip_longest +from functools import partial from megatron import mpu, print_rank_0 from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset @@ -24,106 +24,128 @@ def make_data_loader(dataset, neox_args): # Use a simple sampler with distributed batch sampler. sampler = torch.utils.data.SequentialSampler(dataset) - batch_sampler = DistributedBatchSampler(sampler=sampler, - batch_size=global_batch_size, - drop_last=True, - rank=rank, - world_size=world_size) + batch_sampler = DistributedBatchSampler( + sampler=sampler, + batch_size=global_batch_size, + drop_last=True, + rank=rank, + world_size=world_size, + ) # Torch dataloader. - return torch.utils.data.DataLoader(dataset, - batch_sampler=batch_sampler, - num_workers=num_workers, - pin_memory=True) - - -def build_the_dataset(data_prefix, name, data_impl, - num_samples, - seq_length, seed, skip_warmup, build_index_mappings=True): + return torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True + ) + + +def build_the_dataset( + data_prefix, + name, + data_impl, + num_samples, + seq_length, + seed, + skip_warmup, + build_index_mappings=True, +): """Build train/valid/test datasets.""" - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) - total_num_of_documents = indexed_dataset.sizes.shape[0] - print_rank_0(' {}:'.format(name)) - print_rank_0(' no. of documents:{}'.format(total_num_of_documents)) + total_num_of_documents = indexed_dataset.sizes.shape[0] + print_rank_0(" {}:".format(name)) + print_rank_0(" no. of documents:{}".format(total_num_of_documents)) dataset = None - documents = np.arange(start=0, stop=total_num_of_documents, - step=1, dtype=np.int32) - dataset = GPT2Dataset(name, data_prefix, - documents, indexed_dataset, - num_samples, - seq_length, seed, - build_index_mappings=build_index_mappings) + documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) + dataset = GPT2Dataset( + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed, + build_index_mappings=build_index_mappings, + ) return dataset -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup): +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + skip_warmup, +): """Build train, valid, and test datasets.""" # Indexed dataset. - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - print_rank_0(' > dataset split:') + print_rank_0(" > dataset split:") def print_split_stats(name, index): - print_rank_0(' {}:'.format(name)) - print_rank_0(' document indices in [{}, {}) total of {} ' - 'documents'.format(splits[index], splits[index + 1], - splits[index + 1] - splits[index])) + print_rank_0(" {}:".format(name)) + print_rank_0( + " document indices in [{}, {}) total of {} " + "documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + ) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: - documents = np.arange(start=splits[index], stop=splits[index + 1], - step=1, dtype=np.int32) + documents = np.arange( + start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 + ) - dataset = GPT2Dataset(name, data_prefix, - documents, indexed_dataset, - train_valid_test_num_samples[index], - seq_length, seed) + dataset = GPT2Dataset( + name, + data_prefix, + documents, + indexed_dataset, + train_valid_test_num_samples[index], + seq_length, + seed, + ) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return train_dataset, valid_dataset, test_dataset def get_train_valid_test_split_(splits_string, size): - """ Get dataset splits from comma or '/' separated string list.""" + """Get dataset splits from comma or '/' separated string list.""" splits = [] - if splits_string.find(',') != -1: - splits = [float(s) for s in splits_string.split(',')] - elif splits_string.find('/') != -1: - splits = [float(s) for s in splits_string.split('/')] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] else: splits = [float(splits_string)] while len(splits) < 3: - splits.append(0.) + splits.append(0.0) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split / splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): - splits_index.append(splits_index[index] + - int(round(split * float(size)))) + splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff @@ -132,7 +154,9 @@ def get_train_valid_test_split_(splits_string, size): return splits_index -def get_normalized_weights_and_num_samples(weights: List[float], num_samples: int) -> Tuple[List[float], List[int]]: +def get_normalized_weights_and_num_samples( + weights: List[float], num_samples: int +) -> Tuple[List[float], List[int]]: # Normalize weights weight_sum = sum(weights) assert weight_sum > 0.0 @@ -145,45 +169,67 @@ def get_normalized_weights_and_num_samples(weights: List[float], num_samples: in weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) return weights, weighted_num_samples -def build_weighted_datasets(neox_args, train_num_samples, valid_num_samples, test_num_samples, train_weights, valid_weights, test_weights, build_index_mappings=True): + +def build_weighted_datasets( + neox_args, + train_num_samples, + valid_num_samples, + test_num_samples, + train_weights, + valid_weights, + test_weights, + build_index_mappings=True, +): # build individual datasets train_datasets, valid_datasets, test_datasets = [], [], [] - for i, (train_path, valid_path, test_path) in enumerate(zip_longest(neox_args.train_data_paths, neox_args.valid_data_paths, neox_args.test_data_paths)): + for i, (train_path, valid_path, test_path) in enumerate( + zip_longest( + neox_args.train_data_paths, + neox_args.valid_data_paths, + neox_args.test_data_paths, + ) + ): if train_path: - train_datasets.append(build_the_dataset( - data_prefix=train_path, - name=f'train_{i}', - data_impl=neox_args.data_impl, - num_samples=train_num_samples[i], - seq_length=neox_args.seq_length, - seed=neox_args.seed, - skip_warmup=(not neox_args.mmap_warmup), - build_index_mappings=build_index_mappings - )) + train_datasets.append( + build_the_dataset( + data_prefix=train_path, + name=f"train_{i}", + data_impl=neox_args.data_impl, + num_samples=train_num_samples[i], + seq_length=neox_args.seq_length, + seed=neox_args.seed, + skip_warmup=(not neox_args.mmap_warmup), + build_index_mappings=build_index_mappings, + ) + ) if valid_path: - valid_datasets.append(build_the_dataset( - data_prefix=valid_path, - name=f'valid_{i}', - data_impl=neox_args.data_impl, - num_samples=valid_num_samples[i], - seq_length=neox_args.seq_length, - seed=neox_args.seed, - skip_warmup=(not neox_args.mmap_warmup), - build_index_mappings=build_index_mappings - )) + valid_datasets.append( + build_the_dataset( + data_prefix=valid_path, + name=f"valid_{i}", + data_impl=neox_args.data_impl, + num_samples=valid_num_samples[i], + seq_length=neox_args.seq_length, + seed=neox_args.seed, + skip_warmup=(not neox_args.mmap_warmup), + build_index_mappings=build_index_mappings, + ) + ) if test_path: - test_datasets.append(build_the_dataset( - data_prefix=test_path, - name=f'test_{i}', - data_impl=neox_args.data_impl, - num_samples=test_num_samples[i], - seq_length=neox_args.seq_length, - seed=neox_args.seed, - skip_warmup=(not neox_args.mmap_warmup), - build_index_mappings=build_index_mappings - )) + test_datasets.append( + build_the_dataset( + data_prefix=test_path, + name=f"test_{i}", + data_impl=neox_args.data_impl, + num_samples=test_num_samples[i], + seq_length=neox_args.seq_length, + seed=neox_args.seed, + skip_warmup=(not neox_args.mmap_warmup), + build_index_mappings=build_index_mappings, + ) + ) return train_datasets, valid_datasets, test_datasets @@ -221,18 +267,19 @@ def weights_by_num_docs(l, alpha=0.3): return weights - def build_train_valid_test_data_iterators(neox_args): """XXX""" (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) - print_rank_0('> building train, validation, and test datasets ...') + print_rank_0("> building train, validation, and test datasets ...") # Ensure only the first/last pipeline stages have data loaders if neox_args.is_pipe_parallel: is_first_stage = mpu.get_pipe_parallel_rank() == 0 - is_last_stage = mpu.get_pipe_parallel_rank() == mpu.get_pipe_parallel_world_size() - 1 + is_last_stage = ( + mpu.get_pipe_parallel_rank() == mpu.get_pipe_parallel_world_size() - 1 + ) pipe_load = is_first_stage or is_last_stage else: pipe_load = True @@ -243,38 +290,85 @@ def build_train_valid_test_data_iterators(neox_args): train_iters = neox_args.train_iters eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters test_iters = neox_args.eval_iters - train_val_test_num_samples = [train_iters * neox_args.train_batch_size, - eval_iters * neox_args.train_batch_size, - test_iters * neox_args.train_batch_size] + train_val_test_num_samples = [ + train_iters * neox_args.train_batch_size, + eval_iters * neox_args.train_batch_size, + test_iters * neox_args.train_batch_size, + ] if neox_args.train_data_paths: # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset - train_weights, train_num_samples = get_normalized_weights_and_num_samples(neox_args.train_data_weights, train_val_test_num_samples[0]) - valid_weights, valid_num_samples = get_normalized_weights_and_num_samples(neox_args.valid_data_weights, train_val_test_num_samples[1]) - test_weights, test_num_samples = get_normalized_weights_and_num_samples(neox_args.test_data_weights, train_val_test_num_samples[2]) + train_weights, train_num_samples = get_normalized_weights_and_num_samples( + neox_args.train_data_weights, train_val_test_num_samples[0] + ) + valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( + neox_args.valid_data_weights, train_val_test_num_samples[1] + ) + test_weights, test_num_samples = get_normalized_weights_and_num_samples( + neox_args.test_data_weights, train_val_test_num_samples[2] + ) # build individual datasets - train_datasets, valid_datasets, test_datasets = build_weighted_datasets(neox_args, train_num_samples, valid_num_samples, test_num_samples, train_weights, valid_weights, test_weights, \ - build_index_mappings=not neox_args.weight_by_num_documents) - + train_datasets, valid_datasets, test_datasets = build_weighted_datasets( + neox_args, + train_num_samples, + valid_num_samples, + test_num_samples, + train_weights, + valid_weights, + test_weights, + build_index_mappings=not neox_args.weight_by_num_documents, + ) + if neox_args.weight_by_num_documents: - + # gets the number of documents in each datapath - get_num_docs_list = lambda datasets: [dataset.indexed_dataset.sizes.shape[0] for dataset in datasets] - train_num_docs, valid_num_docs, test_num_docs = get_num_docs_list(train_datasets), get_num_docs_list(valid_datasets), get_num_docs_list(test_datasets) - + get_num_docs_list = lambda datasets: [ + dataset.indexed_dataset.sizes.shape[0] for dataset in datasets + ] + train_num_docs, valid_num_docs, test_num_docs = ( + get_num_docs_list(train_datasets), + get_num_docs_list(valid_datasets), + get_num_docs_list(test_datasets), + ) + # builds weights according to alpha + the number of docs - fn = partial(weights_by_num_docs, alpha=neox_args.weighted_sampler_alpha) - train_weights, valid_weights, test_weights = fn(train_num_docs), fn(valid_num_docs), fn(test_num_docs) - train_weights, train_num_samples = get_normalized_weights_and_num_samples(train_weights, train_val_test_num_samples[0]) - valid_weights, valid_num_samples = get_normalized_weights_and_num_samples(valid_weights, train_val_test_num_samples[1]) - test_weights, test_num_samples = get_normalized_weights_and_num_samples(test_weights, train_val_test_num_samples[2]) - + fn = partial( + weights_by_num_docs, alpha=neox_args.weighted_sampler_alpha + ) + train_weights, valid_weights, test_weights = ( + fn(train_num_docs), + fn(valid_num_docs), + fn(test_num_docs), + ) + ( + train_weights, + train_num_samples, + ) = get_normalized_weights_and_num_samples( + train_weights, train_val_test_num_samples[0] + ) + ( + valid_weights, + valid_num_samples, + ) = get_normalized_weights_and_num_samples( + valid_weights, train_val_test_num_samples[1] + ) + test_weights, test_num_samples = get_normalized_weights_and_num_samples( + test_weights, train_val_test_num_samples[2] + ) + # rebuild datasets weighted according to new weights - train_datasets, valid_datasets, test_datasets = build_weighted_datasets(neox_args, train_num_samples, valid_num_samples, test_num_samples, train_weights, valid_weights, test_weights) - - + train_datasets, valid_datasets, test_datasets = build_weighted_datasets( + neox_args, + train_num_samples, + valid_num_samples, + test_num_samples, + train_weights, + valid_weights, + test_weights, + ) + if train_datasets: train_ds = BlendableDataset(train_datasets, train_weights) if valid_datasets: @@ -291,7 +385,7 @@ def build_train_valid_test_data_iterators(neox_args): train_valid_test_num_samples=train_val_test_num_samples, seq_length=neox_args.seq_length, seed=neox_args.seed, - skip_warmup=(not neox_args.mmap_warmup) + skip_warmup=(not neox_args.mmap_warmup), ) # Build dataloders. @@ -304,8 +398,7 @@ def build_train_valid_test_data_iterators(neox_args): do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 do_test = test_dataloader is not None and neox_args.eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. - flags = torch.cuda.LongTensor( - [int(do_train), int(do_valid), int(do_test)]) + flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) @@ -315,26 +408,38 @@ def build_train_valid_test_data_iterators(neox_args): # broadcast globally instead of just the model parallel group. torch.distributed.broadcast(flags, src=0) else: - torch.distributed.broadcast(flags, - mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) + torch.distributed.broadcast( + flags, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() # Shift the start iterations. if train_dataloader is not None: - train_dataloader.batch_sampler.start_iter = (neox_args.iteration * neox_args.gradient_accumulation_steps) % \ - len(train_dataloader) - print_rank_0('setting training data start iteration to {}'. - format(train_dataloader.batch_sampler.start_iter)) + train_dataloader.batch_sampler.start_iter = ( + neox_args.iteration * neox_args.gradient_accumulation_steps + ) % len(train_dataloader) + print_rank_0( + "setting training data start iteration to {}".format( + train_dataloader.batch_sampler.start_iter + ) + ) if valid_dataloader is not None: - start_iter_val = ((neox_args.iteration * neox_args.gradient_accumulation_steps) // neox_args.eval_interval) * \ - neox_args.eval_iters - valid_dataloader.batch_sampler.start_iter = start_iter_val % \ - len(valid_dataloader) - print_rank_0('setting validation data start iteration to {}'. - format(valid_dataloader.batch_sampler.start_iter)) + start_iter_val = ( + (neox_args.iteration * neox_args.gradient_accumulation_steps) + // neox_args.eval_interval + ) * neox_args.eval_iters + valid_dataloader.batch_sampler.start_iter = start_iter_val % len( + valid_dataloader + ) + print_rank_0( + "setting validation data start iteration to {}".format( + valid_dataloader.batch_sampler.start_iter + ) + ) # Build iterators. if train_dataloader is not None: @@ -352,7 +457,7 @@ def build_train_valid_test_data_iterators(neox_args): else: test_data_iterator = None - return train_data_iterator, valid_data_iterator, test_data_iterator + return train_data_iterator, valid_data_iterator, test_data_iterator def compile_helper(): @@ -360,9 +465,11 @@ def compile_helper(): is invoked on a single process.""" import os import subprocess + path = os.path.abspath(os.path.dirname(__file__)) - ret = subprocess.run(['make', '-C', path]) + ret = subprocess.run(["make", "-C", path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys - sys.exit(1) \ No newline at end of file + + sys.exit(1) diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 6c719c5cd..305e85f73 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -26,10 +25,19 @@ from megatron import mpu, print_rank_0 -class GPT2Dataset(torch.utils.data.Dataset): - def __init__(self, name, data_prefix, documents, indexed_dataset, - num_samples, seq_length, seed, build_index_mappings=True): +class GPT2Dataset(torch.utils.data.Dataset): + def __init__( + self, + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed, + build_index_mappings=True, + ): self.name = name self.indexed_dataset = indexed_dataset @@ -41,13 +49,21 @@ def __init__(self, name, data_prefix, documents, indexed_dataset, if build_index_mappings: # Build index mappings. self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( - self.name, data_prefix, documents, self.indexed_dataset.sizes, - num_samples, seq_length, seed) + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + num_samples, + seq_length, + seed, + ) self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 self.sample_idx_len = self.sample_idx.shape[0] - 1 if self.shuffle_idx_len != self.sample_idx_len: - print(f'WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})') + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) def __len__(self): return min(self.shuffle_idx_len, self.sample_idx_len) @@ -63,31 +79,39 @@ def __getitem__(self, idx): offset_l = self.sample_idx[idx + 1][1] # If we are within the same document, just extract the chunk. if doc_index_f == doc_index_l: - sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1) + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) else: # Otherwise, get the rest of the initial document. - sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], - offset=offset_f)] + sample_list = [ + self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) # And finally add the relevant portion of last document. - sample_list.append(self.indexed_dataset.get( - self.doc_idx[doc_index_l], - length=offset_l + 1)) + sample_list.append( + self.indexed_dataset.get( + self.doc_idx[doc_index_l], length=offset_l + 1 + ) + ) sample = np.concatenate(sample_list) - return {'text': np.array(sample, dtype=np.int64)} + return {"text": np.array(sample, dtype=np.int64)} except IndexError: new_idx = idx % len(self) - print(f'WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})') + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) return self[new_idx] -def _build_index_mappings(name, data_prefix, documents, sizes, - num_samples, seq_length, seed): +def _build_index_mappings( + name, data_prefix, documents, sizes, num_samples, seq_length, seed +): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. sample-idx: is the start document index and document offset for each @@ -102,48 +126,60 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # Filename of the index mappings. _filename = data_prefix - _filename += '_{}_indexmap'.format(name) - _filename += '_{}ns'.format(num_samples) - _filename += '_{}sl'.format(seq_length) - _filename += '_{}s'.format(seed) - doc_idx_filename = _filename + '_doc_idx.npy' - sample_idx_filename = _filename + '_sample_idx.npy' - shuffle_idx_filename = _filename + '_shuffle_idx.npy' + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" # Build the indexed mapping if not exist. if torch.distributed.get_rank() == 0: - if (not os.path.isfile(doc_idx_filename)) or \ - (not os.path.isfile(sample_idx_filename)) or \ - (not os.path.isfile(shuffle_idx_filename)): - print_rank_0(' > WARNING: could not find index map files, building ' - 'the indices on rank 0 ...') + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + print_rank_0( + " > WARNING: could not find index map files, building " + "the indices on rank 0 ..." + ) # doc-idx. start_time = time.time() doc_idx = _build_doc_idx(documents, num_epochs, np_rng) np.save(doc_idx_filename, doc_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save doc-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) + print_rank_0( + " > elasped time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) # sample-idx. start_time = time.time() # Use C++ implementation for speed. from megatron.data import helpers + assert doc_idx.dtype == np.int32 assert sizes.dtype == np.int32 - sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch) + sample_idx = helpers.build_sample_idx( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, # num_epochs, tokens_per_epoch) np.save(sample_idx_filename, sample_idx, allow_pickle=True) - print_rank_0(' > elapsed time to build and save sample-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) + print_rank_0( + " > elapsed time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) # shuffle-idx. start_time = time.time() # -1 is due to data structure used to retieve the index: # sample i --> [sample_idx[i], sample_idx[i+1]) shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) - print_rank_0(' > elapsed time to build and save shuffle-idx mapping' - ' (seconds): {:4f}'.format(time.time() - start_time)) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model @@ -151,24 +187,22 @@ def _build_index_mappings(name, data_prefix, documents, sizes, counts = torch.cuda.LongTensor([1]) torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) assert counts[0].item() == torch.distributed.get_world_size( - group=mpu.get_io_parallel_group()) + group=mpu.get_io_parallel_group() + ) # Load mappings. start_time = time.time() - print_rank_0(' > loading doc-idx mapping from {}'.format( - doc_idx_filename)) - doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' > loading sample-idx mapping from {}'.format( - sample_idx_filename)) - sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' > loading shuffle-idx mapping from {}'.format( - shuffle_idx_filename)) - shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( - time.time() - start_time)) - print_rank_0(' total number of samples: {}'.format( - sample_idx.shape[0])) - print_rank_0(' total number of epochs: {}'.format(num_epochs)) + print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) return doc_idx, sample_idx, shuffle_idx @@ -196,7 +230,7 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples): def _build_doc_idx(documents, num_epochs, np_rng): """Build an array with length = number-of-epochs * number-of-documents. Each index is mapped to a corresponding document.""" - doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] doc_idx[:] = documents doc_idx = doc_idx.reshape(-1) doc_idx = doc_idx.astype(np.int32) @@ -204,8 +238,7 @@ def _build_doc_idx(documents, num_epochs, np_rng): return doc_idx -def _build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch): +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): """Sample index mapping is a 2D array with sizes [number-of-samples + 1, 2] where [..., 0] contains the index into `doc_idx` and [..., 1] is the @@ -239,7 +272,7 @@ def _build_sample_idx(sizes, doc_idx, seq_length, # Note that -1 here is for the same reason we have -1 in # `_num_epochs` calculations. if remaining_seq_length <= 0: - doc_offset += (remaining_seq_length + doc_length - 1) + doc_offset += remaining_seq_length + doc_length - 1 remaining_seq_length = 0 else: # Otherwise, start from the beginning of the next document. @@ -261,5 +294,3 @@ def _build_shuffle_idx(size, np_rng): shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) np_rng.shuffle(shuffle_idx) return shuffle_idx - - diff --git a/megatron/data/helpers.cpp b/megatron/data/helpers.cpp index 5216c18fb..a9ee281db 100644 --- a/megatron/data/helpers.cpp +++ b/megatron/data/helpers.cpp @@ -15,93 +15,85 @@ limitations under the License. */ - /* Helper methods for fast index mapping builds */ +#include +#include +#include #include #include #include -#include -#include -#include -#include #include +#include namespace py = pybind11; using namespace std; const int32_t LONG_SENTENCE_LEN = 512; - - void build_blending_indices(py::array_t& dataset_index, - py::array_t& dataset_sample_index, - const py::array_t& weights, - const int32_t num_datasets, - const int64_t size, - const bool verbose) { - /* Given multiple datasets and a weighting array, build samples - such that it follows those wieghts.*/ - - if (verbose) { - std::cout << "> building indices for blendable datasets ..." << std::endl; - } - - // Get the pointer access without the checks. - auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); - auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); - auto weights_ptr = weights.unchecked<1>(); - - // Initialize buffer for number of samples used for each dataset. - int64_t current_samples[num_datasets]; - for(int64_t i = 0; i < num_datasets; ++i) { - current_samples[i] = 0; - } - - // For each sample: - for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { - - // Determine where the max error in sampling is happening. - double sample_idx_double = std::max(static_cast(sample_idx), 1.0); - int64_t max_error_index = 0; - double max_error = weights_ptr[0] * sample_idx_double - - static_cast(current_samples[0]); - for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { - double error = weights_ptr[dataset_idx] * sample_idx_double - - static_cast(current_samples[dataset_idx]); - if (error > max_error) { - max_error = error; - max_error_index = dataset_idx; - } - } - - // Populate the indices. - dataset_index_ptr[sample_idx] = static_cast(max_error_index); - dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; - - // Update the total samples. - current_samples[max_error_index] += 1; + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, + const bool verbose) +{ + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { std::cout << "> building indices for blendable datasets ..." << std::endl; } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for (int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } + + // For each sample: + for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + // Determine where the max error in sampling is happening. + double sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = + weights_ptr[0] * sample_idx_double - static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } - } + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; - // print info - if (verbose) { - std::cout << " > sample ratios:" << std::endl; - for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { - auto ratio = static_cast(current_samples[dataset_idx]) / - static_cast(size); - std::cout << " dataset " << dataset_idx << ", input: " << - weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + // Update the total samples. + current_samples[max_error_index] += 1; } - } + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = + static_cast(current_samples[dataset_idx]) / static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] + << ", achieved: " << ratio << std::endl; + } + } } py::array build_sample_idx(const py::array_t& sizes_, - const py::array_t& doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) { + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) +{ /* Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened and the samples are built based on this 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] @@ -119,17 +111,14 @@ py::array build_sample_idx(const py::array_t& sizes_, // Mapping and it's length (1D). int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; - int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + int32_t* sample_idx = new int32_t[2 * (num_samples + 1)]; cout << " using:" << endl << std::flush; - cout << " number of documents: " << - doc_idx_.shape(0) / num_epochs << endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " sequence length: " << seq_length << - endl << std::flush; - cout << " total number of samples: " << num_samples << - endl << std::flush; + cout << " number of documents: " << doc_idx_.shape(0) / num_epochs << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl << std::flush; + cout << " sequence length: " << seq_length << endl << std::flush; + cout << " total number of samples: " << num_samples << endl << std::flush; // Index into sample_idx. int64_t sample_index = 0; @@ -144,61 +133,57 @@ py::array build_sample_idx(const py::array_t& sizes_, while (sample_index <= num_samples) { // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length != 0) { + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); - remaining_seq_length = 0; - } else { - // Otherwise, start from the begining of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; } // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) { - int32_t *mem = reinterpret_cast(mem_); - delete[] mem; - }); + py::capsule free_when_done(sample_idx, [](void* mem_) { + int32_t* mem = reinterpret_cast(mem_); + delete[] mem; + }); // Return the numpy array. const auto byte_size = sizeof(int32_t); - return py::array(std::vector{num_samples+1, 2}, // shape - {2*byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references - + return py::array(std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references } - inline int32_t get_target_sample_len(const int32_t short_seq_ratio, - const int32_t max_length, - std::mt19937& rand32_gen) { + const int32_t max_length, + std::mt19937& rand32_gen) +{ /* Training sample length. */ const auto random_number = rand32_gen(); - if ((random_number % short_seq_ratio) == 0) { - return 2 + random_number % (max_length - 1); - } + if ((random_number % short_seq_ratio) == 0) { return 2 + random_number % (max_length - 1); } return max_length; } - -template +template py::array build_mapping_impl(const py::array_t& docs_, const py::array_t& sizes_, const int32_t num_epochs, @@ -206,7 +191,8 @@ py::array build_mapping_impl(const py::array_t& docs_, const int32_t max_seq_length, const double short_seq_prob, const int32_t seed, - const bool verbose) { + const bool verbose) +{ /* Build a mapping of (start-index, end-index, sequence-length) where start and end index are the indices of the sentences in the sample and sequence-length is the target sequence length. @@ -228,27 +214,20 @@ py::array build_mapping_impl(const py::array_t& docs_, if (verbose) { const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " short sequence probability: " << short_seq_prob << - endl << std::flush; - cout << " short sequence ration (1/prob): " << short_seq_ratio << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << endl << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl << std::flush; + cout << " number of epochs: " << num_epochs << endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << endl << std::flush; + cout << " seed: " << seed << endl << std::flush; } // Mapping and it's length (1D). @@ -258,8 +237,7 @@ py::array build_mapping_impl(const py::array_t& docs_, // Perform two iterations, in the first iteration get the size // and allocate memory and in the second iteration populate the map. bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - + for (int32_t iteration = 0; iteration < 2; ++iteration) { // Set the seed so both iterations produce the same results. std::mt19937 rand32_gen(seed); @@ -269,29 +247,29 @@ py::array build_mapping_impl(const py::array_t& docs_, // Counters: uint64_t empty_docs = 0; uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; + uint64_t long_sent_docs = 0; // Current map index. uint64_t map_index = 0; // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " << epoch + << " epochs ..." << endl + << std::flush; + } break; } // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { // Document sentences are in [sent_index_first, sent_index_last) const auto sent_index_first = docs[doc]; const auto sent_index_last = docs[doc + 1]; // At the begining of the document previous index is the - // start index. + // start index. auto prev_start_index = sent_index_first; // Remaining documents. @@ -299,138 +277,122 @@ py::array build_mapping_impl(const py::array_t& docs_, // Some bookkeeping if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } + if (num_remain_sent == 0) { ++empty_docs; } + if (num_remain_sent == 1) { ++one_sent_docs; } } - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent > 1) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { ++long_sent_docs; } + contains_long_sentence = true; + break; + } + } + } // If we have more than two sentences. if ((num_remain_sent > 1) && (!contains_long_sentence)) { - // Set values. auto seq_len = int32_t{0}; auto num_sent = int32_t{0}; - auto target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); + auto target_seq_len = + get_target_sample_len(short_seq_ratio, max_seq_length, rand32_gen); // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and if not only one sentence is left in the document. - // and if we have at least two sentneces. - // and if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent > 1) && - (num_sent > 1) ) || (num_remain_sent == 0)) { - - // Check for overflow. - if ((3 * map_index + 2) > - std::numeric_limits::max()) { - cout << "number of samples exceeded maximum " - << "allowed by type int64: " - << std::numeric_limits::max() - << endl; - throw std::overflow_error("Number of samples"); - } - - // Populate the map. - if (second) { - const auto map_index_0 = 3 * map_index; - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(target_seq_len); - } - - // Update indices / counters. - ++map_index; - prev_start_index = sent_index + 1; - target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - seq_len = 0; - num_sent = 0; - } - - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && (num_remain_sent > 1) && + (num_sent > 1)) || + (num_remain_sent == 0)) { + // Check for overflow. + if ((3 * map_index + 2) > std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = + get_target_sample_len(short_seq_ratio, max_seq_length, rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; - } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[3*map_index]; + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3 * map_index]; num_samples = static_cast(map_index); } - } // for (int iteration=0; iteration < 2; ++iteration) { + } // for (int iteration=0; iteration < 2; ++iteration) { // Shuffle. // We need a 64 bit random number generator as we might have more // than 2 billion samples. std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 3 * i; - const auto j0 = 3 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); } // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); // Return the numpy array. const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 3}, // shape - {3*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references - + return py::array(std::vector{num_samples, 3}, // shape + {3 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } - py::array build_mapping(const py::array_t& docs_, const py::array_t& sizes_, const int num_epochs, @@ -438,26 +400,32 @@ py::array build_mapping(const py::array_t& docs_, const int max_seq_length, const double short_seq_prob, const int seed, - const bool verbose) { - + const bool verbose) +{ if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose); + if (verbose) { cout << " using uint64 for data mapping..." << endl << std::flush; } + return build_mapping_impl(docs_, + sizes_, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose); } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose); + if (verbose) { cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_mapping_impl(docs_, + sizes_, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose); } } -template +template py::array build_blocks_mapping_impl(const py::array_t& docs_, const py::array_t& sizes_, const py::array_t& titles_sizes_, @@ -466,7 +434,8 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, const int32_t max_seq_length, const int32_t seed, const bool verbose, - const bool use_one_sent_blocks) { + const bool use_one_sent_blocks) +{ /* Build a mapping of (start-index, end-index, sequence-length) where start and end index are the indices of the sentences in the sample and sequence-length is the target sequence length. @@ -487,20 +456,15 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, const auto sent_end_index = docs[docs_.shape(0) - 1]; const auto num_sentences = sent_end_index - sent_start_index; cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << endl << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl << std::flush; + cout << " number of epochs: " << num_epochs << endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl << std::flush; + cout << " seed: " << seed << endl << std::flush; } // Mapping and its length (1D). @@ -509,15 +473,12 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, // Acceptable number of sentences per block. int min_num_sent = 2; - if (use_one_sent_blocks) { - min_num_sent = 1; - } + if (use_one_sent_blocks) { min_num_sent = 1; } // Perform two iterations, in the first iteration get the size // and allocate memory and in the second iteration populate the map. bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - + for (int32_t iteration = 0; iteration < 2; ++iteration) { // Set the flag on second iteration. second = (iteration == 1); @@ -528,20 +489,20 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, uint64_t one_sent_docs = 0; uint64_t long_sent_docs = 0; // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; + cout << " reached " << max_num_samples << " samples after " << epoch + << " epochs ..." << endl + << std::flush; } break; } // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { // Document sentences are in [sent_index_first, sent_index_last) const auto sent_index_first = docs[doc]; const auto sent_index_last = docs[doc + 1]; @@ -556,22 +517,16 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, // Some bookkeeping if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } + if (num_remain_sent == 0) { ++empty_docs; } + if (num_remain_sent == 1) { ++one_sent_docs; } } // Detect documents with long sentences. bool contains_long_sentence = false; if (num_remain_sent >= min_num_sent) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { ++long_sent_docs; } contains_long_sentence = true; break; } @@ -579,34 +534,32 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, } // If we have enough sentences and no long sentences. if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - // Set values. auto seq_len = int32_t{0}; auto num_sent = int32_t{0}; // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; // If we have reached the target length. // and there are an acceptable number of sentences left // and if we have at least the minimum number of sentences. // or if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent >= min_num_sent) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - + if (((seq_len >= target_seq_len) && (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { // Populate the map. if (second) { const auto map_index_0 = 4 * map_index; - // Each sample has 4 items: the starting sentence index, ending sentence index, - // the index of the document from which the block comes (used for fetching titles) - // and the unique id of the block (used for creating block indexes) + // Each sample has 4 items: the starting sentence index, ending + // sentence index, the index of the document from which the block + // comes (used for fetching titles) and the unique id of the block + // (used for creating block indexes) maps[map_index_0] = static_cast(prev_start_index); maps[map_index_0 + 1] = static_cast(sent_index + 1); @@ -621,35 +574,34 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, seq_len = 0; num_sent = 0; } - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; + cout << " number of empty documents: " << empty_docs << endl << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; } assert(maps == NULL); assert(num_samples < 0); - maps = new DocIdx[4*map_index]; + maps = new DocIdx[4 * map_index]; num_samples = static_cast(map_index); } - } // for (int iteration=0; iteration < 2; ++iteration) { + } // for (int iteration=0; iteration < 2; ++iteration) { // Shuffle. // We need a 64 bit random number generator as we might have more // than 2 billion samples. std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { + for (auto i = (num_samples - 1); i > 0; --i) { const auto j = static_cast(rand64_gen() % (i + 1)); const auto i0 = 4 * i; const auto j0 = 4 * j; @@ -661,18 +613,17 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, } // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); // Return the numpy array. const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 4}, // shape - {4*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references - + return py::array(std::vector{num_samples, 4}, // shape + {4 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } py::array build_blocks_mapping(const py::array_t& docs_, @@ -682,25 +633,36 @@ py::array build_blocks_mapping(const py::array_t& docs_, const uint64_t max_num_samples, const int max_seq_length, const int seed, - const bool verbose, - const bool use_one_sent_blocks) { - + const bool verbose, + const bool use_one_sent_blocks) +{ if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + if (verbose) { cout << " using uint64 for data mapping..." << endl << std::flush; } + return build_blocks_mapping_impl(docs_, + sizes_, + titles_sizes_, + num_epochs, + max_num_samples, + max_seq_length, + seed, + verbose, + use_one_sent_blocks); } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + if (verbose) { cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_blocks_mapping_impl(docs_, + sizes_, + titles_sizes_, + num_epochs, + max_num_samples, + max_seq_length, + seed, + verbose, + use_one_sent_blocks); } } -PYBIND11_MODULE(helpers, m) { +PYBIND11_MODULE(helpers, m) +{ m.def("build_mapping", &build_mapping); m.def("build_blocks_mapping", &build_blocks_mapping); m.def("build_sample_idx", &build_sample_idx); diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 25f9e4db2..82af8673e 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -31,23 +31,27 @@ def __best_fitting_dtype(vocab_size=None): def infer_dataset_impl(path): if IndexedDataset.exists(path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: - return 'cached' + return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: - return 'mmap' + return "mmap" else: return None else: print(f"Dataset does not exist: {path}") - print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) return None def make_builder(out_file, impl, vocab_size=None): - if impl == 'mmap': - return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) else: return IndexedDatasetBuilder(out_file) @@ -55,22 +59,24 @@ def make_builder(out_file, impl, vocab_size=None): def make_dataset(path, impl, skip_warmup=False): if not IndexedDataset.exists(path): print(f"Dataset does not exist: {path}") - print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) return None - if impl == 'infer': + if impl == "infer": impl = infer_dataset_impl(path) - if impl == 'lazy' and IndexedDataset.exists(path): + if impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path) - elif impl == 'cached' and IndexedDataset.exists(path): + elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path) - elif impl == 'mmap' and MMapIndexedDataset.exists(path): + elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) @@ -94,7 +100,7 @@ def write_longs(f, a): 5: np.int64, 6: np.float, 7: np.double, - 8: np.uint16 + 8: np.uint16, } @@ -106,11 +112,11 @@ def code(dtype): def index_file_path(prefix_path): - return prefix_path + '.idx' + return prefix_path + ".idx" def data_file_path(prefix_path): - return prefix_path + '.bin' + return prefix_path + ".bin" def create_doc_idx(sizes): @@ -123,7 +129,8 @@ def create_doc_idx(sizes): class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" - _HDR_MAGIC = b'TNTIDX\x00\x00' + + _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path): super().__init__() @@ -132,29 +139,29 @@ def __init__(self, path): self.read_index(path) def read_index(self, path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( - 'Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.' + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." ) version = f.read(8) - assert struct.unpack('= self._len: - raise IndexError('index out of range') + raise IndexError("index out of range") def __del__(self): if self.data_file: @@ -167,7 +174,7 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -176,7 +183,7 @@ def __getitem__(self, idx): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") - sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) @@ -196,8 +203,8 @@ def size(self, index): @staticmethod def exists(path): - return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + return os.path.exists(index_file_path(path)) and os.path.exists( + data_file_path(path) ) @property @@ -206,7 +213,6 @@ def supports_prefetch(self): class IndexedCachedDataset(IndexedDataset): - def __init__(self, path): super().__init__(path) self.cache = None @@ -231,7 +237,7 @@ def prefetch(self, indices): for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] - a = self.cache[ptx: ptx + size] + a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size @@ -245,10 +251,10 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] - np.copyto(a, self.cache[ptx: ptx + a.size]) + np.copyto(a, self.cache[ptx : ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary @@ -266,11 +272,11 @@ class IndexedDatasetBuilder(object): np.int32: 4, np.int64: 8, np.float: 4, - np.double: 8 + np.double: 8, } def __init__(self, out_file, dtype=np.int32): - self.out_file = open(out_file, 'wb') + self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] @@ -300,7 +306,7 @@ def merge_file_(self, another_file): for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) - with open(data_file_path(another_file), 'rb') as f: + with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: @@ -310,12 +316,12 @@ def merge_file_(self, another_file): def finalize(self, index_file): self.out_file.close() - index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack('= 0: g.manual_seed(self.epoch) if self.replacement: - return iter(torch.randint(high=n, size=(self.num_samples,), - dtype=torch.int64, generator=g).tolist()) + return iter( + torch.randint( + high=n, size=(self.num_samples,), dtype=torch.int64, generator=g + ).tolist() + ) return iter(torch.randperm(n, generator=g).tolist()) def __len__(self): @@ -81,23 +88,30 @@ class DistributedBatchSampler(data.sampler.BatchSampler): sampler level. This allows wrapping of arbitrary data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. - + The `interleave` argument specifies how to distribute a batch. A value of True combined with the above random sampler is equivalent to pytorch's torch.utils.data.distributed.DistributedSampler. - For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 + For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2 specifying True will result in the following samples for each gpu: GPU0: [0,2,4,6] GPU1: [1,3,5,7] specifying False will result in the following samples: GPU0: [0,1,2,3] GPU1: [4,5,6,7]""" - def __init__(self, sampler, batch_size, drop_last, rank=-1, - world_size=2, wrap_last=False, interleave=False): - super(DistributedBatchSampler, self).__init__(sampler, batch_size, - drop_last) + def __init__( + self, + sampler, + batch_size, + drop_last, + rank=-1, + world_size=2, + wrap_last=False, + interleave=False, + ): + super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) if rank == -1: - assert False, 'should not be here' + assert False, "should not be here" rank = torch.distributed.get_rank() self.rank = rank self.world_size = world_size @@ -122,8 +136,8 @@ def __iter__(self): batch_len = len(batch) if batch_len > 0 and not self.drop_last: if self.wrap_last: - self.sampler.wrap_around -= (self.batch_size) - self.wrap_around += (len(batch)) + self.sampler.wrap_around -= self.batch_size + self.wrap_around += len(batch) self.wrap_around %= self.batch_size yield self._batch(batch) if self.wrap_last: @@ -142,7 +156,7 @@ def data_iterator(self, _iter, wrap_around=False): def _batch(self, batch): """extracts samples only pertaining to this worker's batch""" if self.interleave: - return batch[self.rank:self.batch_size:self.world_size] + return batch[self.rank : self.batch_size : self.world_size] start = self.rank * self.batch_size // self.world_size end = (self.rank + 1) * self.batch_size // self.world_size return batch[start:end] diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py index 609ce7a17..3fb7bf86a 100644 --- a/megatron/fused_kernels/__init__.py +++ b/megatron/fused_kernels/__init__.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,14 +28,17 @@ # extra_cuda_cflags below os.environ["TORCH_CUDA_ARCH_LIST"] = "" + def load_fused_kernels(neox_args): try: import scaled_upper_triang_masked_softmax_cuda import scaled_masked_softmax_cuda except (ImportError, ModuleNotFoundError): - print('\n') - print('='*100) - print(f'ERROR: Please run `python {str(srcpath / "setup.py")} install` to install the fused kernels') - print('='*100) + print("\n") + print("=" * 100) + print( + f'ERROR: Please run `python {str(srcpath / "setup.py")} install` to install the fused kernels' + ) + print("=" * 100) exit() - return \ No newline at end of file + return diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/compat.h index 92e7eb772..25818b2ed 100644 --- a/megatron/fused_kernels/compat.h +++ b/megatron/fused_kernels/compat.h @@ -18,8 +18,6 @@ * https://github.com/NVIDIA/apex * with minor changes. */ - - #ifndef TORCH_CHECK #define TORCH_CHECK AT_CHECK #endif diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp index 1852aee6f..b7c162c78 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.cpp +++ b/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -22,76 +22,62 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads); -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); +torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) +{ + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - return fwd_cuda(input, mask, scale_factor); + return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) +{ + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) +{ return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); } -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); - m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size."); } diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h index 1f98291ca..84d173dbc 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -16,60 +16,77 @@ #pragma once -#include #include +#include #include +#include +#include #include #include -#include -#include -#include namespace { template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); +__device__ __inline__ void copy_vector(Datatype* dst, const Datatype* src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::BFloat16* dst, + const c10::BFloat16* src) +{ + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::BFloat16* dst, + const c10::BFloat16* src) +{ + *((float2*)dst) = *((float2*)src); +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half* dst, const c10::Half* src) +{ + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half* dst, const c10::Half* src) +{ + *((float2*)dst) = *((float2*)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t* dst, const uint8_t* src) +{ + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t* dst, const uint8_t* src) +{ + *((half2*)dst) = *((half2*)src); +} -int log2_ceil(int value) { +int log2_ceil(int value) +{ int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); @@ -78,13 +95,14 @@ __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int wid #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) +{ ReduceOp r; - #pragma unroll +#pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); sum[i] = r(sum[i], b); } @@ -95,40 +113,43 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { * Extended softmax (from native aten pytorch) with following additional features * 1) input scaling * 2) Explicit masking - */ + */ template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) +__global__ void scaled_masked_softmax_warp_forward(output_t* dst, + const input_t* src, + const uint8_t* mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two + : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; } // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; @@ -141,29 +162,29 @@ __global__ void scaled_masked_softmax_warp_forward( acc_t elements[WARP_BATCH][WARP_ITERATIONS]; input_t temp_data[ELEMENTS_PER_LDG_STG]; uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; + int itr_idx = i * element_count + it * WARP_SIZE; copy_vector(temp_data, src + itr_idx); copy_vector(temp_mask, mask + itr_idx); - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } } else { - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements[i][it + element] = -std::numeric_limits::infinity(); } @@ -173,21 +194,21 @@ __global__ void scaled_masked_softmax_warp_forward( // compute max_value acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } warp_reduce(max_value); - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { elements[i][it] = std::exp((elements[i][it] - max_value[i])); sum[i] += elements[i][it]; } @@ -196,52 +217,51 @@ __global__ void scaled_masked_softmax_warp_forward( // store result output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { out[element] = elements[i][it + element] / sum[i]; } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); + copy_vector( + dst + i * element_count + it * WARP_SIZE, out); } else { break; - } + } } } } template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) +__global__ void scaled_masked_softmax_warp_backward(output_t* gradInput, + input_t* grad, + const input_t* output, + acc_t scale, + int micro_batch_size, + int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two + : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) + // gridDim/blockIdx = (seq_len, attn_heads, batches) int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - + // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; @@ -253,67 +273,70 @@ __global__ void scaled_masked_softmax_warp_backward( gradInput += thread_offset; // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : element_count; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + copy_vector( + temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count + it * WARP_SIZE); - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { output_reg[i][it + element] = (acc_t)temp_output[element]; } - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; } - } + } } } - + acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { sum[i] += grad_reg[i][it]; } } warp_reduce(sum); - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + out[element] = (output_t)( + scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } + copy_vector( + gradInput + i * element_count + it * WARP_SIZE, out); + } } } } -} // end of anonymous namespace +} // end of anonymous namespace -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) +{ int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; @@ -328,17 +351,16 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att return batches_per_block; } -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) +template +void dispatch_scaled_masked_softmax_forward(output_t* dst, + const input_t* src, + const uint8_t* mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) { if (key_seq_len == 0) { return; @@ -350,7 +372,8 @@ void dispatch_scaled_masked_softmax_forward( // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximimize gpu utilization @@ -358,86 +381,98 @@ void dispatch_scaled_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 + case 0: // 1 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 1: // 2 + case 1: // 2 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 2: // 4 + case 2: // 4 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 3: // 8 + case 3: // 8 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 4: // 16 + case 4: // 16 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 5: // 32 + case 5: // 32 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 6: // 64 + case 6: // 64 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 7: // 128 + case 7: // 128 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 8: // 256 + case 8: // 256 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 9: // 512 + case 9: // 512 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 10: // 1024 + case 10: // 1024 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; - case 11: // 2048 + case 11: // 2048 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; + default: break; } } } -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) +template +void dispatch_scaled_masked_softmax_backward(output_t* grad_input, + input_t* grad, + const input_t* output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { if (key_seq_len == 0) { - return; + return; } else { int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; + int batch_count = batches * attn_heads * query_seq_len; - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximimize gpu utilization @@ -445,60 +480,71 @@ void dispatch_scaled_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; + int blocks = batch_count / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 + case 0: // 1 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 1: // 2 + case 1: // 2 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 2: // 4 + case 2: // 4 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 3: // 8 + case 3: // 8 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 4: // 16 + case 4: // 16 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 5: // 32 + case 5: // 32 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 6: // 64 + case 6: // 64 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 7: // 128 + case 7: // 128 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 8: // 256 + case 8: // 256 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 9: // 512 + case 9: // 512 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 10: // 1024 + case 10: // 1024 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; - case 11: // 2048 + case 11: // 2048 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); break; + default: break; } } } diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu index 902d36dd0..757850d2b 100644 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -15,11 +15,11 @@ */ #include +#include #include -#include #include #include -#include +#include #include #include "scaled_masked_softmax.h" #include "type_shim.h" @@ -28,90 +28,82 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads) +{ return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches); - ); - return softmax_results; + DISPATCH_HALF_AND_BFLOAT(input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches);); + return softmax_results; } -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) +{ + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); + // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); - void* output_grads_ptr = static_cast(output_grads.data_ptr()); + void* output_grads_ptr = static_cast(output_grads.data_ptr()); - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT(output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads);); + + // backward pass is completely in-place + return output_grads; } +} // namespace scaled_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp index 83c9ef595..945c48c43 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp @@ -22,51 +22,49 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor); +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) +{ + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); - return fwd_cuda(input, scale_factor); + return fwd_cuda(input, scale_factor); } -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) +{ + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); + return bwd_cuda(output_grads, softmax_results, scale_factor); } -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); } diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h index bffc29a0f..b74dd4040 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -17,73 +17,102 @@ #pragma once #include +#include #include +#include #include #include -#include -#include namespace { template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); +__device__ __inline__ void copy_vector(Datatype* dst, const Datatype* src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::BFloat16* dst, + const c10::BFloat16* src) +{ + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - +__device__ __inline__ void copy_vector(c10::BFloat16* dst, + const c10::BFloat16* src) +{ + *((float2*)dst) = *((float2*)src); +} + template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half* dst, const c10::Half* src) +{ + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half* dst, const c10::Half* src) +{ + *((float2*)dst) = *((float2*)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t* dst, const uint8_t* src) +{ + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t* dst, const uint8_t* src) +{ + *((half2*)dst) = *((half2*)src); +} template -__device__ __inline__ void copy_zero_vector(Datatype *dst); +__device__ __inline__ void copy_zero_vector(Datatype* dst); template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector(c10::BFloat16* dst) +{ + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } +__device__ __inline__ void copy_zero_vector(c10::BFloat16* dst) +{ + *((float2*)dst) = make_float2(0.0f, 0.0f); +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector(c10::Half* dst) +{ + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - +__device__ __inline__ void copy_zero_vector(c10::Half* dst) +{ + *((float2*)dst) = make_float2(0.0f, 0.0f); +} -int log2_ceil(int value) { +int log2_ceil(int value) +{ int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); @@ -92,13 +121,14 @@ __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int wid #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) +{ ReduceOp r; - #pragma unroll +#pragma unroll for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); sum[i] = r(sum[i], b); } @@ -111,31 +141,30 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { * 2) Implicit time (diagonal masking) */ template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) +__global__ void scaled_upper_triang_masked_softmax_warp_forward(output_t* dst, + const input_t* src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two + : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; @@ -146,27 +175,28 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( // load data from global memory acc_t elements[WARP_BATCH][WARP_ITERATIONS]; input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : local_seq; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; + elements[i][it + element] = (acc_t)temp_data[element] * scale; } else { elements[i][it + element] = -std::numeric_limits::infinity(); } } } else { - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { elements[i][it + element] = -std::numeric_limits::infinity(); } @@ -176,42 +206,40 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( // compute max_value acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } } warp_reduce(max_value); - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { if (it < warp_iteration_limit) { elements[i][it] = std::exp((elements[i][it] - max_value[i])); sum[i] += elements[i][it]; - } + } } } warp_reduce(sum); // store result output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < local_seq) { - - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < local_seq) { out[element] = elements[i][it + element] / sum[i]; @@ -219,42 +247,43 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( out[element] = 0; } } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + copy_zero_vector(dst + i * element_count * stride + + it * WARP_SIZE); } else { break; - } + } } } } template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) +__global__ void scaled_upper_triang_masked_softmax_warp_backward(output_t* gradInput, + input_t* grad, + const input_t* output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two + : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - + int local_seq = blockIdx.x + 1; + // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; // there might be multiple batches per warp. compute the index within the batch int local_idx = threadIdx.x; @@ -266,79 +295,80 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( gradInput += thread_offset; // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { int batch_element_count = (i >= local_batches) ? 0 : local_seq; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { output_reg[i][it + element] = (acc_t)temp_output[element]; } } - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; } } } } } - + acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { sum[i] += grad_reg[i][it]; } } warp_reduce(sum); - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { // compute gradients output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll +#pragma unroll for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + out[element] = (output_t)( + scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } } } } -} // end of anonymous namespace +} // end of anonymous namespace -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) +template +void dispatch_scaled_upper_triang_masked_softmax_forward(output_t* dst, + const input_t* src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) { if (softmax_elements == 0) { return; @@ -351,7 +381,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximimize gpu utilization @@ -364,82 +395,94 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 + case 0: // 1 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 1: // 2 + case 1: // 2 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 2: // 4 + case 2: // 4 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 3: // 8 + case 3: // 8 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 4: // 16 + case 4: // 16 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 5: // 32 + case 5: // 32 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 6: // 64 + case 6: // 64 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 7: // 128 + case 7: // 128 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 8: // 256 + case 8: // 256 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 9: // 512 + case 9: // 512 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 10: // 1024 + case 10: // 1024 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; - case 11: // 2048 + case 11: // 2048 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; + default: break; } } } -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) +template +void dispatch_scaled_upper_triang_masked_softmax_backward(output_t* grad_input, + input_t* grad, + const input_t* output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) { if (softmax_elements == 0) { - return; + return; } else { int log2_elements = log2_ceil(softmax_elements); const int next_power_of_two = 1 << log2_elements; int seq_len = softmax_elements; int batch_count = attn_batches * seq_len; - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; // use 128 threads per block to maximimize gpu utilization @@ -452,56 +495,139 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { - case 0: // 1 + case 0: // 1 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 1: // 2 + case 1: // 2 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 2: // 4 + case 2: // 4 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 3: // 8 + case 3: // 8 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 4: // 16 + case 4: // 16 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 5: // 32 + case 5: // 32 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 6: // 64 + case 6: // 64 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 7: // 128 + case 7: // 128 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 8: // 256 + case 8: // 256 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 9: // 512 + case 9: // 512 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 10: // 1024 + case 10: // 1024 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; - case 11: // 2048 + case 11: // 2048 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: + <<>>( + grad_input, + grad, + output, + scale, + batch_count, + softmax_elements_stride, + softmax_elements); break; + default: break; } } } diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu index 1adb94905..7ced78acd 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -15,11 +15,11 @@ */ #include +#include #include -#include #include #include -#include +#include #include #include "scaled_upper_triang_masked_softmax.h" #include "type_shim.h" @@ -28,71 +28,64 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) +torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 2048); + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - return softmax_results; + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches);); + return softmax_results; } +torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) +{ + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); + // output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + void* output_grads_ptr = static_cast(output_grads.data_ptr()); - void* output_grads_ptr = static_cast(output_grads.data_ptr()); + // Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches);); - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} + // backward pass is completely in-place + return output_grads; } +} // namespace scaled_upper_triang_masked_softmax +} // namespace fused_softmax +} // namespace multihead_attn diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/type_shim.h index 6437dcc7c..1a3a30953 100644 --- a/megatron/fused_kernels/type_shim.h +++ b/megatron/fused_kernels/type_shim.h @@ -14,78 +14,59 @@ * limitations under the License. */ - #include #include "compat.h" - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Half: \ - { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - - -#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch(TYPEIN) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_in = float; \ - switch(TYPEOUT) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) { \ + case at::ScalarType::Float: { \ + using scalar_t_in = float; \ + switch (TYPEOUT) { \ + case at::ScalarType::Float: { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } diff --git a/megatron/gradient_noise_scale/__init__.py b/megatron/gradient_noise_scale/__init__.py index 863a094bd..6b0d07e94 100644 --- a/megatron/gradient_noise_scale/__init__.py +++ b/megatron/gradient_noise_scale/__init__.py @@ -1 +1 @@ -from .gradient_noise_scale import GradientNoiseScale \ No newline at end of file +from .gradient_noise_scale import GradientNoiseScale diff --git a/megatron/gradient_noise_scale/gradient_noise_scale.py b/megatron/gradient_noise_scale/gradient_noise_scale.py index a0414fead..5fe20f7f3 100644 --- a/megatron/gradient_noise_scale/gradient_noise_scale.py +++ b/megatron/gradient_noise_scale/gradient_noise_scale.py @@ -3,7 +3,8 @@ def ema(avg, beta, yi, i): """Exponential moving average""" - if avg is None: avg = 0 + if avg is None: + avg = 0 avg = beta * avg + (1 - beta) * yi return avg, avg / (1 - beta ** (i + 1)) @@ -40,7 +41,16 @@ class GradientNoiseScale: single-gpu environments. Unfortunately it does come with some memory overhead. """ - def __init__(self, model, batch_size_small, n_batches=10, beta=0.99, cpu_offload=False, neox_args=None, mpu=None): + def __init__( + self, + model, + batch_size_small, + n_batches=10, + beta=0.99, + cpu_offload=False, + neox_args=None, + mpu=None, + ): self.batch_size_small = batch_size_small self.batch_size_large = batch_size_small * n_batches self.n_batches = n_batches @@ -59,7 +69,9 @@ def __init__(self, model, batch_size_small, n_batches=10, beta=0.99, cpu_offload def flatten_grads(self): grads = [] - assert hasattr(self.model, 'stored_gradients'), "You might need to update DeeperSpeed" + assert hasattr( + self.model, "stored_gradients" + ), "You might need to update DeeperSpeed" if self.model.stored_gradients is not None: for g in self.model.stored_gradients: if g is not None and not g.isnan().any() and not g.isinf().any(): @@ -78,9 +90,11 @@ def _sync_overflow(self, is_overflow): # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the pipe parallel GPUs overflow_gpu = torch.cuda.ByteTensor([is_overflow]) - torch.distributed.all_reduce(overflow_gpu, - op=torch.distributed.ReduceOp.MAX, - group=self.mpu.get_pipe_parallel_group()) + torch.distributed.all_reduce( + overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=self.mpu.get_pipe_parallel_group(), + ) overflow = overflow_gpu[0].item() else: overflow = is_overflow @@ -115,12 +129,16 @@ def _update(self): g_small = g_small.to(self.model.device) # avg g_big / g_small across pipe parallel groups - torch.distributed.all_reduce(g_big, - op=torch.distributed.ReduceOp.SUM, - group=self.mpu.get_pipe_parallel_group()) - torch.distributed.all_reduce(g_small, - op=torch.distributed.ReduceOp.SUM, - group=self.mpu.get_pipe_parallel_group()) + torch.distributed.all_reduce( + g_big, + op=torch.distributed.ReduceOp.SUM, + group=self.mpu.get_pipe_parallel_group(), + ) + torch.distributed.all_reduce( + g_small, + op=torch.distributed.ReduceOp.SUM, + group=self.mpu.get_pipe_parallel_group(), + ) g_big /= self.mpu.get_pipe_parallel_world_size() g_small /= self.mpu.get_pipe_parallel_world_size() @@ -129,24 +147,40 @@ def _update(self): g_small = torch.square(torch.norm(grad.to(torch.float))) # communicate any overflows - is_overflow = (g_small.isinf().any() or g_small.isnan().any() or g_big.isinf().any() or g_big.isnan().any()) + is_overflow = ( + g_small.isinf().any() + or g_small.isnan().any() + or g_big.isinf().any() + or g_big.isnan().any() + ) is_overflow = self._sync_overflow(is_overflow) if is_overflow: return # calculate noise / scale - noise = 1 / (self.batch_size_large - self.batch_size_small) * ( - self.batch_size_large * g_big - self.batch_size_small * g_small) - scale = 1 / (1 / self.batch_size_small - 1 / self.batch_size_large) * (g_small - g_big) + noise = ( + 1 + / (self.batch_size_large - self.batch_size_small) + * (self.batch_size_large * g_big - self.batch_size_small * g_small) + ) + scale = ( + 1 + / (1 / self.batch_size_small - 1 / self.batch_size_large) + * (g_small - g_big) + ) # calculate running average - self.ema_noise, noise = ema(self.ema_noise, self.beta, noise, self.n_updates) - self.ema_scale, scale = ema(self.ema_scale, self.beta, scale, self.n_updates) + self.ema_noise, noise = ema( + self.ema_noise, self.beta, noise, self.n_updates + ) + self.ema_scale, scale = ema( + self.ema_scale, self.beta, scale, self.n_updates + ) # calculate noise scale scale = scale.item() noise = noise.item() - self.noise_scale = (scale / noise) + self.noise_scale = scale / noise self.n_updates += 1 diff --git a/megatron/initialize.py b/megatron/initialize.py index b4718d268..8960a7d76 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -36,16 +35,15 @@ def initialize_megatron(neox_args, allow_no_cuda=False): """Set initialize distributed and set autoresume and random seeds. - `allow_no_cuda` should not be set unless using megatron for cpu only - data processing. In general this arg should not be set unless you know + `allow_no_cuda` should not be set unless using megatron for cpu only + data processing. In general this arg should not be set unless you know what you are doing. - Returns a function to finalize distributed env initialization + Returns a function to finalize distributed env initialization (optionally, only when args.lazy_mpu_init == True) - -""" + """ if not allow_no_cuda: # Make sure cuda is available. - assert torch.cuda.is_available(), 'Megatron requires CUDA.' + assert torch.cuda.is_available(), "Megatron requires CUDA." # torch.distributed initialization def finish_mpu_init(): @@ -54,7 +52,7 @@ def finish_mpu_init(): # Random seeds for reproducibility. if neox_args.rank == 0: - print('> setting random seeds to {} ...'.format(neox_args.seed)) + print("> setting random seeds to {} ...".format(neox_args.seed)) _set_random_seed(neox_args.seed) # load scaled_upper_triang_masked_softmax_fusion kernel @@ -63,7 +61,7 @@ def finish_mpu_init(): if neox_args.lazy_mpu_init: neox_args.use_cpu_initialization = True # delayed initialization of DDP-related stuff - # We only set basic DDP globals + # We only set basic DDP globals set_model_parallel_world_size(neox_args.model_parallel_size) # and return function for external DDP manager to call when it has DDP initialized set_model_parallel_rank(neox_args.rank) @@ -78,6 +76,7 @@ def finish_mpu_init(): # Compile dataset C++ code. if neox_args.local_rank == 0: from megatron.data.data_utils import compile_helper + compile_helper() # Write arguments to tensorboard. @@ -87,7 +86,7 @@ def finish_mpu_init(): def setup_deepspeed_random_and_activation_checkpointing(neox_args): - '''Optional DeepSpeed Activation Checkpointing features. + """Optional DeepSpeed Activation Checkpointing features. Gives access to partition activations, contiguous memory optimizations and cpu checkpointing. @@ -99,9 +98,13 @@ def setup_deepspeed_random_and_activation_checkpointing(neox_args): we overwrite them to maintain consistency. This must be called before all the calls to mpu.model_parallel_cuda_manual_seed - ''' + """ num_layers = neox_args.num_layers // neox_args.checkpoint_num_layers - num_layers = num_layers if neox_args.num_layers % neox_args.checkpoint_num_layers == 0 else num_layers + 1 + num_layers = ( + num_layers + if neox_args.num_layers % neox_args.checkpoint_num_layers == 0 + else num_layers + 1 + ) deepspeed.checkpointing.configure( mpu, @@ -110,7 +113,8 @@ def setup_deepspeed_random_and_activation_checkpointing(neox_args): num_checkpoints=num_layers, checkpoint_in_cpu=neox_args.checkpoint_in_cpu, synchronize=neox_args.synchronize_each_layer, - profile=neox_args.profile_backward) + profile=neox_args.profile_backward, + ) def _initialize_distributed(neox_args): @@ -120,21 +124,25 @@ def _initialize_distributed(neox_args): if torch.distributed.is_initialized(): if neox_args.rank == 0: - print('torch distributed is already initialized, ' - 'skipping initialization ...', flush=True) + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) neox_args.rank = torch.distributed.get_rank() neox_args.world_size = torch.distributed.get_world_size() else: if neox_args.rank == 0: - print('> initializing torch distributed ...', flush=True) + print("> initializing torch distributed ...", flush=True) # Manually set the device ids. if device_count > 0: device = neox_args.rank % device_count if neox_args.local_rank is not None: - assert neox_args.local_rank == device, \ - 'expected local-rank to be the same as rank % device-count.' + assert ( + neox_args.local_rank == device + ), "expected local-rank to be the same as rank % device-count." else: neox_args.local_rank = device torch.cuda.set_device(device) @@ -142,17 +150,20 @@ def _initialize_distributed(neox_args): distributed.init_distributed( dist_backend=neox_args.distributed_backend, auto_mpi_discovery=True, - distributed_port=os.getenv('MASTER_PORT', '6000'), + distributed_port=os.getenv("MASTER_PORT", "6000"), verbose=True, ) # Setup 3D topology. pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1 mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1 - assert neox_args.world_size % (pp * mp) == 0, f'world_size={neox_args.world_size}, pp={pp}, mp={mp}' + assert ( + neox_args.world_size % (pp * mp) == 0 + ), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}" dp = neox_args.world_size // (pp * mp) from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + # this does pipe on the most outside, then data, then model. # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order. topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp) @@ -160,17 +171,23 @@ def _initialize_distributed(neox_args): # Offset base seeds for the interior pipeline stages. # TODO: adjust last stage too once IO is improved. stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe - if 0 < stage_id < topo.get_dim('pipe') - 1: + if 0 < stage_id < topo.get_dim("pipe") - 1: offset = neox_args.seed + 1138 neox_args.seed = offset + (stage_id * mp) # Set the model-parallel / data-parallel communicators. if device_count > 0: if mpu.model_parallel_is_initialized(): - print('_initialize_distributed() model parallel is already initialized', flush=True) + print( + "_initialize_distributed() model parallel is already initialized", + flush=True, + ) else: - mpu.initialize_model_parallel(neox_args.model_parallel_size, topology=topo, - fp32_allreduce=neox_args.fp32_allreduce) + mpu.initialize_model_parallel( + neox_args.model_parallel_size, + topology=topo, + fp32_allreduce=neox_args.fp32_allreduce, + ) # Init DeepSpeed Activation Checkpointing Features setup_deepspeed_random_and_activation_checkpointing(neox_args=neox_args) @@ -180,12 +197,12 @@ def _init_autoresume(neox_args): """Set autoresume start time.""" if neox_args.adlr_autoresume: - print_rank_0('> enabling autoresume ...') - sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) + print_rank_0("> enabling autoresume ...") + sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) try: from userlib.auto_resume import AutoResume except BaseException: - print('> ADLR autoresume is not available, exiting ...', flush=True) + print("> ADLR autoresume is not available, exiting ...", flush=True) sys.exit() neox_args.adlr_autoresume_object = AutoResume @@ -204,7 +221,7 @@ def _set_random_seed(seed): if torch.cuda.device_count() > 0: mpu.model_parallel_cuda_manual_seed(seed) else: - raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) + raise ValueError("Seed ({}) should be a positive integer.".format(seed)) def _write_args_to_tensorboard(neox_args): @@ -212,4 +229,6 @@ def _write_args_to_tensorboard(neox_args): """Write arguments to tensorboard.""" if neox_args.tensorboard_writer: for arg_name in vars(neox_args): - neox_args.tensorboard_writer.add_text(arg_name, str(getattr(neox_args, arg_name))) + neox_args.tensorboard_writer.add_text( + arg_name, str(getattr(neox_args, arg_name)) + ) diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 1a449be89..81f20ec58 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,11 +22,18 @@ class AnnealingLR(object): """Anneals the learning rate.""" - def __init__(self, optimizer, start_lr, - warmup_iter, total_iters, - decay_style, last_iter, min_lr=0.0, - use_checkpoint_lr_scheduler=True, - override_lr_scheduler=False): + def __init__( + self, + optimizer, + start_lr, + warmup_iter, + total_iters, + decay_style, + last_iter, + min_lr=0.0, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False, + ): # Class values. self.optimizer = optimizer @@ -41,16 +47,17 @@ def __init__(self, optimizer, start_lr, self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: - assert not self.use_checkpoint_lr_scheduler, 'both override and '\ - 'use-checkpoint are set.' + assert not self.use_checkpoint_lr_scheduler, ( + "both override and " "use-checkpoint are set." + ) # Set the learning rate self.step(self.num_iters) - print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) + print_rank_0("> learning rate decay style: {}".format(self.decay_style)) def get_lr(self): """Learning rate decay functions from: - https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter) # Warmup. @@ -58,12 +65,15 @@ def get_lr(self): return float(self.start_lr) * num_iters_ / self.warmup_iter num_iters_ = num_iters_ - self.warmup_iter - if self.decay_style == 'linear': + if self.decay_style == "linear": lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter - elif self.decay_style == 'cosine': - lr = self.start_lr / 2.0 * (math.cos( - math.pi * num_iters_ / self.end_iter) + 1) - elif self.decay_style == 'exponential': + elif self.decay_style == "cosine": + lr = ( + self.start_lr + / 2.0 + * (math.cos(math.pi * num_iters_ / self.end_iter) + 1) + ) + elif self.decay_style == "exponential": # exp(-0.693) = 1/2 lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter) else: @@ -77,16 +87,16 @@ def step(self, step_num=None): self.num_iters = step_num new_lr = self.get_lr() for group in self.optimizer.param_groups: - group['lr'] = new_lr + group["lr"] = new_lr def state_dict(self): state_dict = { - 'start_lr': self.start_lr, - 'warmup_iter': self.warmup_iter, - 'num_iters': self.num_iters, - 'decay_style': self.decay_style, - 'end_iter': self.end_iter, - 'min_lr': self.min_lr + "start_lr": self.start_lr, + "warmup_iter": self.warmup_iter, + "num_iters": self.num_iters, + "decay_style": self.decay_style, + "end_iter": self.end_iter, + "min_lr": self.min_lr, } return state_dict @@ -94,30 +104,34 @@ def _check_and_set(self, cls_value, sd_value, name): """Auxiliary function for checking the values in the checkpoint and setting them.""" if self.override_lr_scheduler: - print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) + print_rank_0(" > overriding {} value to {}".format(name, cls_value)) return cls_value if not self.use_checkpoint_lr_scheduler: - assert cls_value == sd_value, 'AnnealingLR: class input value' \ - 'and checkpoint values for {} do not match'.format(name) - print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, - name)) + assert cls_value == sd_value, ( + "AnnealingLR: class input value" + "and checkpoint values for {} do not match".format(name) + ) + print_rank_0(" > using checkpoint value {} for {}".format(sd_value, name)) return sd_value def load_state_dict(self, sd): - self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], - 'learning rate') - self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], - 'minimum learning rate') - self.warmup_iter = self._check_and_set(self.warmup_iter, - sd['warmup_iter'], - 'warmup iterations') - self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'], - 'total number of iterations') - self.decay_style = self._check_and_set(self.decay_style, - sd['decay_style'], - 'decay style') - - self.num_iters = sd['num_iters'] + self.start_lr = self._check_and_set( + self.start_lr, sd["start_lr"], "learning rate" + ) + self.min_lr = self._check_and_set( + self.min_lr, sd["min_lr"], "minimum learning rate" + ) + self.warmup_iter = self._check_and_set( + self.warmup_iter, sd["warmup_iter"], "warmup iterations" + ) + self.end_iter = self._check_and_set( + self.end_iter, sd["end_iter"], "total number of iterations" + ) + self.decay_style = self._check_and_set( + self.decay_style, sd["decay_style"], "decay style" + ) + + self.num_iters = sd["num_iters"] self.step(self.num_iters) diff --git a/megatron/logging.py b/megatron/logging.py index fc404e010..08beedf68 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -22,9 +21,10 @@ class Tee: - """ Duplicate output to both stdout/err and file """ + """Duplicate output to both stdout/err and file""" + def __init__(self, file, err=False): - self.file = open(file, 'w') + self.file = open(file, "w") self.err = err if not err: self.std = sys.stdout @@ -58,44 +58,68 @@ def flush(self): def human_readable_flops(num): - for unit in ['', 'KFLOPS', 'MFLOPS', 'GFLOPS', 'TFLOPS', 'PFLOPS', 'EFLOPS', 'ZFLOPS']: + for unit in [ + "", + "KFLOPS", + "MFLOPS", + "GFLOPS", + "TFLOPS", + "PFLOPS", + "EFLOPS", + "ZFLOPS", + ]: if abs(num) < 1000.0: return "%3.1f%s" % (num, unit) num /= 1000.0 - return "%.1f%s" % (num, 'Yi') + return "%.1f%s" % (num, "Yi") def get_flops(neox_args, model, iter_time_s): world_size = torch.distributed.get_world_size() ff = model.total_params * 6 attn = neox_args.seq_length * neox_args.hidden_size * neox_args.num_layers * 60 - flops = neox_args.train_batch_size * neox_args.seq_length * (ff + attn) / (iter_time_s * world_size) + flops = ( + neox_args.train_batch_size + * neox_args.seq_length + * (ff + attn) + / (iter_time_s * world_size) + ) return flops -def training_log(neox_args, timers, loss_dict, total_loss_dict, learning_rate, iteration, - loss_scale, report_memory_flag, skipped_iter, model, optimizer, noise_scale_logger): +def training_log( + neox_args, + timers, + loss_dict, + total_loss_dict, + learning_rate, + iteration, + loss_scale, + report_memory_flag, + skipped_iter, + model, + optimizer, + noise_scale_logger, +): """Log training information such as losses, timing, etc.""" # Update losses. - skipped_iters_key = 'skipped iterations' - total_loss_dict[skipped_iters_key] = total_loss_dict.get( - skipped_iters_key, 0) + skipped_iter - got_nan_key = 'got nan' + skipped_iters_key = "skipped iterations" + total_loss_dict[skipped_iters_key] = ( + total_loss_dict.get(skipped_iters_key, 0) + skipped_iter + ) + got_nan_key = "got nan" got_nan = False for key in loss_dict: if not skipped_iter: - total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key] + total_loss_dict[key] = total_loss_dict.get(key, 0.0) + loss_dict[key] else: value = loss_dict[key].float().sum().item() - is_nan = value == float('inf') or \ - value == -float('inf') or \ - value != value + is_nan = value == float("inf") or value == -float("inf") or value != value got_nan = got_nan or is_nan - total_loss_dict[got_nan_key] = total_loss_dict.get( - got_nan_key, 0) + int(got_nan) + total_loss_dict[got_nan_key] = total_loss_dict.get(got_nan_key, 0) + int(got_nan) # Logging. timers_to_log = [] @@ -105,52 +129,93 @@ def add_to_logging(name): timers_to_log.append(name) if not neox_args.is_pipe_parallel: - add_to_logging('forward') - add_to_logging('backward') - add_to_logging('backward-backward') - add_to_logging('backward-allreduce') - add_to_logging('backward-master-grad') - add_to_logging('backward-clip-grad') - add_to_logging('optimizer') - add_to_logging('batch generator') + add_to_logging("forward") + add_to_logging("backward") + add_to_logging("backward-backward") + add_to_logging("backward-allreduce") + add_to_logging("backward-master-grad") + add_to_logging("backward-clip-grad") + add_to_logging("optimizer") + add_to_logging("batch generator") # Log timer info to tensorboard and wandb normalizer = iteration % neox_args.log_interval if normalizer == 0: normalizer = neox_args.log_interval if torch.distributed.get_rank() == 0: - timers.write(names=timers_to_log, iteration=iteration, normalizer=normalizer) + timers.write( + names=timers_to_log, iteration=iteration, normalizer=normalizer + ) else: # with pipeline parallel, the megatron timers are overridden by the deepspeed ones. # Try to grab timer values from model engine. Only recently added to deeperspeed, so check that the engine # has that attribute first - if hasattr(model, 'timer_values') and model.timer_values is not None: - if model.wall_clock_breakdown() and model.global_steps % model.steps_per_print() == 0: + if hasattr(model, "timer_values") and model.timer_values is not None: + if ( + model.wall_clock_breakdown() + and model.global_steps % model.steps_per_print() == 0 + ): timer_values = model.timer_values # deepspeed already logs to tensorboard / prints values, so just log to wandb if neox_args.use_wandb and torch.distributed.get_rank() == 0: for key in timer_values: - tb_wandb_log(f"timers/{key}", timer_values[key], iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) + tb_wandb_log( + f"timers/{key}", + timer_values[key], + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) # write losses, lr, etc. every step - tb_wandb_log('train/learning_rate', learning_rate, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) + tb_wandb_log( + "train/learning_rate", + learning_rate, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) for key in loss_dict: - tb_wandb_log(f'train/{key.replace(" ", "_")}', loss_dict[key], iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) + tb_wandb_log( + f'train/{key.replace(" ", "_")}', + loss_dict[key], + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) if neox_args.fp16: - tb_wandb_log(f'train/loss_scale', loss_scale, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) + tb_wandb_log( + f"train/loss_scale", + loss_scale, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) # log gradient noise scale if neox_args.log_gradient_noise_scale: if noise_scale_logger.noise_scale is not None: - tb_wandb_log(f'train/noise_scale', noise_scale_logger.noise_scale, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) + tb_wandb_log( + f"train/noise_scale", + noise_scale_logger.noise_scale, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) # (optional) Log optimizer states to wandb / tb every step if neox_args.log_optimizer_states: - for k, v in optimizer.state_dict()['optimizer_state_dict']['state'].items(): + for k, v in optimizer.state_dict()["optimizer_state_dict"]["state"].items(): for ki, vi in v.items(): # step, module - if ki != 'step': - opt_state_norm = torch.norm(vi) if hasattr(vi, 'dim') else vi - tb_wandb_log(f'optimizer_state_norms/{k}_{ki}', opt_state_norm, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) + if ki != "step": + opt_state_norm = torch.norm(vi) if hasattr(vi, "dim") else vi + tb_wandb_log( + f"optimizer_state_norms/{k}_{ki}", + opt_state_norm, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) # (optional) Log grad/param norms to wandb / tb every step if neox_args.log_grad_norm or neox_args.log_param_norm: @@ -158,51 +223,99 @@ def add_to_logging(name): model.store_gradients = True # start storing gradients for i, (name, param) in enumerate(model.module.named_parameters()): if neox_args.log_grad_norm: - if hasattr(model, 'stored_gradients') and model.stored_gradients is not None: + if ( + hasattr(model, "stored_gradients") + and model.stored_gradients is not None + ): grad = model.stored_gradients[i] if grad is not None: - tb_wandb_log(f'gradient_norms/{name}', torch.norm(grad), iteration, - use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer, all_ranks=True) + tb_wandb_log( + f"gradient_norms/{name}", + torch.norm(grad), + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + all_ranks=True, + ) if neox_args.log_param_norm: - tb_wandb_log(f'parameter_norms/{name}', torch.norm(param), iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer, all_ranks=True) + tb_wandb_log( + f"parameter_norms/{name}", + torch.norm(param), + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + all_ranks=True, + ) if iteration % neox_args.log_interval == 0: # log other stuff every neox_args.log_interval iters - elapsed_time = timers('interval time').elapsed() + elapsed_time = timers("interval time").elapsed() iteration_time = elapsed_time / neox_args.log_interval samples_per_sec = neox_args.train_batch_size / iteration_time - log_string = ' samples/sec: {:.3f} |'.format(samples_per_sec) - tb_wandb_log('runtime/samples_per_sec', samples_per_sec, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) - tb_wandb_log('runtime/iteration_time', iteration_time, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) - log_string += ' iteration {:8d}/{:8d} |'.format(iteration, neox_args.train_iters) - log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( - elapsed_time * 1000.0 / neox_args.log_interval) - log_string += ' learning rate: {:.3E} |'.format(learning_rate) + log_string = " samples/sec: {:.3f} |".format(samples_per_sec) + tb_wandb_log( + "runtime/samples_per_sec", + samples_per_sec, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) + tb_wandb_log( + "runtime/iteration_time", + iteration_time, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) + log_string += " iteration {:8d}/{:8d} |".format( + iteration, neox_args.train_iters + ) + log_string += " elapsed time per iteration (ms): {:.1f} |".format( + elapsed_time * 1000.0 / neox_args.log_interval + ) + log_string += " learning rate: {:.3E} |".format(learning_rate) num_iterations = max( - 1, neox_args.log_interval - total_loss_dict[skipped_iters_key]) + 1, neox_args.log_interval - total_loss_dict[skipped_iters_key] + ) # log tflop / gpu - flops_per_s_per_gpu = get_flops(neox_args=neox_args, model=model, iter_time_s=iteration_time) - log_string += f' approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |' - tb_wandb_log('runtime/flops_per_sec_per_gpu', flops_per_s_per_gpu, iteration, use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer) + flops_per_s_per_gpu = get_flops( + neox_args=neox_args, model=model, iter_time_s=iteration_time + ) + log_string += ( + f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |" + ) + tb_wandb_log( + "runtime/flops_per_sec_per_gpu", + flops_per_s_per_gpu, + iteration, + use_wandb=neox_args.use_wandb, + tensorboard_writer=neox_args.tensorboard_writer, + ) for key in total_loss_dict: if key not in [skipped_iters_key, got_nan_key]: - v = total_loss_dict[key].item() if hasattr(total_loss_dict[key], 'item') else total_loss_dict[key] + v = ( + total_loss_dict[key].item() + if hasattr(total_loss_dict[key], "item") + else total_loss_dict[key] + ) avg = v / float(num_iterations) - log_string += ' {}: {:.6E} |'.format(key, avg) + log_string += " {}: {:.6E} |".format(key, avg) total_loss_dict[key] = 0.0 if neox_args.precision == "fp16": - log_string += ' loss scale: {:.1f} |'.format(loss_scale) - log_string += ' number of skipped iterations: {:3d} |'.format( - total_loss_dict[skipped_iters_key]) - log_string += ' number of nan iterations: {:3d} |'.format( - total_loss_dict[got_nan_key]) + log_string += " loss scale: {:.1f} |".format(loss_scale) + log_string += " number of skipped iterations: {:3d} |".format( + total_loss_dict[skipped_iters_key] + ) + log_string += " number of nan iterations: {:3d} |".format( + total_loss_dict[got_nan_key] + ) total_loss_dict[skipped_iters_key] = 0 total_loss_dict[got_nan_key] = 0 print_rank_0(log_string) if report_memory_flag: - report_memory('after {} iterations'.format(iteration)) + report_memory("after {} iterations".format(iteration)) report_memory_flag = False timers.log(timers_to_log, normalizer=neox_args.log_interval) @@ -210,9 +323,11 @@ def add_to_logging(name): return report_memory_flag -def tb_wandb_log(key, value, iteration_no, use_wandb, tensorboard_writer=None, all_ranks=False): +def tb_wandb_log( + key, value, iteration_no, use_wandb, tensorboard_writer=None, all_ranks=False +): # logs to both tb and wandb (if present) from the zeroth rank - do_log = torch.distributed.get_rank() == 0 or all_ranks + do_log = torch.distributed.get_rank() == 0 or all_ranks if do_log and value is not None: if tensorboard_writer: tensorboard_writer.add_scalar(key, value, iteration_no) diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index cd428d944..9af46de95 100755 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2021 Biderman et al. This file is based on code by the authors denoted below and has been modified from its original version. # @@ -18,4 +17,4 @@ from .gpt2_model import GPT2ModelPipe from .utils import get_params_for_weight_decay_optimization -from .word_embeddings import SoftEmbedding \ No newline at end of file +from .word_embeddings import SoftEmbedding diff --git a/megatron/model/activations.py b/megatron/model/activations.py index 5eba73c3b..62e7861e7 100644 --- a/megatron/model/activations.py +++ b/megatron/model/activations.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +27,7 @@ def get_activation(neox_args): activation_func = GEGLU(neox_args=neox_args) elif neox_args.activation == "gelu": if neox_args.onnx_safe and neox_args.bias_gelu_fusion: - raise ValueError('onnx_safe + bias_gelu_fusion not compatible') + raise ValueError("onnx_safe + bias_gelu_fusion not compatible") if neox_args.onnx_safe: activation_func = erf_gelu elif neox_args.bias_gelu_fusion: @@ -56,6 +55,7 @@ def get_activation(neox_args): # actual gelu is: # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + @torch.jit.script def bias_gelu(bias, y): x = bias + y @@ -70,7 +70,9 @@ def bias_gelu_back(g, bias, y): x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + ff = 0.5 * x * ( + (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) + ) + 0.5 * (1 + tanh_out) return ff * g @@ -94,7 +96,14 @@ def backward(ctx, grad_output): # This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter @torch.jit.script def erf_gelu(x): - return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype)) + return ( + x + * 0.5 + * ( + torch.erf(x / 1.41421).to(dtype=x.dtype) + + torch.ones_like(x).to(dtype=x.dtype) + ) + ) @torch.jit.script @@ -108,7 +117,6 @@ def mish(x): class GEGLU(torch.nn.Module): - def __init__(self, neox_args): super(GEGLU, self).__init__() if neox_args.onnx_safe: diff --git a/megatron/model/fused_bias_dropout.py b/megatron/model/fused_bias_dropout.py index a3ee1e57a..4f29e5d7a 100644 --- a/megatron/model/fused_bias_dropout.py +++ b/megatron/model/fused_bias_dropout.py @@ -10,7 +10,9 @@ torch._C._jit_override_can_fuse_on_gpu(True) -def bias_dropout_add(x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float, training: bool) -> Tensor: +def bias_dropout_add( + x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float, training: bool +) -> Tensor: out = torch.nn.functional.dropout(x + bias, p=prob, training=training) if residual is not None: out = residual + out @@ -20,14 +22,19 @@ def bias_dropout_add(x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add @torch.jit.script -def bias_dropout_add_fused_train(x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float) -> Tensor: +def bias_dropout_add_fused_train( + x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float +) -> Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script -def bias_dropout_add_fused_inference(x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float) -> Tensor: +def bias_dropout_add_fused_inference( + x: Tensor, bias: Tensor, residual: Optional[Tensor], prob: float +) -> Tensor: return bias_dropout_add(x, bias, residual, prob, False) diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 04175ff27..ad98f5e3c 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,90 +16,107 @@ import torch.nn as nn import enum + class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. """ @staticmethod def forward(ctx, inputs, scale): import scaled_upper_triang_masked_softmax_cuda + scale_t = torch.tensor([scale]) - softmax_results = \ - scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) + softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( + inputs, scale_t[0] + ) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod def backward(ctx, output_grads): import scaled_upper_triang_masked_softmax_cuda + softmax_results, scale_t = ctx.saved_tensors - input_grads = \ - scaled_upper_triang_masked_softmax_cuda.backward(output_grads, - softmax_results, - scale_t[0]) + input_grads = scaled_upper_triang_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) return input_grads, None class ScaledMaskedSoftmax(torch.autograd.Function): """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. """ @staticmethod def forward(ctx, inputs, mask, scale): import scaled_masked_softmax_cuda + scale_t = torch.tensor([scale]) - softmax_results = \ - scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) + softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod def backward(ctx, output_grads): import scaled_masked_softmax_cuda + softmax_results, scale_t = ctx.saved_tensors - input_grads = \ - scaled_masked_softmax_cuda.backward(output_grads, - softmax_results, - scale_t[0]) + input_grads = scaled_masked_softmax_cuda.backward( + output_grads, softmax_results, scale_t[0] + ) return input_grads, None, None + class SoftmaxFusionTypes(enum.Enum): - upper_triang = 1 # causal mask - general = 2 # general mask - none = 3 # no fusion + upper_triang = 1 # causal mask + general = 2 # general mask + none = 3 # no fusion + class FusedScaleMaskSoftmax(nn.Module): """ - fused operation: scaling + mask + softmax - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - fusion_type: type of fusion to perform, should be either upper_triang, general or none. None will perform a regular torch softmax. - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. + fused operation: scaling + mask + softmax + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + fusion_type: type of fusion to perform, should be either upper_triang, general or none. None will perform a regular torch softmax. + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. """ - def __init__(self, input_in_fp16, input_in_bf16, fusion_type, mask_func, softmax_in_fp32, scale): + def __init__( + self, + input_in_fp16, + input_in_bf16, + fusion_type, + mask_func, + softmax_in_fp32, + scale, + ): super().__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - assert fusion_type in [SoftmaxFusionTypes.upper_triang, SoftmaxFusionTypes.general, SoftmaxFusionTypes.none], f"Invalid fusion type {fusion_type}" + assert fusion_type in [ + SoftmaxFusionTypes.upper_triang, + SoftmaxFusionTypes.general, + SoftmaxFusionTypes.none, + ], f"Invalid fusion type {fusion_type}" self.upper_triang_mask_fusion = fusion_type == SoftmaxFusionTypes.upper_triang self.general_mask_fusion = fusion_type == SoftmaxFusionTypes.general self.fusion = fusion_type != SoftmaxFusionTypes.none @@ -108,8 +124,9 @@ def __init__(self, input_in_fp16, input_in_bf16, fusion_type, mask_func, softmax self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - assert self.scale is None or softmax_in_fp32, \ - 'softmax should be in fp32 when scaled' + assert ( + self.scale is None or softmax_in_fp32 + ), "softmax should be in fp32 when scaled" def forward(self, input, mask): # [b, np, sq, sk] @@ -140,7 +157,7 @@ def is_kernel_available(self, mask, b, np, sq, sk): if sq % batch_per_block == 0: return True return False - + def forward_fused_softmax(self, input, mask): b, np, sq, sk = input.size() scale = self.scale if self.scale is not None else 1.0 @@ -169,10 +186,11 @@ def forward_torch_softmax(self, input, mask): probs = probs.half() else: probs = probs.bfloat16() - + return probs - + @staticmethod def get_batch_per_block(b, np, sq, sk): import scaled_masked_softmax_cuda - return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) \ No newline at end of file + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/megatron/model/gmlp.py b/megatron/model/gmlp.py index b48fddcf0..96e7e654f 100644 --- a/megatron/model/gmlp.py +++ b/megatron/model/gmlp.py @@ -14,7 +14,7 @@ class TinyAttention(nn.Module): def __init__(self, neox_args, d_attn, d_ff, mask_fn): super().__init__() self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn) - self.scale = d_attn**-0.5 + self.scale = d_attn ** -0.5 self.proj_ffn = nn.Linear(d_attn, d_ff) self.softmax = FusedScaleMaskSoftmax( input_in_fp16=neox_args.precision == "fp16", diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index eea0791a0..76e971dbb 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2021 Biderman et al. This file is based on code by the authors denoted below and has been modified from its original version. # @@ -323,7 +322,7 @@ def train_mode(self): recursive_setattr(self.forward_funcs, "use_cache", False) # then set parallel output to true (more efficient training) self._set_parallel_output(True) - + def clear_cache(self): """ Recursively clears the kv cache on all layers diff --git a/megatron/model/init_functions.py b/megatron/model/init_functions.py index 267120802..32ddf47ea 100644 --- a/megatron/model/init_functions.py +++ b/megatron/model/init_functions.py @@ -21,6 +21,7 @@ def init_(tensor): return init_ + # orthogonal init does not support fp16, so have to patch it def _orthogonal(tensor, gain=1): if tensor.ndimension() < 2: @@ -35,7 +36,7 @@ def _orthogonal(tensor, gain=1): # Compute the qr factorization dt = flattened.dtype - flattened = flattened.to(torch.float32) # orthogonal init does not support fp16 + flattened = flattened.to(torch.float32) # orthogonal init does not support fp16 q, r = torch.qr(flattened) q, r = q.to(dtype=dt), r.to(dtype=dt) # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf @@ -51,18 +52,20 @@ def _orthogonal(tensor, gain=1): tensor.mul_(gain) return tensor + def orthogonal_init_method(n_layers=1): - """Fills the input Tensor with a (semi) orthogonal matrix, as described in + """Fills the input Tensor with a (semi) orthogonal matrix, as described in Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013) - Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released) """ + Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released)""" def init_(tensor): return _orthogonal(tensor, math.sqrt(2 / n_layers)) return init_ + def xavier_uniform_init_method(): - """Fills the input Tensor with values according to the method described in Understanding the difficulty of + """Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution.""" def init_(tensor): @@ -70,8 +73,9 @@ def init_(tensor): return init_ + def xavier_normal_init_method(): - """Fills the input Tensor with values according to the method described in Understanding the difficulty of + """Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution.""" def init_(tensor): @@ -79,8 +83,9 @@ def init_(tensor): return init_ + def small_init_init_method(dim): - """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving + """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution.""" std = math.sqrt(2 / (5 * dim)) @@ -89,6 +94,7 @@ def init_(tensor): return init_ + def wang_init_method(n_layers, dim): std = 2 / n_layers / math.sqrt(dim) @@ -97,9 +103,10 @@ def init_(tensor): return init_ + def get_init_methods(args): def _get(name): - if name == "normal": + if name == "normal": return init_method_normal(args.init_method_std) elif name == "scaled_normal": return scaled_init_method_normal(args.init_method_std, args.num_layers) @@ -117,5 +124,5 @@ def _get(name): return small_init_init_method(args.hidden_size) else: raise NotImplementedError(f"Unkown init method {name}") - + return _get(args.init_method), _get(args.output_layer_init_method) diff --git a/megatron/model/norms.py b/megatron/model/norms.py index e05528d46..94d335d17 100644 --- a/megatron/model/norms.py +++ b/megatron/model/norms.py @@ -69,4 +69,4 @@ def __init__(self, dim, eps=1e-5): def forward(self, x): n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) - return x / n * self.g \ No newline at end of file + return x / n * self.g diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 24ff17bee..c7c4b69dc 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -1,12 +1,12 @@ import torch -import math +import math -class SinusoidalPositionalEmbedding(torch.nn.Module): +class SinusoidalPositionalEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half): super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) self.precision = precision def forward(self, x, seq_dim=1): @@ -22,11 +22,10 @@ def forward(self, x, seq_dim=1): class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half): super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None @@ -38,7 +37,7 @@ def forward(self, x, seq_dim=1, seq_len=None): if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: emb = emb.float() @@ -52,75 +51,101 @@ def forward(self, x, seq_dim=1, seq_len=None): # rotary pos emb helpers: + def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat( + (-x2, x1), dim=x1.ndim - 1 + ) # dim=-1 triggers a bug in earlier torch versions + @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): - cos, sin = cos[offset:q.shape[0]+offset, ...], sin[offset:q.shape[0]+offset, ...] + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) -def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 - cos, sin = cos[offset:q.shape[0]+offset, ...], sin[offset:q.shape[0]+offset, ...] + +def apply_rotary_pos_emb_torch( + q, k, cos, sin, offset: int = 0 +): # jitting fails with bf16 + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) class AliBi(torch.nn.Module): - - def __init__(self, num_heads, mp_size=1, mp_rank=1): - super().__init__() - # megatron splits across heads, so we need to make sure each - # head receives the correct matrix - assert mp_size <= num_heads and mp_rank <= mp_size - self.mp_size = mp_size - self.mp_rank = mp_rank - self.num_heads = num_heads - self.slice_size = num_heads // mp_size - self.cached_matrix = None - self.cached_seq_len = None - slopes = torch.Tensor(self._get_slopes(num_heads))[mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size] - self.register_buffer('slopes', slopes) - - - def _get_slopes(self, n): - """ - Get slopes for Alibi positional embedding - n : int = number of heads. - For best performance, restrict n to a power of 2. - """ - def get_slopes_power_of_2(n): - start = (2**(-2**-(math.log2(n)-3))) - ratio = start - return [start*ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return get_slopes_power_of_2(closest_power_of_2) + self._get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2] - - def forward(self, x): - # [b, np, sq, sk] - seq_len_q = x.shape[-2] - seq_len_k = x.shape[-1] - if self.cached_seq_len != seq_len_k: - a = -torch.tril(torch.arange(seq_len_k).view(seq_len_k, 1).repeat(1, seq_len_k) + torch.arange(0, -seq_len_k, -1)) - a = a.to(x.device).to(x.dtype) - slopes = self.slopes.to(a.device).to(a.dtype) - a = a * slopes.view(self.slopes.shape[0], 1, 1) - self.cached_seq_len = seq_len_k - self.cached_matrix = a - else: - a = self.cached_matrix - - if seq_len_q != seq_len_k: - # In the train case x has dimensionality [b, np, sq, sk] with sq == sk - # The number of query tokens is equal to the number of key tokens - # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) - # In this case we use the appropriate token index of the cache matrix. - # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used - assert seq_len_q == 1, "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" - a = a[:, seq_len_k - 1, :].view(a.shape[0], 1, a.shape[2]) # seq_len_k - 1 points to the last token index in the current inference batch. - - return x + a \ No newline at end of file + def __init__(self, num_heads, mp_size=1, mp_rank=1): + super().__init__() + # megatron splits across heads, so we need to make sure each + # head receives the correct matrix + assert mp_size <= num_heads and mp_rank <= mp_size + self.mp_size = mp_size + self.mp_rank = mp_rank + self.num_heads = num_heads + self.slice_size = num_heads // mp_size + self.cached_matrix = None + self.cached_seq_len = None + slopes = torch.Tensor(self._get_slopes(num_heads))[ + mp_rank * self.slice_size : (mp_rank + 1) * self.slice_size + ] + self.register_buffer("slopes", slopes) + + def _get_slopes(self, n): + """ + Get slopes for Alibi positional embedding + n : int = number of heads. + For best performance, restrict n to a power of 2. + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : n - closest_power_of_2 + ] + ) + + def forward(self, x): + # [b, np, sq, sk] + seq_len_q = x.shape[-2] + seq_len_k = x.shape[-1] + if self.cached_seq_len != seq_len_k: + a = -torch.tril( + torch.arange(seq_len_k).view(seq_len_k, 1).repeat(1, seq_len_k) + + torch.arange(0, -seq_len_k, -1) + ) + a = a.to(x.device).to(x.dtype) + slopes = self.slopes.to(a.device).to(a.dtype) + a = a * slopes.view(self.slopes.shape[0], 1, 1) + self.cached_seq_len = seq_len_k + self.cached_matrix = a + else: + a = self.cached_matrix + + if seq_len_q != seq_len_k: + # In the train case x has dimensionality [b, np, sq, sk] with sq == sk + # The number of query tokens is equal to the number of key tokens + # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence) + # In this case we use the appropriate token index of the cache matrix. + # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used + assert ( + seq_len_q == 1 + ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1" + a = a[:, seq_len_k - 1, :].view( + a.shape[0], 1, a.shape[2] + ) # seq_len_k - 1 points to the last token index in the current inference batch. + + return x + a diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 67426d896..64d024a35 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2021 Biderman et al. This file is based on code by the authors denoted below and has been modified from its original version. # diff --git a/megatron/model/utils.py b/megatron/model/utils.py index bf7579101..ea345fdfe 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2021 Biderman et al. This file is based on code by the authors denoted below and has been modified from its original version. # diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 419e596e5..611d2adbf 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/megatron/mpu/cross_entropy.py b/megatron/mpu/cross_entropy.py index 79ea83db9..850442388 100644 --- a/megatron/mpu/cross_entropy.py +++ b/megatron/mpu/cross_entropy.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,15 +22,16 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): - @staticmethod def forward(ctx, vocab_parallel_logits, target): # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=get_model_parallel_group()) + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=get_model_parallel_group(), + ) # Subtract the maximum value. vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) @@ -41,7 +41,8 @@ def forward(ctx, vocab_parallel_logits, target): rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( - partition_vocab_size, rank, world_size) + partition_vocab_size, rank, world_size + ) # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) @@ -53,24 +54,29 @@ def forward(ctx, vocab_parallel_logits, target): # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], - device=logits_2d.device) + arange_1d = torch.arange( + start=0, end=logits_2d.size()[0], device=logits_2d.device + ) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=get_model_parallel_group()) + torch.distributed.all_reduce( + predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group(), + ) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=get_model_parallel_group()) + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group(), + ) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits @@ -94,10 +100,8 @@ def backward(ctx, grad_output): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= ( - 1.0 - target_mask.view(-1).float()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/megatron/mpu/data.py b/megatron/mpu/data.py index 84b0af6cd..521cb05c1 100644 --- a/megatron/mpu/data.py +++ b/megatron/mpu/data.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,8 +25,11 @@ def _check_data_types(keys, data, target_dtype): """Check that all the keys have the same target data type.""" for key in keys: - assert data[key].dtype == target_dtype, '{} has data type {} which '\ - 'is different than {}'.format(key, data[key].dtype, target_dtype) + assert ( + data[key].dtype == target_dtype + ), "{} has data type {} which " "is different than {}".format( + key, data[key].dtype, target_dtype + ) def _build_key_size_numel_dictionaries(keys, data): @@ -39,7 +41,7 @@ def _build_key_size_numel_dictionaries(keys, data): if get_model_parallel_rank() == 0: offset = 0 for key in keys: - assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" size = data[key].size() for i, s in enumerate(size): sizes[i + offset] = s @@ -47,8 +49,9 @@ def _build_key_size_numel_dictionaries(keys, data): # Move to GPU and broadcast. sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(), - group=get_model_parallel_group()) + torch.distributed.broadcast( + sizes_cuda, get_model_parallel_src_rank(), group=get_model_parallel_group() + ) # Move back to cpu and unpack. sizes_cpu = sizes_cuda.cpu() @@ -85,8 +88,7 @@ def broadcast_data(keys, data, datatype): """ # Build (key, size) and (key, number of elements) dictionaries along # with the total number of elements on all ranks. - key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, - data) + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) # Pack on rank zero. if get_model_parallel_rank() == 0: @@ -94,15 +96,17 @@ def broadcast_data(keys, data, datatype): _check_data_types(keys, data, datatype) # Flatten the data associated with the keys flatten_data = torch.cat( - [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + [data[key].contiguous().view(-1) for key in keys], dim=0 + ).cuda() else: - flatten_data = torch.empty(total_numel, - device=torch.cuda.current_device(), - dtype=datatype) + flatten_data = torch.empty( + total_numel, device=torch.cuda.current_device(), dtype=datatype + ) # Boradcast - torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(), - group=get_model_parallel_group()) + torch.distributed.broadcast( + flatten_data, get_model_parallel_src_rank(), group=get_model_parallel_group() + ) # Unpack output = {} diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index e7c62a56d..80d6d5847 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -68,8 +67,7 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce ranks 8 to 15 belong to the second box. """ if torch.distributed.get_rank() == 0: - print('> initializing model parallel with size {}'.format( - model_parallel_size)) + print("> initializing model parallel with size {}".format(model_parallel_size)) # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() @@ -84,13 +82,12 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce # Build the data parallel groups. global _DATA_PARALLEL_GROUP - assert _DATA_PARALLEL_GROUP is None, \ - 'data parallel group is already initialized' + assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" if topology: - for dp_group in topology.get_axis_comm_lists('data'): + for dp_group in topology.get_axis_comm_lists("data"): group = torch.distributed.new_group(ranks=dp_group) if rank == 0: - print(f'MPU DP:', dp_group) + print(f"MPU DP:", dp_group) if rank in dp_group: _DATA_PARALLEL_GROUP = group else: @@ -103,22 +100,22 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce # Build pipeline parallel group if topology is not None: global _PIPE_PARALLEL_GROUP - for pp_group in topology.get_axis_comm_lists('pipe'): + for pp_group in topology.get_axis_comm_lists("pipe"): group = torch.distributed.new_group(ranks=pp_group) if rank == 0: - print(f'MPU PP:', pp_group) + print(f"MPU PP:", pp_group) if rank in pp_group: _PIPE_PARALLEL_GROUP = group # Build IO group global _IO_PARALLEL_GROUP - if topology and topology.get_dim('pipe') > 1: - io_stages = [0, topology.get_dim('pipe') - 1] + if topology and topology.get_dim("pipe") > 1: + io_stages = [0, topology.get_dim("pipe") - 1] io_group = [] for stage in io_stages: io_group.extend(topology.filter_match(pipe=stage, model=0)) if rank == 0: - print(f'MPU IO:', io_group) + print(f"MPU IO:", io_group) group = torch.distributed.new_group(ranks=io_group) if rank in io_group: _IO_PARALLEL_GROUP = group @@ -127,8 +124,7 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce # Build the model parallel groups. global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, \ - 'model parallel group is already initialized' + assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" if topology: # Short circuit case without model parallelism. # TODO: it would be nice to avoid this branching case? @@ -136,28 +132,27 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce for group_rank in range(world_size): group = torch.distributed.new_group(ranks=[group_rank]) if rank == 0: - print(f'MPU MP:', [group_rank]) + print(f"MPU MP:", [group_rank]) if rank == group_rank: _MODEL_PARALLEL_GROUP = group return - for mp_group in topology.get_axis_comm_lists('model'): + for mp_group in topology.get_axis_comm_lists("model"): group = torch.distributed.new_group(ranks=mp_group) if rank == 0: - print(f'MPU MP:', mp_group) + print(f"MPU MP:", mp_group) if rank in mp_group: _MODEL_PARALLEL_GROUP = group else: for i in range(world_size // model_parallel_size): - ranks = range(i * model_parallel_size, - (i + 1) * model_parallel_size) + ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) group = torch.distributed.new_group(ranks) if i == (rank // model_parallel_size): _MODEL_PARALLEL_GROUP = group global _FP32_ALLREDUCE - assert _FP32_ALLREDUCE is None, 'fp32_allreduce is already initialized' + assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized" _FP32_ALLREDUCE = fp32_allreduce @@ -170,22 +165,19 @@ def model_parallel_is_initialized(): def get_model_parallel_group(): """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, \ - 'model parallel group is not initialized' + assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" return _MODEL_PARALLEL_GROUP def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" - assert _DATA_PARALLEL_GROUP is not None, \ - 'data parallel group is not initialized' + assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _DATA_PARALLEL_GROUP def get_io_parallel_group(): """Get the IO parallel group the caller rank belongs to.""" - assert _IO_PARALLEL_GROUP is not None, \ - 'IO parallel group is not initialized' + assert _IO_PARALLEL_GROUP is not None, "IO parallel group is not initialized" return _IO_PARALLEL_GROUP @@ -235,7 +227,7 @@ def get_data_parallel_src_rank(): return global_rank % get_model_parallel_world_size() else: # We are using pipeline parallel - d = topo.get_axis_comm_lists('data') + d = topo.get_axis_comm_lists("data") for l in d: if global_rank in l: return l[0] @@ -257,8 +249,7 @@ def get_topology(): def get_pipe_parallel_group(): """Get the pipe parallel group the caller rank belongs to.""" - assert _PIPE_PARALLEL_GROUP is not None, \ - 'data parallel group is not initialized' + assert _PIPE_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _PIPE_PARALLEL_GROUP @@ -294,6 +285,5 @@ def destroy_model_parallel(): def get_fp32_allreduce(): """Get the fp32 allreduce flag""" - assert _FP32_ALLREDUCE is not None, \ - 'fp32_allreduce is not Initialized' + assert _FP32_ALLREDUCE is not None, "fp32_allreduce is not Initialized" return _FP32_ALLREDUCE diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index d6895f60d..514b847db 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -38,8 +37,8 @@ from .utils import divide from .utils import VocabUtility -def _initialize_affine_weight_gpu(weight, init_method, - partition_dim, stride=1): + +def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): """Initialize affine weight for model parallel on GPU.""" weight.model_parallel = True @@ -50,10 +49,17 @@ def _initialize_affine_weight_gpu(weight, init_method, init_method(weight) -def _initialize_affine_weight_cpu(neox_args, weight, output_size, input_size, - per_partition_size, partition_dim, - init_method, stride=1, - return_master_weight=False): +def _initialize_affine_weight_cpu( + neox_args, + weight, + output_size, + input_size, + per_partition_size, + partition_dim, + init_method, + stride=1, + return_master_weight=False, +): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter @@ -64,16 +70,17 @@ def _initialize_affine_weight_cpu(neox_args, weight, output_size, input_size, weight.partition_stride = stride # Initialize master weight - master_weight = torch.empty(output_size, input_size, - dtype=torch.float, - requires_grad=False) + master_weight = torch.empty( + output_size, input_size, dtype=torch.float, requires_grad=False + ) init_method(master_weight) master_weight = master_weight.to(dtype=neox_args.params_dtype) # Split and copy per_partition_per_stride_size = divide(per_partition_size, stride) - weight_list = torch.split(master_weight, per_partition_per_stride_size, - dim=partition_dim) + weight_list = torch.split( + master_weight, per_partition_per_stride_size, dim=partition_dim + ) rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() my_weight_list = weight_list[rank::world_size] @@ -96,7 +103,9 @@ class VocabParallelEmbedding(torch.nn.Module): init_method: method to initialize weights. """ - def __init__(self, neox_args, num_embeddings, embedding_dim, init_method=init.xavier_normal_): + def __init__( + self, neox_args, num_embeddings, embedding_dim, init_method=init.xavier_normal_ + ): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings @@ -104,49 +113,74 @@ def __init__(self, neox_args, num_embeddings, embedding_dim, init_method=init.xa # Set the detauls for compatibility. self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None self.model_parallel_size = get_model_parallel_world_size() # Divide the weight matrix along the vocabulary dimension. - self.vocab_start_index, self.vocab_end_index = \ - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_model_parallel_rank(), - self.model_parallel_size) - self.num_embeddings_per_partition = self.vocab_end_index - \ - self.vocab_start_index + ( + self.vocab_start_index, + self.vocab_end_index, + ) = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_model_parallel_rank(), self.model_parallel_size + ) + self.num_embeddings_per_partition = ( + self.vocab_end_index - self.vocab_start_index + ) # Allocate weights and initialize. if neox_args.use_cpu_initialization: - self.weight = Parameter(torch.empty( - self.num_embeddings_per_partition, self.embedding_dim, - dtype=neox_args.params_dtype)) + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + dtype=neox_args.params_dtype, + ) + ) _initialize_affine_weight_cpu( - neox_args, self.weight, self.num_embeddings, self.embedding_dim, - self.num_embeddings_per_partition, 0, init_method) + neox_args, + self.weight, + self.num_embeddings, + self.embedding_dim, + self.num_embeddings_per_partition, + 0, + init_method, + ) else: - self.weight = Parameter(torch.empty( - self.num_embeddings_per_partition, self.embedding_dim, - device=torch.cuda.current_device(), dtype=neox_args.params_dtype)) - _initialize_affine_weight_gpu(self.weight, init_method, - partition_dim=0, stride=1) + self.weight = Parameter( + torch.empty( + self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=0, stride=1 + ) def forward(self, input_): if self.model_parallel_size > 1: # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) + input_mask = (input_ < self.vocab_start_index) | ( + input_ >= self.vocab_end_index + ) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight, - self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, - self.sparse) + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) # Mask the output embedding. if self.model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 @@ -169,7 +203,16 @@ class ParallelRelativePositionBias(torch.nn.Module): heads: number of attention heads (total) """ - def __init__(self, neox_args, scale, causal=True, num_buckets=32, max_distance=128, heads=8, init_method=init.xavier_normal_): + def __init__( + self, + neox_args, + scale, + causal=True, + num_buckets=32, + max_distance=128, + heads=8, + init_method=init.xavier_normal_, + ): super().__init__() self.scale = scale self.causal = causal @@ -180,7 +223,7 @@ def __init__(self, neox_args, scale, causal=True, num_buckets=32, max_distance=1 # Set the defaults for compatibility. self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None @@ -188,24 +231,41 @@ def __init__(self, neox_args, scale, causal=True, num_buckets=32, max_distance=1 self.model_parallel_rank = get_model_parallel_rank() # Divide the weight matrix along the heads dimension. - self.head_start_index, self.head_end_index = self.get_heads_range(self.heads, self.model_parallel_rank, - self.model_parallel_size) + self.head_start_index, self.head_end_index = self.get_heads_range( + self.heads, self.model_parallel_rank, self.model_parallel_size + ) self.num_heads_per_partition = self.head_end_index - self.head_start_index # Allocate weights and initialize. if neox_args.use_cpu_initialization: - self.weight = Parameter(torch.empty( - self.num_buckets, self.num_heads_per_partition, - dtype=neox_args.params_dtype)) + self.weight = Parameter( + torch.empty( + self.num_buckets, + self.num_heads_per_partition, + dtype=neox_args.params_dtype, + ) + ) _initialize_affine_weight_cpu( - neox_args, self.weight, self.num_buckets, self.heads, - self.num_heads_per_partition, partition_dim=1, init_method=init_method) + neox_args, + self.weight, + self.num_buckets, + self.heads, + self.num_heads_per_partition, + partition_dim=1, + init_method=init_method, + ) else: - self.weight = Parameter(torch.empty( - self.num_buckets, self.num_heads_per_partition, - device=torch.cuda.current_device(), dtype=neox_args.params_dtype)) - _initialize_affine_weight_gpu(self.weight, init_method, - partition_dim=1, stride=1) + self.weight = Parameter( + torch.empty( + self.num_buckets, + self.num_heads_per_partition, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=1, stride=1 + ) self._q_len_cached = None self._k_len_cached = None self._rel_pos_bucket_cached = None @@ -217,7 +277,9 @@ def get_heads_range(global_n_heads, rank, world_size): index_l = index_f + per_partition_n_heads return index_f, index_l - def _relative_position_bucket(self, relative_position, num_buckets=32, max_distance=128): + def _relative_position_bucket( + self, relative_position, num_buckets=32, max_distance=128 + ): ret = 0 n = -relative_position if not self.causal: @@ -230,10 +292,17 @@ def _relative_position_bucket(self, relative_position, num_buckets=32, max_dista max_exact = num_buckets // 2 is_small = n < max_exact - val_if_large = max_exact + ( - torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - ).long() - val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + val_if_large = ( + max_exact + + ( + torch.log(n.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).long() + ) + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1) + ) ret += torch.where(is_small, n, val_if_large) self._rel_pos_bucket_cached = ret @@ -243,16 +312,28 @@ def forward(self, q_len, k_len): if self._q_len_cached != q_len or self._k_len_cached != k_len: # cache bucket if first step seq len stays constant self._q_len_cached, self._k_len_cached = q_len, k_len - q_pos = torch.arange(q_len, dtype=torch.long, device=torch.cuda.current_device()) - k_pos = torch.arange(k_len, dtype=torch.long, device=torch.cuda.current_device()) + q_pos = torch.arange( + q_len, dtype=torch.long, device=torch.cuda.current_device() + ) + k_pos = torch.arange( + k_len, dtype=torch.long, device=torch.cuda.current_device() + ) rel_pos = k_pos[None, :] - q_pos[:, None] - rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, - max_distance=self.max_distance) + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance + ) else: rp_bucket = self._rel_pos_bucket_cached - values = F.embedding(rp_bucket, self.weight, self.padding_idx, - self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) - bias = values.movedim(2,0).unsqueeze(0) + values = F.embedding( + rp_bucket, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + bias = values.movedim(2, 0).unsqueeze(0) return bias * self.scale @@ -276,14 +357,22 @@ class ColumnParallelLinear(torch.nn.Module): set to False. It returns the master weights used for initialization. skip_bias_add: This was added to enable performance optimations where bias - can be fused with other elementwise operations. we skip + can be fused with other elementwise operations. we skip adding bias but instead return it. """ - def __init__(self, neox_args, input_size, output_size, bias=True, gather_output=True, - init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False): + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + gather_output=True, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + ): super(ColumnParallelLinear, self).__init__() # Keep input parameters @@ -300,29 +389,52 @@ def __init__(self, neox_args, input_size, output_size, bias=True, gather_output= # we allocate the transpose. # Initialize weight. if neox_args.use_cpu_initialization: - self.weight = Parameter(torch.empty(self.output_size_per_partition, - self.input_size, - dtype=neox_args.params_dtype)) + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + dtype=neox_args.params_dtype, + ) + ) self.master_weight = _initialize_affine_weight_cpu( - neox_args, self.weight, self.output_size, self.input_size, - self.output_size_per_partition, 0, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + neox_args, + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + ) else: - self.weight = Parameter(torch.empty( - self.output_size_per_partition, self.input_size, - device=torch.cuda.current_device(), dtype=neox_args.params_dtype)) - _initialize_affine_weight_gpu(self.weight, init_method, - partition_dim=0, stride=stride) + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=0, stride=stride + ) if bias: if neox_args.use_cpu_initialization: - self.bias = Parameter(torch.empty( - self.output_size_per_partition, dtype=neox_args.params_dtype)) + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, dtype=neox_args.params_dtype + ) + ) else: - self.bias = Parameter(torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=neox_args.params_dtype)) + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) self.bias.model_parallel = True self.bias.partition_dim = 0 self.bias.stride = stride @@ -330,11 +442,13 @@ def __init__(self, neox_args, input_size, output_size, bias=True, gather_output= with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def set_parallel_output(self, value: bool): assert isinstance(value, bool) - self.gather_output = not value # if gather_output is True, parallel output is False, so we set the opposite + self.gather_output = ( + not value + ) # if gather_output is True, parallel output is False, so we set the opposite def forward(self, input_): # Set up backprop all-reduce. @@ -378,16 +492,23 @@ class RowParallelLinear(torch.nn.Module): set to False. It returns the master weights used for initialization. skip_bias_add: This was added to enable performance optimations where bias - can be fused with other elementwise operations. we skip + can be fused with other elementwise operations. we skip adding bias but instead return it. """ - def __init__(self, neox_args, input_size, output_size, bias=True, - input_is_parallel=False, - init_method=init.xavier_normal_, stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - parallel_output=False): + def __init__( + self, + neox_args, + input_size, + output_size, + bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + parallel_output=False, + ): super(RowParallelLinear, self).__init__() # Keep input parameters @@ -405,32 +526,54 @@ def __init__(self, neox_args, input_size, output_size, bias=True, # we allocate the transpose. # Initialize weight. if neox_args.use_cpu_initialization: - self.weight = Parameter(torch.empty(self.output_size, - self.input_size_per_partition, - dtype=neox_args.params_dtype)) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + dtype=neox_args.params_dtype, + ) + ) self.master_weight = _initialize_affine_weight_cpu( - neox_args, self.weight, self.output_size, self.input_size, - self.input_size_per_partition, 1, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) + neox_args, + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + ) else: - self.weight = Parameter(torch.empty( - self.output_size, self.input_size_per_partition, - device=torch.cuda.current_device(), dtype=neox_args.params_dtype)) - _initialize_affine_weight_gpu(self.weight, init_method, - partition_dim=1, stride=stride) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + _initialize_affine_weight_gpu( + self.weight, init_method, partition_dim=1, stride=stride + ) if bias: if neox_args.use_cpu_initialization: - self.bias = Parameter(torch.empty(self.output_size, - dtype=neox_args.params_dtype)) + self.bias = Parameter( + torch.empty(self.output_size, dtype=neox_args.params_dtype) + ) else: - self.bias = Parameter(torch.empty( - self.output_size, device=torch.cuda.current_device(), - dtype=neox_args.params_dtype)) + self.bias = Parameter( + torch.empty( + self.output_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def set_parallel_output(self, parallel_output: bool): assert isinstance(parallel_output, bool) diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 897ec2aca..dd01da7ed 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +14,12 @@ import torch -from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank, get_fp32_allreduce +from .initialize import ( + get_model_parallel_group, + get_model_parallel_world_size, + get_model_parallel_rank, + get_fp32_allreduce, +) from .utils import split_tensor_along_last_dim @@ -168,6 +172,7 @@ def backward(ctx, grad_output): # Helper functions. # ----------------- + def copy_to_model_parallel_region(input_): return _CopyToModelParallelRegion.apply(input_) diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py index b3eaecba6..2c2ea432e 100644 --- a/megatron/mpu/random.py +++ b/megatron/mpu/random.py @@ -5,7 +5,9 @@ import deepspeed.runtime.activation_checkpointing.checkpointing as checkpointing # Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = deepspeed.checkpointing._MODEL_PARALLEL_RNG_TRACKER_NAME +_MODEL_PARALLEL_RNG_TRACKER_NAME = ( + deepspeed.checkpointing._MODEL_PARALLEL_RNG_TRACKER_NAME +) # Whether apply model parallelsim to checkpointed hidden states. _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None @@ -19,4 +21,3 @@ checkpoint = checkpointing.checkpoint model_parallel_cuda_manual_seed = checkpointing.model_parallel_cuda_manual_seed get_cuda_rng_tracker = checkpointing.get_cuda_rng_tracker - diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index f4efad3c7..e13a99f00 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,8 +18,9 @@ def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, '{} is not divisible by {}'.format( - numerator, denominator) + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) def divide(numerator, denominator): @@ -30,8 +30,7 @@ def divide(numerator, denominator): return numerator // denominator -def split_tensor_along_last_dim(tensor, num_partitions, - contiguous_split_chunks=False): +def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Arguments: tensor: input tensor. @@ -53,12 +52,13 @@ def split_tensor_along_last_dim(tensor, num_partitions, class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [first, last]""" + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [first, last]""" @staticmethod - def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, world_size): + def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size + ): index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f, index_l @@ -67,4 +67,5 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): per_partition_vocab_size = divide(global_vocab_size, world_size) return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size) + per_partition_vocab_size, rank, world_size + ) diff --git a/megatron/neox_arguments/__init__.py b/megatron/neox_arguments/__init__.py index afe9cb571..db5825e75 100644 --- a/megatron/neox_arguments/__init__.py +++ b/megatron/neox_arguments/__init__.py @@ -1,9 +1,9 @@ """ -NeoX Arguments manages all configuration arguments. +NeoX Arguments manages all configuration arguments. **general** -* The implementation makes use of the python dataclass. +* The implementation makes use of the python dataclass. * The main class 'NeoXArgs' (in ./arguments) exposes all configuration attributes that are relevant to GPT NeoX * No attributes are nested (apart from attributes with type dict) * Output functions (enable_logging, save_yml, print) are implemented @@ -28,7 +28,7 @@ * The Subclasses group args according to their purpose * The attributes of NeoXArgsDeepspeedRunner are directly mapped to the expected command line args of deepspeed.launcher.runner.main; no attributes unknown to deepspeed should be included; no arguments relevant for deepspeed should be ommitted * The attributes of NeoXArgsDeepspeedConfig are directly mapped to the expected keys of the deepspeed config; no arguments relevant for deepspeed should be ommitted -* calculated attributes (decorator '@property') are available as attribute, but would not be included in dataclass fields (e.g. NeoXArgs().__dataclass_fields__.items()) +* calculated attributes (decorator '@property') are available as attribute, but would not be included in dataclass fields (e.g. NeoXArgs().__dataclass_fields__.items()) * refer to docstrings in code for more information """ diff --git a/megatron/neox_arguments/deepspeed_args.py b/megatron/neox_arguments/deepspeed_args.py index b1b8ba4db..9287725f3 100644 --- a/megatron/neox_arguments/deepspeed_args.py +++ b/megatron/neox_arguments/deepspeed_args.py @@ -46,7 +46,7 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate): dict containing the keys type and params type: The scheduler name. See here (https://deepspeed.readthedocs.io/en/latest/schedulers.html) for list of support schedulers. - + params: Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. """ @@ -72,7 +72,7 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate): fp16: dict = None """ - Configuration for using mixed precision/FP16 training that leverages NVIDIA’s Apex package. + Configuration for using mixed precision/FP16 training that leverages NVIDIA’s Apex package. """ amp: dict = None @@ -124,7 +124,7 @@ class NeoXArgsDeepspeedRunner(NeoXArgsTemplate): hostfile: str = None """ list of hostnames / ssh aliases and the number of GPUs per host - + example file contents: worker-1 slots=4 worker-2 slots=4 diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index e33cb9a8d..3a0aceade 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -51,8 +51,8 @@ class NeoXArgsParallelism(NeoXArgsTemplate): pipe_partition_method: str = "type:transformer|mlp" """ - method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number - of parameters on each pipeline stage, "uniform", which naively balances the number of layers per stage, or + method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number + of parameters on each pipeline stage, "uniform", which naively balances the number of layers per stage, or "type:[regex]", which balances layers whose class names match [regex] """ @@ -63,7 +63,7 @@ class NeoXArgsParallelism(NeoXArgsTemplate): is_pipe_parallel: bool = False """ - flag to determine whether pipeline parallelism is on - shouldn't be set by user, is automatically determined + flag to determine whether pipeline parallelism is on - shouldn't be set by user, is automatically determined according to pipeline parallel size. """ @@ -150,18 +150,18 @@ class NeoXArgsModel(NeoXArgsTemplate): """ Attention configuration for gpt-neox - - The first item in the list specifies the attention type(s), and should be a list of strings. The second item + + The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird] - + So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] - + or a 12 layer network with alternating global / local like: [[[`global`, `local`], 6]] - + If none is specified, this defaults to [[[`global`], n_layers]] """ @@ -170,13 +170,13 @@ class NeoXArgsModel(NeoXArgsTemplate): """ Sparsity configuration dict as defined in https://www.deepspeed.ai/docs/config-json/#sparse-attention - - Note that since neox is autoregressive, attention is always "unidirectional" and `horizontal_global_attention` is + + Note that since neox is autoregressive, attention is always "unidirectional" and `horizontal_global_attention` is always false. - - The main difference between our sparsity config and deepspeed's is that `mode` is ignored - since it is instead + + The main difference between our sparsity config and deepspeed's is that `mode` is ignored - since it is instead specified in attention_config defining each layer. - + An example config is given below: "sparse_attention": { "block": 16, @@ -278,7 +278,7 @@ class NeoXArgsModel(NeoXArgsTemplate): "small_init", ] = "normal" """ - Init function used on all layers except ff residual outputs - choose from + Init function used on all layers except ff residual outputs - choose from ["normal", "scaled_normal", "orthogonal", "scaled_orthogonal", "xavier_uniform", "xavier_normal", "wang_init", "small_init"] """ @@ -293,7 +293,7 @@ class NeoXArgsModel(NeoXArgsTemplate): "small_init", ] = "scaled_normal" """ - Init function used for ff residual outputs - choose from + Init function used for ff residual outputs - choose from ["normal", "scaled_normal", "orthogonal", "scaled_orthogonal", "xavier_uniform", "xavier_normal", "wang_init", "small_init"] """ @@ -315,7 +315,7 @@ class NeoXArgsModel(NeoXArgsTemplate): soft_prompt_tuning: dict = None """ - Dictionary configuring the soft prompt tuning parameters. + Dictionary configuring the soft prompt tuning parameters. If enabled, will train *only* the soft prompt, and freezes the rest of the model. parameters in the dict are: 'enabled': bool = True # enables soft prompting @@ -469,7 +469,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): log_grad_norm: bool = False """ Log the frob norm of the gradients to wandb / tensorboard (useful for debugging). - (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because + (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.) """ @@ -480,7 +480,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): log_gradient_noise_scale: bool = False """ - Whether to log the gradient noise scale when training (cf. https://arxiv.org/abs/1812.06162 for explanation) + Whether to log the gradient noise scale when training (cf. https://arxiv.org/abs/1812.06162 for explanation) """ gradient_noise_scale_n_batches: int = 5 @@ -606,7 +606,11 @@ class NeoXArgsTokenizer(NeoXArgsTemplate): """ tokenizer_type: Literal[ - "GPT2BPETokenizer", "HFTokenizer", "HFGPT2Tokenizer", "SPMTokenizer", "CharLevelTokenizer" + "GPT2BPETokenizer", + "HFTokenizer", + "HFGPT2Tokenizer", + "SPMTokenizer", + "CharLevelTokenizer", ] = "GPT2BPETokenizer" """ Type of tokenizer to use - should be one of ["GPT2BPETokenizer", "HFTokenizer", "HFGPT2Tokenizer", "SPMTokenizer", "CharLevelTokenizer"] @@ -614,7 +618,7 @@ class NeoXArgsTokenizer(NeoXArgsTemplate): padded_vocab_size: int = None """ - Total (padded) vocabulary size of tokenizer. Configured after launching of training, + Total (padded) vocabulary size of tokenizer. Configured after launching of training, as it's dependent on the parallelism size. """ @@ -671,7 +675,7 @@ class NeoXArgsTraining(NeoXArgsTemplate): weight_by_num_documents: bool = False """ If True, Builds dataset weights from a multinomial distribution over groups of data according to the number of - documents in each group. + documents in each group. WARNING: setting this to True will override any user provided weights @@ -939,7 +943,7 @@ class NeoXArgsTextgen(NeoXArgsTemplate): sample_output_file: str = "samples.txt" """ - Output file + Output file """ num_samples: int = 1 diff --git a/megatron/neox_arguments/template.py b/megatron/neox_arguments/template.py index 14138c436..b0e3869ff 100644 --- a/megatron/neox_arguments/template.py +++ b/megatron/neox_arguments/template.py @@ -1,32 +1,37 @@ from dataclasses import dataclass -import logging +import logging + @dataclass class NeoXArgsTemplate: - def defaults(self): """ generator for getting default values. """ for key, field_def in self.__dataclass_fields__.items(): yield key, field_def.default - + def update_value(self, key: str, value): """ updates a property value if the key already exists - Problem: a previously non-existing property can be added to the class instance without error. + Problem: a previously non-existing property can be added to the class instance without error. """ if hasattr(self, key): setattr(self, key, value) else: - error_message = self.__class__.__name__+".update_value() to be updated property "+str(key)+" does not exist" + error_message = ( + self.__class__.__name__ + + ".update_value() to be updated property " + + str(key) + + " does not exist" + ) logging.error(error_message) raise ValueError(error_message) - + def update_values(self, d): """ Updates multiple values in self if the keys already exists """ for k, v in d.items(): - self.update_value(k, v) \ No newline at end of file + self.update_value(k, v) diff --git a/megatron/optimizers.py b/megatron/optimizers.py index 1b2cd2af2..e5c8a4712 100644 --- a/megatron/optimizers.py +++ b/megatron/optimizers.py @@ -32,7 +32,7 @@ def __init__(self, params, lr=0.1, momentum=0.0, beta=0.0, eps=1e-30): if not 0.0 <= eps: raise ValueError("Invalid eps: {0}".format(eps)) - defaults = {'lr': lr, 'momentum': momentum, 'beta': beta, 'eps': eps} + defaults = {"lr": lr, "momentum": momentum, "beta": beta, "eps": eps} super(SM3, self).__init__(params, defaults) @torch.no_grad() @@ -48,10 +48,10 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - momentum = group['momentum'] - beta = group['beta'] - eps = group['eps'] - for p in group['params']: + momentum = group["momentum"] + beta = group["beta"] + eps = group["eps"] + for p in group["params"]: if p is None: continue grad = p.grad @@ -62,8 +62,8 @@ def step(self, closure=None): # State initialization if len(state) == 0: - state['step'] = 0 - state['momentum_buffer'] = 0. + state["step"] = 0 + state["momentum_buffer"] = 0.0 _add_initial_accumulators(state, grad) if grad.is_sparse: @@ -80,9 +80,13 @@ def make_sparse(values): return constructor(grad_indices, values, grad.size()) acc = state[_key(0)] - update_values = _compute_sparse_update(beta, acc, grad_values, grad_indices) + update_values = _compute_sparse_update( + beta, acc, grad_values, grad_indices + ) - self._update_sparse_accumulator(beta, acc, make_sparse(update_values)) + self._update_sparse_accumulator( + beta, acc, make_sparse(update_values) + ) # Add small amount for numerical stability update_values.add_(eps).rsqrt_().mul_(grad_values) @@ -104,20 +108,20 @@ def make_sparse(values): # Add small amount for numerical stability update.add_(eps).rsqrt_().mul_(grad) - if momentum > 0.: - m = state['momentum_buffer'] - update.mul_(1. - momentum).add_(m, alpha=momentum) - state['momentum_buffer'] = update.detach() + if momentum > 0.0: + m = state["momentum_buffer"] + update.mul_(1.0 - momentum).add_(m, alpha=momentum) + state["momentum_buffer"] = update.detach() - p.sub_(update, alpha=group['lr']) - state['step'] += 1 + p.sub_(update, alpha=group["lr"]) + state["step"] += 1 return loss @staticmethod def _update_accumulator(beta, acc_list, update): for i, acc in enumerate(acc_list): nu_max = _max_reduce_except_dim(update, i) - if beta > 0.: + if beta > 0.0: torch.max(acc, nu_max, out=acc) else: # No need to compare - nu_max is bigger because of grad ** 2 @@ -126,7 +130,7 @@ def _update_accumulator(beta, acc_list, update): @staticmethod def _update_sparse_accumulator(beta, acc, update): nu_max = _max_reduce_except_dim(update.to_dense(), 0).squeeze() - if beta > 0.: + if beta > 0.0: torch.max(acc, nu_max, out=acc) else: # No need to compare - nu_max is bigger because of grad ** 2 @@ -136,9 +140,9 @@ def _update_sparse_accumulator(beta, acc, update): def _compute_sparse_update(beta, acc, grad_values, grad_indices): # In the sparse case, a single accumulator is used. update_values = torch.gather(acc, 0, grad_indices[0]) - if beta > 0.: + if beta > 0.0: update_values.mul_(beta) - update_values.addcmul_(grad_values, grad_values, value=1. - beta) + update_values.addcmul_(grad_values, grad_values, value=1.0 - beta) return update_values @@ -148,16 +152,16 @@ def _compute_update(beta, acc_list, grad): for i in range(1, rank): # We rely on broadcasting to get the proper end shape. update = torch.min(update, acc_list[i]) - if beta > 0.: + if beta > 0.0: update.mul_(beta) - update.addcmul_(grad, grad, value=1. - beta) + update.addcmul_(grad, grad, value=1.0 - beta) return update def _key(i): # Returns key used for accessing accumulators - return 'accumulator_' + str(i) + return "accumulator_" + str(i) def _add_initial_accumulators(state, grad): @@ -167,7 +171,7 @@ def _add_initial_accumulators(state, grad): # accumulator of shape (n,). shape = grad.shape rank = len(shape) - defaults = {'device': grad.device, 'dtype': grad.dtype} + defaults = {"device": grad.device, "dtype": grad.dtype} acc = {} if grad.is_sparse: @@ -219,6 +223,7 @@ def _max_reduce_except_dim(tensor, dim): else: _params_t = Any + class madgrad_wd(torch.optim.Optimizer): """ MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index a0f65b067..5f38b5b20 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021 Josh Levy-Kramer . All rights reserved. # This file is based on code by the authors denoted below and has been modified from its original version. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. @@ -436,7 +435,7 @@ def generate_samples_from_prompt( # generate completions generated_texts = [] while True: - model.module.clear_cache() # clear kv cache between batches + model.module.clear_cache() # clear kv cache between batches start_time = time.time() # Tokenize text, and check whether we should terminate process @@ -718,7 +717,7 @@ def generate_samples_interactive( """ while True: - model.module.clear_cache() # clear kv cache between batches + model.module.clear_cache() # clear kv cache between batches torch.distributed.barrier(group=mpu.get_model_parallel_group()) terminate_runs = 0 diff --git a/megatron/tokenizer/__init__.py b/megatron/tokenizer/__init__.py index 311f2fdca..22b0f7b9e 100644 --- a/megatron/tokenizer/__init__.py +++ b/megatron/tokenizer/__init__.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/megatron/tokenizer/gpt2_tokenization.py b/megatron/tokenizer/gpt2_tokenization.py index adcb7231e..fd7b88843 100644 --- a/megatron/tokenizer/gpt2_tokenization.py +++ b/megatron/tokenizer/gpt2_tokenization.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -18,8 +17,7 @@ """Tokenization classes for OpenAI GPT.""" -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import absolute_import, division, print_function, unicode_literals import sys import json @@ -31,22 +29,21 @@ from functools import lru_cache - logger = logging.getLogger(__name__) PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", + "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", } PRETRAINED_MERGES_ARCHIVE_MAP = { - 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", + "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", } PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { - 'gpt2': 1024, + "gpt2": 1024, } -VOCAB_NAME = 'vocab.json' -MERGES_NAME = 'merges.txt' -SPECIAL_TOKENS_NAME = 'special_tokens.txt' +VOCAB_NAME = "vocab.json" +MERGES_NAME = "merges.txt" +SPECIAL_TOKENS_NAME = "special_tokens.txt" @lru_cache() @@ -61,14 +58,17 @@ def bytes_to_unicode(): And avoids mapping to whitespace/control characters the bpe code barfs on. """ _chr = unichr if sys.version_info[0] == 2 else chr - bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \ - list(range(ord("®"), ord("ÿ") + 1)) + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) cs = bs[:] n = 0 - for b in range(2**8): + for b in range(2 ** 8): if b not in bs: bs.append(b) - cs.append(2**8 + n) + cs.append(2 ** 8 + n) n += 1 cs = [_chr(n) for n in cs] return dict(zip(bs, cs)) @@ -92,8 +92,11 @@ class GPT2Tokenizer(object): GPT-2 BPE tokenizer. Peculiarities: - Byte-level BPE """ + @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + def from_pretrained( + cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs + ): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. @@ -105,14 +108,19 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) - special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + special_tokens_file = os.path.join( + pretrained_model_name_or_path, SPECIAL_TOKENS_NAME + ) if not os.path.exists(special_tokens_file): special_tokens_file = None else: - logger.info("loading special tokens file {}".format(special_tokens_file)) + logger.info( + "loading special tokens file {}".format(special_tokens_file) + ) # redirect to the cache, if necessary try: from .file_utils import cached_path + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) except EnvironmentError: @@ -121,52 +129,76 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, "We assumed '{}' was a path or url but couldn't find files {} and {} " "at this path or url.".format( pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + ", ".join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, - vocab_file, merges_file)) + vocab_file, + merges_file, + ) + ) return None if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: logger.info("loading vocabulary file {}".format(vocab_file)) logger.info("loading merges file {}".format(merges_file)) else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - logger.info("loading merges file {} from cache at {}".format( - merges_file, resolved_merges_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + logger.info( + "loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file + ) + ) + logger.info( + "loading merges file {} from cache at {}".format( + merges_file, resolved_merges_file + ) + ) + if ( + pretrained_model_name_or_path + in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP + ): # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ + pretrained_model_name_or_path + ] + kwargs["max_len"] = min(kwargs.get("max_len", int(1e12)), max_len) # Instantiate tokenizer. - if special_tokens_file and 'special_tokens' not in kwargs: - special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + if special_tokens_file and "special_tokens" not in kwargs: + special_tokens = ( + open(special_tokens_file, encoding="utf-8").read().split("\n")[:-1] + ) else: - special_tokens = kwargs.pop('special_tokens', []) + special_tokens = kwargs.pop("special_tokens", []) tokenizer = cls( resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, - **kwargs) + **kwargs + ) return tokenizer - def __init__(self, vocab_file, merges_file, errors='replace', - special_tokens=None, max_len=None): + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + special_tokens=None, + max_len=None, + ): self.max_len = max_len if max_len is not None else int(1e12) self.encoder = json.load(open(vocab_file)) self.decoder = {v: k for k, v in self.encoder.items()} self.errors = errors # how to handle errors in decoding self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_data = open(merges_file, encoding="utf-8").read().split("\n")[1:-1] bpe_merges = [tuple(merge.split()) for merge in bpe_data] self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) # Should haved added re.IGNORECASE so BPE merges can happen for # capitalized versions of contractions self.pat = re.compile( - r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) self.special_tokens = {} self.special_tokens_decoder = {} @@ -176,16 +208,17 @@ def __len__(self): return len(self.encoder) + len(self.special_tokens) def set_special_tokens(self, special_tokens): - """ Add a list of additional tokens to the encoder. - The additional tokens are indexed starting from the last index of the - current vocabulary in the order of the `special_tokens` list. + """Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. """ if not special_tokens: self.special_tokens = {} self.special_tokens_decoder = {} return - self.special_tokens = dict((tok, len(self.encoder) + i) - for i, tok in enumerate(special_tokens)) + self.special_tokens = dict( + (tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens) + ) self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} logger.info("Special tokens {}".format(self.special_tokens)) @@ -198,7 +231,7 @@ def bpe(self, token): return token while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -225,24 +258,26 @@ def bpe(self, token): break else: pairs = get_pairs(word) - word = ' '.join(word) + word = " ".join(word) return word def tokenize(self, text): - """ Tokenize a string. """ + """Tokenize a string.""" bpe_tokens = [] for token in re.findall(self.pat, text): if sys.version_info[0] == 2: - token = ''.join(self.byte_encoder[ord(b)] for b in token) + token = "".join(self.byte_encoder[ord(b)] for b in token) else: - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) return bpe_tokens def convert_tokens_to_ids(self, tokens): - """ Converts a sequence of tokens into ids using the vocab. """ + """Converts a sequence of tokens into ids using the vocab.""" ids = [] - if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): + if isinstance(tokens, str) or ( + sys.version_info[0] == 2 and isinstance(tokens, unicode) + ): if tokens in self.special_tokens: return self.special_tokens[tokens] else: @@ -257,7 +292,8 @@ def convert_tokens_to_ids(self, tokens): "Token indices sequence length is longer than the specified maximum " " sequence length for this OpenAI GPT model ({} > {}). Running this" " sequence through the model will result in indexing errors".format( - len(ids), self.max_len) + len(ids), self.max_len + ) ) return ids @@ -276,41 +312,57 @@ def encode(self, text): return self.convert_tokens_to_ids(self.tokenize(text)) def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) return text def save_vocabulary(self, vocab_path): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(vocab_path): - logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + logger.error( + "Vocabulary path ({}) should be a directory".format(vocab_path) + ) return vocab_file = os.path.join(vocab_path, VOCAB_NAME) merge_file = os.path.join(vocab_path, MERGES_NAME) special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) - with open(vocab_file, 'w', encoding='utf-8') as f: + with open(vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) index = 0 with open(merge_file, "w", encoding="utf-8") as writer: - writer.write(u'#version: 0.2\n') - for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted( + self.bpe_ranks.items(), key=lambda kv: kv[1] + ): if index != token_index: - logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(merge_file)) + logger.warning( + "Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format( + merge_file + ) + ) index = token_index - writer.write(' '.join(bpe_tokens) + u'\n') + writer.write(" ".join(bpe_tokens) + "\n") index += 1 index = len(self.encoder) - with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + with open(special_tokens_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted( + self.special_tokens.items(), key=lambda kv: kv[1] + ): if index != token_index: - logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + logger.warning( + "Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format( + special_tokens_file + ) + ) index = token_index - writer.write(token + u'\n') + writer.write(token + "\n") index += 1 return vocab_file, merge_file, special_tokens_file diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 52a6883de..69ce98e2e 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -32,32 +31,34 @@ def build_tokenizer(args): """Initialize tokenizer.""" if args.rank == 0: - print('> building {} tokenizer ...'.format(args.tokenizer_type), flush=True) + print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True) # Select and instantiate the tokenizer. - if args.tokenizer_type.lower() == 'GPT2BPETokenizer'.lower(): + if args.tokenizer_type.lower() == "GPT2BPETokenizer".lower(): assert args.vocab_file is not None assert args.merge_file is not None tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) - elif args.tokenizer_type.lower() == 'SPMTokenizer'.lower(): + elif args.tokenizer_type.lower() == "SPMTokenizer".lower(): assert args.vocab_file is not None tokenizer = SentencePieceTokenizer(args.vocab_file) - elif args.tokenizer_type.lower() == 'HFTokenizer'.lower(): + elif args.tokenizer_type.lower() == "HFTokenizer".lower(): assert args.vocab_file is not None tokenizer = HFTokenizer(args.vocab_file) - elif args.tokenizer_type.lower() == 'HFGPT2Tokenizer'.lower(): + elif args.tokenizer_type.lower() == "HFGPT2Tokenizer".lower(): if args.vocab_file is None: - print("WARNING: No vocab file found, loading Huggingface's pretrained GPT2Tokenizer") + print( + "WARNING: No vocab file found, loading Huggingface's pretrained GPT2Tokenizer" + ) tokenizer = HFGPT2Tokenizer(args.vocab_file) elif args.tokenizer_type.lower() == "CharLevelTokenizer".lower(): tokenizer = CharLevelTokenizer(vocab_size=512) else: - raise NotImplementedError('{} tokenizer is not ' - 'implemented.'.format(args.tokenizer_type)) + raise NotImplementedError( + "{} tokenizer is not " "implemented.".format(args.tokenizer_type) + ) # Add vocab size. - args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, - args) + args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) return tokenizer @@ -67,14 +68,15 @@ def _vocab_size_with_padding(orig_vocab_size, args): still having GPU friendly size.""" after = orig_vocab_size - multiple = args.make_vocab_size_divisible_by * \ - args.model_parallel_size + multiple = args.make_vocab_size_divisible_by * args.model_parallel_size while (after % multiple) != 0: after += 1 if args.rank == 0: - print(' > padded vocab (size: {}) with {} dummy tokens ' - '(new size: {})'.format( - orig_vocab_size, after - orig_vocab_size, after), flush=True) + print( + " > padded vocab (size: {}) with {} dummy tokens " + "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after), + flush=True, + ) return after @@ -107,45 +109,52 @@ def tokenize(self, text): pass def detokenize(self, token_ids): - raise NotImplementedError('detokenizer is not implemented for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError( + "detokenizer is not implemented for {} " "tokenizer".format(self.name) + ) @property def cls(self): - raise NotImplementedError('CLS is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError( + "CLS is not provided for {} " "tokenizer".format(self.name) + ) @property def sep(self): - raise NotImplementedError('SEP is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError( + "SEP is not provided for {} " "tokenizer".format(self.name) + ) @property def pad(self): - raise NotImplementedError('PAD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError( + "PAD is not provided for {} " "tokenizer".format(self.name) + ) @property def eod(self): - raise NotImplementedError('EOD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError( + "EOD is not provided for {} " "tokenizer".format(self.name) + ) @property def mask(self): - raise NotImplementedError('MASK is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError( + "MASK is not provided for {} " "tokenizer".format(self.name) + ) class _GPT2BPETokenizer(AbstractTokenizer): """Original GPT2 BPE tokenizer.""" def __init__(self, vocab_file, merge_file): - name = 'GPT2 BPE' + name = "GPT2 BPE" super().__init__(name) - self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', - special_tokens=[], max_len=None) - self.eod_id = self.tokenizer.encoder['<|endoftext|>'] + self.tokenizer = GPT2Tokenizer( + vocab_file, merge_file, errors="replace", special_tokens=[], max_len=None + ) + self.eod_id = self.tokenizer.encoder["<|endoftext|>"] @property def vocab_size(self): @@ -174,11 +183,11 @@ class SentencePieceTokenizer(AbstractTokenizer): """Designed to Integrate SP's Tokenizer.""" def __init__(self, vocab_file): - name = 'SPM' + name = "SPM" super().__init__(name) self.tokenizer = spm.SentencePieceProcessor(model_file=vocab_file) - self.eod_id = self.tokenizer.piece_to_id('<|endoftext|>') + self.eod_id = self.tokenizer.piece_to_id("<|endoftext|>") @property def vocab_size(self): @@ -186,11 +195,17 @@ def vocab_size(self): @property def vocab(self): - return {self.tokenizer.id_to_piece(idx):idx for idx in range(self.tokenizer.get_piece_size())} + return { + self.tokenizer.id_to_piece(idx): idx + for idx in range(self.tokenizer.get_piece_size()) + } @property def inv_vocab(self): - return {idx:self.tokenizer.id_to_piece(idx) for idx in range(self.tokenizer.get_piece_size())} + return { + idx: self.tokenizer.id_to_piece(idx) + for idx in range(self.tokenizer.get_piece_size()) + } def tokenize(self, text): return self.tokenizer.encode(text) @@ -207,12 +222,12 @@ class HFTokenizer(AbstractTokenizer): """Designed to Integrate HF's Tokenizer library.""" def __init__(self, vocab_file): - name = 'HFTokenizer' + name = "HFTokenizer" super().__init__(name) self.tokenizer = Tokenizer.from_file(vocab_file) - self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') - self.pad_id = self.tokenizer.token_to_id('<|padding|>') + self.eod_id = self.tokenizer.token_to_id("<|endoftext|>") + self.pad_id = self.tokenizer.token_to_id("<|padding|>") @property def vocab_size(self): @@ -244,17 +259,18 @@ class HFGPT2Tokenizer(AbstractTokenizer): """Designed to Integrate the pretrained OpenAI GPT2 Tokenizers from HF""" def __init__(self, vocab_file=None, fast=True): - name = 'HFGPT2Tokenizer' - if fast: name += "Fast" + name = "HFGPT2Tokenizer" + if fast: + name += "Fast" super().__init__(name) if vocab_file is None: - vocab_file = 'gpt2' + vocab_file = "gpt2" if fast: self.tokenizer = GPT2TokenizerFast.from_pretrained(vocab_file) else: self.tokenizer = GPT2Tokenizer.from_pretrained(vocab_file) - self.tokenizer.add_special_tokens({'pad_token': '<|padding|>'}) + self.tokenizer.add_special_tokens({"pad_token": "<|padding|>"}) self.eod_id = self.tokenizer.eos_token_id self.pad_id = self.tokenizer.pad_token_id @@ -290,7 +306,7 @@ class CharLevelTokenizer(AbstractTokenizer): """Character Level Tokenizer""" def __init__(self, vocab_size): - name = 'CharLevelTokenizer' + name = "CharLevelTokenizer" super().__init__(name) self._vocab_size = vocab_size self.eod_id = 0 @@ -324,7 +340,7 @@ def tokenize_batch(self, text_batch: Union[List[str], str]): return self.tokenize(text_batch) def detokenize(self, token_ids): - return ''.join(list(map(self.decode_token, token_ids))) + return "".join(list(map(self.decode_token, token_ids))) @property def eod(self): diff --git a/megatron/tokenizer/train_tokenizer.py b/megatron/tokenizer/train_tokenizer.py index 1786572b3..ab81314b7 100644 --- a/megatron/tokenizer/train_tokenizer.py +++ b/megatron/tokenizer/train_tokenizer.py @@ -2,8 +2,7 @@ Assumes a dataset of jsonl files in the same format as the neox training set. """ -from tokenizers import (Tokenizer, decoders, models, pre_tokenizers, - processors, trainers) +from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers from tokenizers.normalizers import NFKC from glob import glob @@ -11,27 +10,31 @@ import json import argparse + def load_jsonl(input_path, quiet=True) -> list: """ Read list of objects from a JSON lines file. """ data = [] - with open(input_path, 'r', encoding='utf-8') as f: + with open(input_path, "r", encoding="utf-8") as f: for line in f: - data.append(json.loads(line.rstrip('\n|\r'))) + data.append(json.loads(line.rstrip("\n|\r"))) if not quiet: - print('Loaded {} records from {}'.format(len(data), input_path)) + print("Loaded {} records from {}".format(len(data), input_path)) return data -def json_iterator(input_dir, text_key='text'): - all_jsonls = glob(f'{input_dir}/*.jsonl') + glob(f'{input_dir}/*.json') + +def json_iterator(input_dir, text_key="text"): + all_jsonls = glob(f"{input_dir}/*.jsonl") + glob(f"{input_dir}/*.json") for j in all_jsonls: data = load_jsonl(j) for doc in data: yield doc[text_key] - -def train_tokenizer(input_dir: str, save_path: str, tokenizer_type: str = "BPE", vocab_size: int = 52000): + +def train_tokenizer( + input_dir: str, save_path: str, tokenizer_type: str = "BPE", vocab_size: int = 52000 +): """ Trains a tokenizer on all the json files in `input_dir` and saves it to `save_path` @@ -45,7 +48,7 @@ def train_tokenizer(input_dir: str, save_path: str, tokenizer_type: str = "BPE", if tokenizer_type == "BPE": model = models.BPE() else: - raise NotImplementedError(f'Tokenizer type {tokenizer_type} not implemented') + raise NotImplementedError(f"Tokenizer type {tokenizer_type} not implemented") tokenizer = Tokenizer(model) # Customize pre-tokenization and decoding @@ -55,27 +58,45 @@ def train_tokenizer(input_dir: str, save_path: str, tokenizer_type: str = "BPE", tokenizer.normalizer = NFKC() # And then train - trainer = trainers.BpeTrainer(vocab_size=vocab_size, special_tokens=["<|endoftext|>", "<|padding|>"]) + trainer = trainers.BpeTrainer( + vocab_size=vocab_size, special_tokens=["<|endoftext|>", "<|padding|>"] + ) tokenizer.train_from_iterator(json_iterator(input_dir), trainer) # And Save it tokenizer.save(save_path, pretty=True) - print(f'Tokenizer saved at {save_path}') + print(f"Tokenizer saved at {save_path}") + def parse_args(): - parser = argparse.ArgumentParser(description='script for training a multilingual ' - 'HF tokenizer on CC dumps with upweighting for low resource languages') - parser.add_argument('--json_input_dir', type=str, - help='Path to folder containing tokenizer training data in jsonl format') - parser.add_argument('--tokenizer_output_path', type=str, - help='Path to which your trained tokenizer will be saved (should end in .json)') - parser.add_argument('--tokenizer_type', type=str, - help="type of tokenizer to train, currently only BPE is supported", - choices=['BPE'], - default=['BPE']) - parser.add_argument('-v', '--vocab_size', - help='vocabulary size of tokenizer, default=52k', - type=int, default=52000) + parser = argparse.ArgumentParser( + description="script for training a multilingual " + "HF tokenizer on CC dumps with upweighting for low resource languages" + ) + parser.add_argument( + "--json_input_dir", + type=str, + help="Path to folder containing tokenizer training data in jsonl format", + ) + parser.add_argument( + "--tokenizer_output_path", + type=str, + help="Path to which your trained tokenizer will be saved (should end in .json)", + ) + parser.add_argument( + "--tokenizer_type", + type=str, + help="type of tokenizer to train, currently only BPE is supported", + choices=["BPE"], + default=["BPE"], + ) + parser.add_argument( + "-v", + "--vocab_size", + help="vocabulary size of tokenizer, default=52k", + type=int, + default=52000, + ) return parser.parse_args() @@ -83,7 +104,9 @@ def parse_args(): args = parse_args() - train_tokenizer(args.json_input_dir, - save_path=args.tokenizer_output_path, - tokenizer_type=args.tokenizer_type, - vocab_size=args.vocab_size) + train_tokenizer( + args.json_input_dir, + save_path=args.tokenizer_output_path, + tokenizer_type=args.tokenizer_type, + vocab_size=args.vocab_size, + ) diff --git a/megatron/training.py b/megatron/training.py index 63b977e4e..0741c0a43 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # diff --git a/megatron/utils.py b/megatron/utils.py index 674ae8deb..593c38d6d 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -1,5 +1,3 @@ -# coding=utf-8 - # Copyright (c) 2021 Josh Levy-Kramer . # This file is based on code by the authors denoted below and has been modified from its original version. # diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 2cb3173a9..2e7812f5a 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,4 +1,4 @@ +autopep8==1.5.6 pytest==6.2.3 pytest-cov==2.11.1 pytest-forked==1.3.0 -autopep8==1.5.6 diff --git a/requirements/requirements-tensorboard.txt b/requirements/requirements-tensorboard.txt index a1b71d380..fef3fbbc7 100644 --- a/requirements/requirements-tensorboard.txt +++ b/requirements/requirements-tensorboard.txt @@ -1 +1 @@ -tensorboard==2.5.0 \ No newline at end of file +tensorboard==2.5.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 47cb11bc2..f77522f31 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,14 +1,14 @@ -pybind11==2.6.2 -six -regex -numpy==1.21.0 git+git://github.com/EleutherAI/DeeperSpeed.git@eb7f5cff36678625d23db8a8fe78b4a93e5d2c75#egg=deepspeed -mpi4py==3.0.3 -wandb==0.10.28 einops==0.3.0 -transformers==4.5.0 -tokenizers==0.10.2 -lm_dataformat==0.0.19 ftfy==6.0.1 +lm_dataformat==0.0.19 git+https://github.com/EleutherAI/lm-evaluation-harness.git@dc937d4b70af819c5695e09d94e59e4cdb1e40ad#egg=lm_eval +mpi4py==3.0.3 +numpy==1.21.0 +pybind11==2.6.2 +regex sentencepiece +six +tokenizers==0.10.2 +transformers==4.5.0 +wandb==0.10.28 diff --git a/tests/Readme.md b/tests/Readme.md index a477ff6c0..ba0895411 100644 --- a/tests/Readme.md +++ b/tests/Readme.md @@ -8,7 +8,7 @@ pip install -r requirements/requirements-dev.txt # Run -Tests can be run using pytest. +Tests can be run using pytest. * The argument --forked needs to be provided * A coverage report can be created using the optional arguments --cov-report and --cov (see pytest documentation) diff --git a/tests/model/test_fused_kernels.py b/tests/model/test_fused_kernels.py index 2f2436eb2..0301231c1 100644 --- a/tests/model/test_fused_kernels.py +++ b/tests/model/test_fused_kernels.py @@ -1,9 +1,9 @@ - import os if __name__ == "__main__": import sys - sys.path.append(os.path.abspath('')) + + sys.path.append(os.path.abspath("")) import math @@ -220,6 +220,7 @@ def test_fused_upper_triangle_mask_softmax(): f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" ) + if __name__ == "__main__": try: from transformers import BertTokenizer, GPT2Tokenizer diff --git a/tests/model/test_model_checkpoint.py b/tests/model/test_model_checkpoint.py index 36edf3ec6..22a24b291 100644 --- a/tests/model/test_model_checkpoint.py +++ b/tests/model/test_model_checkpoint.py @@ -8,67 +8,117 @@ if __name__ == "__main__": import sys - sys.path.append(os.path.abspath('')) + + sys.path.append(os.path.abspath("")) import pytest -from tests.common import distributed_test, clear_test_dirs, model_setup, binary, parametrize +from tests.common import ( + distributed_test, + clear_test_dirs, + model_setup, + binary, + parametrize, +) import torch PARAMS_TO_TEST = { "pipe_parallel_size,model_parallel_size": [[0, 1], [1, 2], [0, 2], [2, 1]], "checkpoint_validation_with_forward_pass": [True], - "fp16,fp32_allreduce": [[{ - "enabled": True, - "type": "bfloat16", - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, True], [{ - "enabled": True, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, False]] + "fp16,fp32_allreduce": [ + [ + { + "enabled": True, + "type": "bfloat16", + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + True, + ], + [ + { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + False, + ], + ], } -parameters, names = parametrize(PARAMS_TO_TEST, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None) +parameters, names = parametrize( + PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None +) + + @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_train(param_dict): @distributed_test(world_size=2) def wrapper(): run_checkpoint_test(param_dict=param_dict) + wrapper() def run_checkpoint_test(yaml_list=None, param_dict=None): - + from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint - model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict, clear_data=True) + model, optimizer, lr_scheduler, args_loaded = model_setup( + yaml_list, param_dict, clear_data=True + ) # save model checkpoint - save_checkpoint(neox_args=args_loaded, iteration=42, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) - + save_checkpoint( + neox_args=args_loaded, + iteration=42, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) # reload model from checkpoint - reloaded_model, reloaded_optimizer, reloaded_lr_scheduler, args_reloaded = model_setup(yaml_list, param_dict, clear_data=False) - iteration = load_checkpoint(neox_args=args_reloaded, model=reloaded_model, optimizer=reloaded_optimizer, lr_scheduler=reloaded_lr_scheduler) + ( + reloaded_model, + reloaded_optimizer, + reloaded_lr_scheduler, + args_reloaded, + ) = model_setup(yaml_list, param_dict, clear_data=False) + iteration = load_checkpoint( + neox_args=args_reloaded, + model=reloaded_model, + optimizer=reloaded_optimizer, + lr_scheduler=reloaded_lr_scheduler, + ) - #ensure same checkpoint is loaded - assert iteration == 42, "run_checkpoint_test() iteration loaded from checkpoint correct" + # ensure same checkpoint is loaded + assert ( + iteration == 42 + ), "run_checkpoint_test() iteration loaded from checkpoint correct" - #check all weight groups are the same - for idx, ((n1, p1), (n2, p2)) in enumerate(zip(list(model.module.named_parameters()), list(reloaded_model.module.named_parameters()))): + # check all weight groups are the same + for idx, ((n1, p1), (n2, p2)) in enumerate( + zip( + list(model.module.named_parameters()), + list(reloaded_model.module.named_parameters()), + ) + ): assert n1 == n2 params_equal = (p1 == p2).all().item() - assert params_equal, "run_checkpoint_test() params equal: "+str(n1) + assert params_equal, "run_checkpoint_test() params equal: " + str(n1) if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0: clear_test_dirs() + if __name__ == "__main__": - params = list(parametrize(PARAMS_TO_TEST, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None)) - test_train(params[0]) \ No newline at end of file + params = list( + parametrize( + PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None + ) + ) + test_train(params[0]) diff --git a/tests/model/test_model_generation.py b/tests/model/test_model_generation.py index 7bb0e10eb..ddb1aca73 100644 --- a/tests/model/test_model_generation.py +++ b/tests/model/test_model_generation.py @@ -10,38 +10,59 @@ if __name__ == "__main__": import sys - sys.path.append(os.path.abspath('')) + + sys.path.append(os.path.abspath("")) import pytest from tests.common import distributed_test, model_setup, parametrize, dict_repr import torch PARAMS_TO_TEST = { - "pipe_parallel_size,model_parallel_size,world_size": [[0, 1, 1], [0, 1, 2], [1, 2, 2], [0, 2, 2], [2, 1, 2]], + "pipe_parallel_size,model_parallel_size,world_size": [ + [0, 1, 1], + [0, 1, 2], + [1, 2, 2], + [0, 2, 2], + [2, 1, 2], + ], "top_p,temperature,top_k": [[0.0, 0.5, 0], [0.5, 0.0, 100], [0.5, 0.5, 0]], "prompt": ["", "hello world"], - "fp16,fp32_allreduce": [[{ - "enabled": True, - "type": "bfloat16", - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, True], [{ - "enabled": True, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, False]] + "fp16,fp32_allreduce": [ + [ + { + "enabled": True, + "type": "bfloat16", + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + True, + ], + [ + { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + False, + ], + ], } -parameters, names = parametrize(PARAMS_TO_TEST, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None) +parameters, names = parametrize( + PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None +) + + @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_train(param_dict): @distributed_test(world_size=param_dict.pop("world_size", 2)) def wrapper(): run_generate_test(param_dict=param_dict, prompt=param_dict.pop("prompt")) + wrapper() @@ -50,36 +71,37 @@ def run_generate_test(param_dict, prompt): from megatron.utils import is_mp_rank_0 fixed_params = { - "num_samples": 3, - "maximum_tokens": 50, - "make_vocab_size_divisible_by": 2, - "sample_output_file": "test_sample_output.txt", - "checkpoint_activations": False, - "partition_activations": False, - "no_load_optim": True, + "num_samples": 3, + "maximum_tokens": 50, + "make_vocab_size_divisible_by": 2, + "sample_output_file": "test_sample_output.txt", + "checkpoint_activations": False, + "partition_activations": False, + "no_load_optim": True, } - + param_dict.update(fixed_params) # TODO: we don't need to reinstantiate the model every time if we're only changing sampling settings - should be a workaround for this - model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True, inference=True) + model, _, _, args_loaded = model_setup( + None, param_dict, clear_data=True, inference=True + ) model.eval() prompts = [prompt for _ in range(args_loaded.num_samples)] output = generate_samples_from_prompt( - neox_args=args_loaded, - model=model, + neox_args=args_loaded, + model=model, text=prompts, - maximum_tokens=args_loaded.maximum_tokens, - recompute=False, - temperature=args_loaded.temperature, - top_k=args_loaded.top_k, - top_p=args_loaded.top_p, - ) + maximum_tokens=args_loaded.maximum_tokens, + recompute=False, + temperature=args_loaded.temperature, + top_k=args_loaded.top_k, + top_p=args_loaded.top_p, + ) # outputs only get generated on mp rank 0 if is_mp_rank_0(): assert len(output) == len(prompts) - for prompt, out in zip(prompts, output): - assert(prompt == out["context"]) - assert(len(out["text"]) > 0) - + for prompt, out in zip(prompts, output): + assert prompt == out["context"] + assert len(out["text"]) > 0 diff --git a/tests/model/test_model_instantiation.py b/tests/model/test_model_instantiation.py index baa4d36bc..e5484c771 100644 --- a/tests/model/test_model_instantiation.py +++ b/tests/model/test_model_instantiation.py @@ -9,59 +9,96 @@ from ..common import distributed_test, model_setup, clear_test_dirs, parametrize, binary PARAMS_TO_TEST = { - "pipe_parallel_size,model_parallel_size,world_size": [[0, 1, 1], [1, 2, 2], [0, 2, 2]], + "pipe_parallel_size,model_parallel_size,world_size": [ + [0, 1, 1], + [1, 2, 2], + [0, 2, 2], + ], "no_weight_tying": binary, - "attention_config": [[[["global"], "all"]], [[["local"], "all"]], [[["sparse_variable"], "all"]], - [[["sparse_fixed"], "all"]]], - "scaled_upper_triang_masked_softmax_fusion,bias_gelu_fusion": [[True, False], [False, True]], - "fp16,fp32_allreduce": [[{ - "enabled": True, - "type": "bfloat16", - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, True], [{ - "enabled": True, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, False]] + "attention_config": [ + [[["global"], "all"]], + [[["local"], "all"]], + [[["sparse_variable"], "all"]], + [[["sparse_fixed"], "all"]], + ], + "scaled_upper_triang_masked_softmax_fusion,bias_gelu_fusion": [ + [True, False], + [False, True], + ], + "fp16,fp32_allreduce": [ + [ + { + "enabled": True, + "type": "bfloat16", + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + True, + ], + [ + { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + False, + ], + ], } -parameters, names = parametrize(PARAMS_TO_TEST, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None) +parameters, names = parametrize( + PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None +) + + @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_instantiate(param_dict): @distributed_test(world_size=param_dict.pop("world_size", 2)) def wrapper(): run_test_model_instantiation(param_dict=param_dict) + wrapper() + OPTIMIZER_PARAMS = { "optimizer": [ - {"type": "adam","params": {"lr": 0.0006}}, - {"type": "onebitadam","params": {"lr": 0.0006}}, - {"type": "cpu_adam","params": {"lr": 0.0006}}, - {"type": "cpu_torch_adam","params": {"lr": 0.0006}}, - {"type": "sm3","params": {"lr": 0.0006}}, - {"type": "madgrad_wd","params": {"lr": 0.0006}} - ] - } -opt_params, opt_name = parametrize(OPTIMIZER_PARAMS, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None) + {"type": "adam", "params": {"lr": 0.0006}}, + {"type": "onebitadam", "params": {"lr": 0.0006}}, + {"type": "cpu_adam", "params": {"lr": 0.0006}}, + {"type": "cpu_torch_adam", "params": {"lr": 0.0006}}, + {"type": "sm3", "params": {"lr": 0.0006}}, + {"type": "madgrad_wd", "params": {"lr": 0.0006}}, + ] +} +opt_params, opt_name = parametrize( + OPTIMIZER_PARAMS, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None +) + + @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_instantiate_optimizers(param_dict): @distributed_test(world_size=2) def wrapper(): run_test_model_instantiation(param_dict=param_dict) + wrapper() + def run_test_model_instantiation(yaml_list=None, param_dict=None): from deepspeed.runtime.pipe.engine import PipelineEngine, DeepSpeedEngine + model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict) if args_loaded.pipe_parallel_size < 2: - assert isinstance(model, DeepSpeedEngine), "test model instantiation " + str(yaml_list) + assert isinstance(model, DeepSpeedEngine), "test model instantiation " + str( + yaml_list + ) else: - assert isinstance(model, PipelineEngine), "test model instantiation " + str(yaml_list) + assert isinstance(model, PipelineEngine), "test model instantiation " + str( + yaml_list + ) if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0: clear_test_dirs() diff --git a/tests/model/test_model_train.py b/tests/model/test_model_train.py index 4314d3277..882bffd03 100644 --- a/tests/model/test_model_train.py +++ b/tests/model/test_model_train.py @@ -9,73 +9,114 @@ from ..common import distributed_test, clear_test_dirs, model_setup, binary, parametrize import torch -import os +import os PARAMS_TO_TEST = { - "norm,pos_emb,activation": [["layernorm", "learned", "gelu"], ["rmsnorm", "rotary", "relu"], - ["scalenorm", "sinusoidal", "mish"], ["layernorm", "rpe", "geglu"], - ["rmsnorm", "none", "swish"]], + "norm,pos_emb,activation": [ + ["layernorm", "learned", "gelu"], + ["rmsnorm", "rotary", "relu"], + ["scalenorm", "sinusoidal", "mish"], + ["layernorm", "rpe", "geglu"], + ["rmsnorm", "none", "swish"], + ], "pipe_parallel_size,model_parallel_size": [[0, 1], [1, 2], [0, 2]], "no_weight_tying": binary, - "attention_config,num_layers": [[[[["global"], "all"]], 2], [[[["local", "global"], "all"]], 12], [[[["sparse_variable", "global"], "all"]], 12], - [[[["sparse_fixed", "global"], "all"]], 12]], # the sparse attention models need more layers to be stable - "scaled_upper_triang_masked_softmax_fusion,bias_gelu_fusion": [[True, False], [False, True]], + "attention_config,num_layers": [ + [[[["global"], "all"]], 2], + [[[["local", "global"], "all"]], 12], + [[[["sparse_variable", "global"], "all"]], 12], + [[[["sparse_fixed", "global"], "all"]], 12], + ], # the sparse attention models need more layers to be stable + "scaled_upper_triang_masked_softmax_fusion,bias_gelu_fusion": [ + [True, False], + [False, True], + ], "checkpoint_activations": binary, "log_gradient_noise_scale": [True], - "sparsity_config": [{ - "block": 16, # block size - "num_local_blocks": 32, - }], + "sparsity_config": [ + { + "block": 16, # block size + "num_local_blocks": 32, + } + ], } -parameters, names = parametrize(PARAMS_TO_TEST, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None) +parameters, names = parametrize( + PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None +) + + @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_train(param_dict): @distributed_test(world_size=2) def wrapper(): run_train_test(param_dict=param_dict) + wrapper() -BF16_PARAMS_TO_TEST = {"fp16,fp32_allreduce": [[{ - "enabled": True, - "type": "bfloat16", - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, True], [{ - "enabled": True, - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, False]]} - -parameters, names = parametrize(BF16_PARAMS_TO_TEST, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None) + +BF16_PARAMS_TO_TEST = { + "fp16,fp32_allreduce": [ + [ + { + "enabled": True, + "type": "bfloat16", + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + True, + ], + [ + { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1, + }, + False, + ], + ] +} + +parameters, names = parametrize( + BF16_PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None +) + + @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_train_bf16(param_dict): @distributed_test(world_size=2) def wrapper(): run_train_test(param_dict=param_dict) + wrapper() + OPTIMIZER_PARAMS = { "optimizer": [ - {"type": "adam","params": {"lr": 0.0006}}, - {"type": "onebitadam","params": {"lr": 0.0006}}, - {"type": "cpu_adam","params": {"lr": 0.0006}}, - {"type": "cpu_torch_adam","params": {"lr": 0.0006}}, - {"type": "sm3","params": {"lr": 0.0006}}, - {"type": "madgrad_wd","params": {"lr": 0.0006}} - ] - } -opt_params, opt_name = parametrize(OPTIMIZER_PARAMS, max_tests=int(os.getenv('MAX_TESTCASES', 50)), seed=None) + {"type": "adam", "params": {"lr": 0.0006}}, + {"type": "onebitadam", "params": {"lr": 0.0006}}, + {"type": "cpu_adam", "params": {"lr": 0.0006}}, + {"type": "cpu_torch_adam", "params": {"lr": 0.0006}}, + {"type": "sm3", "params": {"lr": 0.0006}}, + {"type": "madgrad_wd", "params": {"lr": 0.0006}}, + ] +} +opt_params, opt_name = parametrize( + OPTIMIZER_PARAMS, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None +) + + @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_train_optimizers(param_dict): @distributed_test(world_size=2) def wrapper(): run_train_test(param_dict=param_dict) + wrapper() @@ -94,8 +135,9 @@ def run_train_test(yaml_list=None, param_dict=None): # generate some random data on which we can overfit # context size of data is model seq_len + 1 in order to compute loss data_list = list() - context_tokens_tensor = torch.randint(0, args_loaded.padded_vocab_size, (4, args_loaded.seq_length + 1)).to( - torch.int64) + context_tokens_tensor = torch.randint( + 0, args_loaded.padded_vocab_size, (4, args_loaded.seq_length + 1) + ).to(torch.int64) for i in range(max_steps): data_list.append({"text": context_tokens_tensor.clone()}) data_iterator = iter(data_list) @@ -109,17 +151,21 @@ def run_train_test(yaml_list=None, param_dict=None): data_iterator=data_iterator, model=model, optimizer=optimizer, - lr_scheduler=lr_scheduler + lr_scheduler=lr_scheduler, ) losses.append(loss_dict["lm_loss"]) if len(losses) >= 2: - if torch.isnan(losses[-1]): continue - if torch.isnan(losses[-2]): continue + if torch.isnan(losses[-1]): + continue + if torch.isnan(losses[-2]): + continue if losses[-1] < losses[-2]: return # all good # loss should have decreased by now (otherwise increasing the max_steps parameter could have the testcase pass) - assert losses[-1] < losses[-2], "run_train_test() loss going down within " + str(max_steps) + " steps" + assert losses[-1] < losses[-2], ( + "run_train_test() loss going down within " + str(max_steps) + " steps" + ) if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0: clear_test_dirs() diff --git a/tests/neox_args/__init__.py b/tests/neox_args/__init__.py index 1ed4d46ed..b2ef7435c 100644 --- a/tests/neox_args/__init__.py +++ b/tests/neox_args/__init__.py @@ -1,3 +1,3 @@ """ testing of implementation of command line arguments and configuration (NeoXArgs) -""" \ No newline at end of file +""" diff --git a/tests/neox_args/test_neoxargs_commandline.py b/tests/neox_args/test_neoxargs_commandline.py index 0149916c5..95c06e42b 100644 --- a/tests/neox_args/test_neoxargs_commandline.py +++ b/tests/neox_args/test_neoxargs_commandline.py @@ -7,7 +7,8 @@ from ..common import get_root_directory, get_config_directory, get_configs_with_path -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_consume_deepy_args_with_config_dir(): """ verify consume_deepy_args processes command line arguments without config dir @@ -16,12 +17,17 @@ def test_neoxargs_consume_deepy_args_with_config_dir(): from megatron.neox_arguments import NeoXArgs # load neox args with command line - with patch('sys.argv', [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py"] + get_configs_with_path(["small.yml", "local_setup.yml"])): + with patch( + "sys.argv", + [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py"] + + get_configs_with_path(["small.yml", "local_setup.yml"]), + ): args_loaded_consume = NeoXArgs.consume_deepy_args() - # load neox args directly from yaml files - args_loaded_yamls = NeoXArgs.from_ymls(get_configs_with_path(["small.yml", "local_setup.yml"])) + args_loaded_yamls = NeoXArgs.from_ymls( + get_configs_with_path(["small.yml", "local_setup.yml"]) + ) # update values from yaml files that cannot otherwise be matched args_loaded_yamls.update_value("user_script", "pretrain_gpt2.py") @@ -29,7 +35,8 @@ def test_neoxargs_consume_deepy_args_with_config_dir(): assert args_loaded_yamls == args_loaded_consume -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_consume_deepy_args_without_yml_suffix(): """ verify consume_deepy_args processes command line arguments without yaml suffix @@ -38,12 +45,17 @@ def test_neoxargs_consume_deepy_args_without_yml_suffix(): from megatron.neox_arguments import NeoXArgs # load neox args with command line - with patch('sys.argv', [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py"] + get_configs_with_path(["small", "local_setup"])): + with patch( + "sys.argv", + [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py"] + + get_configs_with_path(["small", "local_setup"]), + ): args_loaded_consume = NeoXArgs.consume_deepy_args() - # load neox args directly from yaml files - args_loaded_yamls = NeoXArgs.from_ymls(get_configs_with_path(["small.yml", "local_setup.yml"])) + args_loaded_yamls = NeoXArgs.from_ymls( + get_configs_with_path(["small.yml", "local_setup.yml"]) + ) # update values from yaml files that cannot otherwise be matched args_loaded_yamls.update_value("user_script", "pretrain_gpt2.py") @@ -51,7 +63,8 @@ def test_neoxargs_consume_deepy_args_without_yml_suffix(): assert args_loaded_yamls == args_loaded_consume -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_consume_deepy_args_with_config_dir(): """ verify consume_deepy_args processes command line arguments including config dir @@ -60,12 +73,22 @@ def test_neoxargs_consume_deepy_args_with_config_dir(): from megatron.neox_arguments import NeoXArgs # load neox args with command line - with patch('sys.argv', [str(get_root_directory() / "deepy.py"), "pretrain_gpt2.py", '-d', str(get_config_directory())] + ["small.yml", "local_setup.yml"]): + with patch( + "sys.argv", + [ + str(get_root_directory() / "deepy.py"), + "pretrain_gpt2.py", + "-d", + str(get_config_directory()), + ] + + ["small.yml", "local_setup.yml"], + ): args_loaded_consume = NeoXArgs.consume_deepy_args() - # load neox args directly from yaml files - args_loaded_yamls = NeoXArgs.from_ymls(get_configs_with_path(["small.yml", "local_setup.yml"])) + args_loaded_yamls = NeoXArgs.from_ymls( + get_configs_with_path(["small.yml", "local_setup.yml"]) + ) # update values from yaml files that cannot otherwise be matched args_loaded_yamls.update_value("user_script", "pretrain_gpt2.py") @@ -73,24 +96,26 @@ def test_neoxargs_consume_deepy_args_with_config_dir(): assert args_loaded_yamls == args_loaded_consume -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_consume_neox_args(): """ verify megatron args are correctly consumed after sending via deepspeed """ from megatron.neox_arguments import NeoXArgs - + # intitially load config from files as would be the case in deepy.py yaml_list = get_configs_with_path(["small.yml", "local_setup.yml"]) args_baseline = NeoXArgs.from_ymls(yaml_list) - args_baseline.update_value("user_script", str(get_root_directory() / "pretrain_gpt2.py")) + args_baseline.update_value( + "user_script", str(get_root_directory() / "pretrain_gpt2.py") + ) deepspeed_main_args = args_baseline.get_deepspeed_main_args() # patch sys.argv so that args can be access by set_global_variables within initialize_megatron - with patch('sys.argv', deepspeed_main_args): + with patch("sys.argv", deepspeed_main_args): args_loaded = NeoXArgs.consume_neox_args() - #TODO is the wandb group really to be changed? + # TODO is the wandb group really to be changed? args_loaded.wandb_group = args_baseline.wandb_group assert args_baseline.megatron_config == args_loaded.megatron_config - diff --git a/tests/neox_args/test_neoxargs_implementation.py b/tests/neox_args/test_neoxargs_implementation.py index 5a096bd2c..7c7d60ff7 100644 --- a/tests/neox_args/test_neoxargs_implementation.py +++ b/tests/neox_args/test_neoxargs_implementation.py @@ -3,10 +3,12 @@ """ import pytest -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_duplicates(): """ tests that there are no duplicates among parent classes of NeoXArgs """ from megatron import NeoXArgs - assert NeoXArgs.validate_keys(), "test_neoxargs_duplicates" \ No newline at end of file + + assert NeoXArgs.validate_keys(), "test_neoxargs_duplicates" diff --git a/tests/neox_args/test_neoxargs_load.py b/tests/neox_args/test_neoxargs_load.py index 830a92dd1..a098044e5 100644 --- a/tests/neox_args/test_neoxargs_load.py +++ b/tests/neox_args/test_neoxargs_load.py @@ -1,10 +1,11 @@ """ -load all confings in neox/configs in order to perform validations implemented in NeoXArgs +load all confings in neox/configs in order to perform validations implemented in NeoXArgs """ import pytest import yaml from ..common import get_configs_with_path + def run_neox_args_load_test(yaml_files): from megatron.neox_arguments import NeoXArgs @@ -26,101 +27,121 @@ def run_neox_args_load_test(yaml_files): for conf_key, conf_value in conf.items(): if conf_key in config: raise ValueError( - f'Conf file {conf_file_name} has the following duplicate keys with previously loaded file: {conf_key}') + f"Conf file {conf_file_name} has the following duplicate keys with previously loaded file: {conf_key}" + ) - conf_key_converted = conf_key.replace("-", "_") # TODO remove replace and update configuration files? + conf_key_converted = conf_key.replace( + "-", "_" + ) # TODO remove replace and update configuration files? config[conf_key_converted] = conf_value # validate that neox args has the same value as specified in the config (if specified in the config) for k, v in config.items(): neox_args_value = getattr(args_loaded, k) - assert v == neox_args_value, "loaded neox args value "+str(k)+" == "+str(neox_args_value)+" different from config file "+str(v) + assert v == neox_args_value, ( + "loaded neox args value " + + str(k) + + " == " + + str(neox_args_value) + + " different from config file " + + str(v) + ) + -@pytest.mark.cpu +@pytest.mark.cpu def test_neoxargs_load_arguments_small_local_setup(): """ verify small.yml can be loaded without raising validation errors """ run_neox_args_load_test(["small.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_small_local_setup_text_generation(): """ verify small.yml can be loaded together with text generation without raising validation errors """ run_neox_args_load_test(["small.yml", "local_setup.yml", "text_generation.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_medium_local_setup(): """ verify medium.yml can be loaded without raising validation errors """ run_neox_args_load_test(["medium.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_large_local_setup(): """ verify large.yml can be loaded without raising validation errors """ run_neox_args_load_test(["large.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_2_7B_local_setup(): """ verify 2-7B.yml can be loaded without raising validation errors """ run_neox_args_load_test(["2-7B.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_6_7B_local_setup(): """ verify 6-7B.yml can be loaded without raising validation errors """ run_neox_args_load_test(["6-7B.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_13B_local_setup(): """ verify 13B.yml can be loaded without raising validation errors """ run_neox_args_load_test(["13B.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_XL_local_setup(): """ verify XL.yml can be loaded without raising validation errors """ run_neox_args_load_test(["XL.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_load_arguments_175B_local_setup(): """ verify 13B.yml can be loaded without raising validation errors """ run_neox_args_load_test(["175B.yml", "local_setup.yml"]) -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_fail_instantiate_without_required_params(): """ verify assertion error if required arguments are not provided """ - + try: run_neox_args_load_test(["local_setup.yml"]) assert False except Exception as e: assert True -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_fail_instantiate_without_any_params(): """ verify assertion error if required arguments are not provided """ from megatron.neox_arguments import NeoXArgs - + try: args_loaded = NeoXArgs() assert False except Exception as e: assert True - diff --git a/tests/neox_args/test_neoxargs_usage.py b/tests/neox_args/test_neoxargs_usage.py index d86b1b6c0..36439fcb6 100644 --- a/tests/neox_args/test_neoxargs_usage.py +++ b/tests/neox_args/test_neoxargs_usage.py @@ -5,37 +5,59 @@ import re from ..common import get_root_directory -@pytest.mark.cpu + +@pytest.mark.cpu def test_neoxargs_usage(): - """" + """ " checks for code pieces of the pattern "args.*" and verifies that such used arg is defined in NeoXArgs """ from megatron.neox_arguments import NeoXArgs - + declared_all = True neox_args_attributes = set(NeoXArgs.__dataclass_fields__.keys()) # we exlude a number of properties (implemented with the @property decorator) or functions that we know exists - exclude = set(['params_dtype', 'deepspeed_config', 'get', 'pop', 'get_deepspeed_main_args', 'optimizer["params"]', 'attention_config[layer_number]', 'adlr_autoresume_object', 'update_value', 'all_config', 'tensorboard_writer', 'tokenizer', 'train_batch_size]']) + exclude = set( + [ + "params_dtype", + "deepspeed_config", + "get", + "pop", + "get_deepspeed_main_args", + 'optimizer["params"]', + "attention_config[layer_number]", + "adlr_autoresume_object", + "update_value", + "all_config", + "tensorboard_writer", + "tokenizer", + "train_batch_size]", + ] + ) # test file by file - for filename in (get_root_directory() / "megatron").glob('**/*.py'): - if filename.name in ["text_generation_utils.py", "train_tokenizer.py"]: continue + for filename in (get_root_directory() / "megatron").glob("**/*.py"): + if filename.name in ["text_generation_utils.py", "train_tokenizer.py"]: + continue # load file - with open(filename, 'r') as f: + with open(filename, "r") as f: file_contents = f.read() # find args matches - matches = list(re.findall(r"(?<=args\.).{2,}?(?=[\s\n(){}+-/*;:,=])", file_contents)) - if len(matches) == 0: continue + matches = list( + re.findall(r"(?<=args\.).{2,}?(?=[\s\n(){}+-/*;:,=])", file_contents) + ) + if len(matches) == 0: + continue # compare for match in matches: if match not in neox_args_attributes and match not in exclude: - print(f"(arguments used not found in neox args): {filename.name}: {match}", flush=True) + print( + f"(arguments used not found in neox args): {filename.name}: {match}", + flush=True, + ) declared_all = False assert declared_all, "all arguments used in code defined in NeoXArgs" - - diff --git a/tests/pytest.ini b/tests/pytest.ini index baa2acc0b..6b5e00def 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,3 +1,3 @@ [pytest] markers = - cpu: marks tests that can be run on cpu \ No newline at end of file + cpu: marks tests that can be run on cpu diff --git a/tests/test_configs/test_train_base.yml b/tests/test_configs/test_train_base.yml index e742fe1e3..bc82cc400 100644 --- a/tests/test_configs/test_train_base.yml +++ b/tests/test_configs/test_train_base.yml @@ -58,7 +58,7 @@ "attention_dropout": 0.0, # precision settings - "fp16": { + "fp16": { "enabled": true, "loss_scale": 0, "loss_scale_window": 1000, diff --git a/tools/corpora.py b/tools/corpora.py index b4b1e8d0c..a14018cd4 100644 --- a/tools/corpora.py +++ b/tools/corpora.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2021, EleutherAI contributors # This file is based on code by the authors denoted below and has been modified from its original version. # @@ -22,10 +21,10 @@ """ This registry is for automatically downloading and extracting datasets. -To register a class you need to inherit the DataDownloader class, and provide name and url attributes, and (optionally) +To register a class you need to inherit the DataDownloader class, and provide name and url attributes, and (optionally) the number of documents. -When done, add it to the DATA_DOWNLOADERS dict. The function process_data runs the pre-processing for the selected +When done, add it to the DATA_DOWNLOADERS dict. The function process_data runs the pre-processing for the selected dataset. """ @@ -35,23 +34,30 @@ class DataDownloader(ABC): """Dataset registry class to automatically download / extract datasets""" - - def __init__(self, tokenizer_type=None, merge_file=None, vocab_file=None, data_dir=None, num_workers=None): + + def __init__( + self, + tokenizer_type=None, + merge_file=None, + vocab_file=None, + data_dir=None, + num_workers=None, + ): if tokenizer_type is None: tokenizer_type = "GPT2BPETokenizer" if data_dir is None: - data_dir = os.environ.get('DATA_DIR', './data') + data_dir = os.environ.get("DATA_DIR", "./data") if merge_file is None: merge_file = f"{data_dir}/gpt2-merges.txt" if vocab_file is None: if tokenizer_type == "GPT2BPETokenizer": vocab_file = f"{data_dir}/gpt2-vocab.json" elif tokenizer_type == "HFGPT2Tokenizer": - vocab_file = 'gpt2' + vocab_file = "gpt2" elif tokenizer_type == "CharLevelTokenizer": pass else: - assert vocab_file is not None, 'No vocab file provided' + assert vocab_file is not None, "No vocab file provided" if num_workers is None: num_workers = cpu_count() self._tokenizer_type = tokenizer_type @@ -115,15 +121,16 @@ def download(self): """downloads dataset""" os.makedirs(os.path.join(self.base_dir, self.name), exist_ok=True) for url in self.urls: - os.system(f"wget {url} -O {os.path.join(self.base_dir, self.name, os.path.basename(url))}") + os.system( + f"wget {url} -O {os.path.join(self.base_dir, self.name, os.path.basename(url))}" + ) def tokenize(self): """tokenizes dataset""" parent_folder = os.path.join(self.base_dir, self.name) - jsonl_filepath = ",".join([ - os.path.join(parent_folder, os.path.basename(url)) - for url in self.urls - ]) + jsonl_filepath = ",".join( + [os.path.join(parent_folder, os.path.basename(url)) for url in self.urls] + ) cmd = f"python tools/preprocess_data.py \ --input {jsonl_filepath} \ @@ -162,7 +169,10 @@ class PileSubset(DataDownloader): class Pile(DataDownloader): name = "pile" - urls = [f"https://mystic.the-eye.eu/public/AI/pile/train/{i:02}.jsonl.zst" for i in range(30)] + urls = [ + f"https://mystic.the-eye.eu/public/AI/pile/train/{i:02}.jsonl.zst" + for i in range(30) + ] class Github(DataDownloader): @@ -173,37 +183,50 @@ class Github(DataDownloader): class ArXiv(DataDownloader): name = "arxiv" urls = [ - "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/2020-09-08-arxiv-extracts-nofallback-until-2007-068.tar.gz"] + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/2020-09-08-arxiv-extracts-nofallback-until-2007-068.tar.gz" + ] class EuroParl(DataDownloader): name = "europarl" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/EuroParliamentProceedings_1996_2011.jsonl.zst"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/EuroParliamentProceedings_1996_2011.jsonl.zst" + ] class FreeLaw(DataDownloader): name = "freelaw" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/FreeLaw_Opinions.jsonl.zst"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/FreeLaw_Opinions.jsonl.zst" + ] class NiH(DataDownloader): name = "nih" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/NIH_ExPORTER_awarded_grant_text.jsonl.zst"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/NIH_ExPORTER_awarded_grant_text.jsonl.zst" + ] class PubMed(DataDownloader): name = "pubmed" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/PMC_extracts.tar.gz"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/PMC_extracts.tar.gz" + ] class Books1(DataDownloader): name = "books1" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/books1.tar.gz"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/books1.tar.gz" + ] class Books3(DataDownloader): name = "books3" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/books3.tar.gz"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/books3.tar.gz" + ] class HackerNews(DataDownloader): @@ -214,33 +237,47 @@ class HackerNews(DataDownloader): class OpenWebText2(DataDownloader): name = "openwebtext2" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/openwebtext2.jsonl.zst.tar"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/openwebtext2.jsonl.zst.tar" + ] num_docs = 17103000 class StackExchange(DataDownloader): name = "stackexchange" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/stackexchange_dataset.tar"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/stackexchange_dataset.tar" + ] class UbuntuIRC(DataDownloader): name = "ubuntu_irc" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/ubuntu_irc_until_2020_9_1.jsonl.zst"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/ubuntu_irc_until_2020_9_1.jsonl.zst" + ] class YoutubeSubtitles(DataDownloader): name = "youtube_subtitles" - urls = ["https://mystic.the-eye.eu/public/AI/pile_preliminary_components/yt_subs.jsonl.zst"] + urls = [ + "https://mystic.the-eye.eu/public/AI/pile_preliminary_components/yt_subs.jsonl.zst" + ] class C4(DataDownloader): name = "c4" - urls = [f"https://mystic.the-eye.eu/eleuther_staging/c4/en/c4-train.{i:05}-of-01024.json.gz" for i in range(1024)] + urls = [ + f"https://mystic.the-eye.eu/eleuther_staging/c4/en/c4-train.{i:05}-of-01024.json.gz" + for i in range(1024) + ] class C4OpenWebText(DataDownloader): name = "c4_openwebtext" - urls = [f"https://mystic.the-eye.eu/eleuther_staging/c4/realnewslike/c4-train.{i:05}-of-00512.json.gz" for i in range(512)] + urls = [ + f"https://mystic.the-eye.eu/eleuther_staging/c4/realnewslike/c4-train.{i:05}-of-00512.json.gz" + for i in range(512) + ] class Enwik8(DataDownloader): @@ -253,9 +290,9 @@ def maybe_download_gpt2_tokenizer_data(tokenizer_type, data_dir): GPT2_VOCAB_FP = f"{data_dir}//gpt2-vocab.json" GPT2_MERGE_FP = f"{data_dir}/gpt2-merges.txt" if not os.path.isfile(GPT2_VOCAB_FP): - os.system(f'wget {GPT2_VOCAB_URL} -O {GPT2_VOCAB_FP}') + os.system(f"wget {GPT2_VOCAB_URL} -O {GPT2_VOCAB_FP}") if not os.path.isfile(GPT2_MERGE_FP): - os.system(f'wget {GPT2_MERGE_URL} -O {GPT2_MERGE_FP}') + os.system(f"wget {GPT2_MERGE_URL} -O {GPT2_MERGE_FP}") DATA_DOWNLOADERS = { @@ -278,28 +315,40 @@ def maybe_download_gpt2_tokenizer_data(tokenizer_type, data_dir): "youtube_subtitles": YoutubeSubtitles, "c4": C4, "c4_openwebtext": C4OpenWebText, - "enwik8": Enwik8 + "enwik8": Enwik8, } -def prepare_dataset(dataset_name: str, tokenizer_type: str = None, data_dir: str = None, vocab_file: str = None, - merge_file: str = None, num_workers: int = None): +def prepare_dataset( + dataset_name: str, + tokenizer_type: str = None, + data_dir: str = None, + vocab_file: str = None, + merge_file: str = None, + num_workers: int = None, +): """ Downloads + tokenizes a dataset in the registry (dataset_name) and saves output .npy files to data_dir. """ if data_dir is None: - data_dir = os.environ.get('DATA_DIR', './data') + data_dir = os.environ.get("DATA_DIR", "./data") os.makedirs(data_dir, exist_ok=True) maybe_download_gpt2_tokenizer_data(tokenizer_type, data_dir) DownloaderClass = DATA_DOWNLOADERS.get(dataset_name.lower(), None) if DownloaderClass is None: raise NotImplementedError( - f'Dataset "{dataset_name}" not recognized - please choose from {list(DATA_DOWNLOADERS.keys())}') + f'Dataset "{dataset_name}" not recognized - please choose from {list(DATA_DOWNLOADERS.keys())}' + ) elif DownloaderClass == "pass": # pass on building dataset (for unit tests) pass else: num_workers = 1 if dataset_name == "enwik8" else num_workers - d = DownloaderClass(tokenizer_type=tokenizer_type, vocab_file=vocab_file, merge_file=merge_file, - data_dir=data_dir, num_workers=num_workers) + d = DownloaderClass( + tokenizer_type=tokenizer_type, + vocab_file=vocab_file, + merge_file=merge_file, + data_dir=data_dir, + num_workers=num_workers, + ) d.prepare() diff --git a/tools/inspect_checkpoints.py b/tools/inspect_checkpoints.py index 67404f112..d6483b239 100644 --- a/tools/inspect_checkpoints.py +++ b/tools/inspect_checkpoints.py @@ -14,30 +14,33 @@ class COLORS: BLUE = "\033[94m" CYAN = "\033[96m" GREEN = "\033[92m" - RED = '\033[31m' - YELLOW = '\033[33m' - MAGENTA = '\033[35m' - WHITE = '\033[37m' - UNDERLINE = '\033[4m' + RED = "\033[31m" + YELLOW = "\033[33m" + MAGENTA = "\033[35m" + WHITE = "\033[37m" + UNDERLINE = "\033[4m" END = "\033[0m" PRIMITIVE_TYPES = (int, float, bool, str, type) + def natural_sort(l): convert = lambda text: int(text) if text.isdigit() else text.lower() - alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', str(key))] + alphanum_key = lambda key: [convert(c) for c in re.split("([0-9]+)", str(key))] return sorted(l, key=alphanum_key) -def sizeof_fmt(num, suffix='B'): - for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']: + +def sizeof_fmt(num, suffix="B"): + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if abs(num) < 1024.0: return "%3.1f%s%s" % (num, unit, suffix) num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) + return "%.1f%s%s" % (num, "Yi", suffix) + def pretty_print(contents: dict): - """ Prints a nice summary of the top-level contens in a checkpoint dictionary. """ + """Prints a nice summary of the top-level contens in a checkpoint dictionary.""" col_size = max(len(str(k)) for k in contents) for k, v in sorted(contents.items()): key_length = len(str(k)) @@ -60,7 +63,10 @@ def pretty_print(contents: dict): line += f"{COLORS.CYAN}shape={list(v.shape)}{COLORS.END}" line += ", " line += f"{COLORS.CYAN}dtype={v.dtype}{COLORS.END}" - line += ", " + f"{COLORS.CYAN}size={sizeof_fmt(v.nelement() * v.element_size())}{COLORS.END}" + line += ( + ", " + + f"{COLORS.CYAN}size={sizeof_fmt(v.nelement() * v.element_size())}{COLORS.END}" + ) print(line) @@ -72,8 +78,10 @@ def common_entries(*dcts): def pretty_print_double(contents1: dict, contents2: dict, args): - """ Prints a nice summary of the top-level contens in a checkpoint dictionary. """ - col_size = max(max(len(str(k)) for k in contents1), max(len(str(k)) for k in contents2)) + """Prints a nice summary of the top-level contens in a checkpoint dictionary.""" + col_size = max( + max(len(str(k)) for k in contents1), max(len(str(k)) for k in contents2) + ) common_keys = list(contents1.keys() & contents2.keys()) uncommon_keys_1 = [i for i in contents2.keys() if i not in common_keys] uncommon_keys_2 = [i for i in contents1.keys() if i not in common_keys] @@ -81,14 +89,20 @@ def pretty_print_double(contents1: dict, contents2: dict, args): if uncommon_keys_1 + uncommon_keys_2: diffs_found = True if uncommon_keys_1: - print(f"{COLORS.RED}{len(uncommon_keys_1)} key(s) found in ckpt 1 that isn't present in ckpt 2:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_1)}{COLORS.END}") + print( + f"{COLORS.RED}{len(uncommon_keys_1)} key(s) found in ckpt 1 that isn't present in ckpt 2:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_1)}{COLORS.END}" + ) if uncommon_keys_2: - print(f"{COLORS.RED}{len(uncommon_keys_2)} key(s) found in ckpt 2 that isn't present in ckpt 1:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_2)}{COLORS.END}") + print( + f"{COLORS.RED}{len(uncommon_keys_2)} key(s) found in ckpt 2 that isn't present in ckpt 1:{COLORS.END} \n\t{COLORS.BLUE}{' '.join(uncommon_keys_2)}{COLORS.END}" + ) for k, v1, v2 in sorted(common_entries(contents1, contents2)): key_length = len(str(k)) line = " " * (col_size - key_length) if type(v1) != type(v2): - print(f"{COLORS.RED}{k} is a different type between ckpt1 and ckpt2: ({type(v1).__name__} vs. {type(v2).__name__}){COLORS.END}") + print( + f"{COLORS.RED}{k} is a different type between ckpt1 and ckpt2: ({type(v1).__name__} vs. {type(v2).__name__}){COLORS.END}" + ) continue else: prefix = f"{k}: {COLORS.BLUE}{type(v1).__name__} | {type(v2).__name__}{COLORS.END}" @@ -115,12 +129,14 @@ def pretty_print_double(contents1: dict, contents2: dict, args): line += ", " line += f"{c}len={len(v1)} | len={len(v2)}{COLORS.END}" elif isinstance(v1, torch.Tensor): - if (v1.ndimension() != v2.ndimension()): + if v1.ndimension() != v2.ndimension(): c = COLORS.RED else: c = COLORS.CYAN - if (v1.ndimension() in (0, 1) and v1.numel() == 1) and (v2.ndimension() in (0, 1) and v2.numel() == 1): + if (v1.ndimension() in (0, 1) and v1.numel() == 1) and ( + v2.ndimension() in (0, 1) and v2.numel() == 1 + ): if not args.diff: line += f" = " line += f"{c}{v1.item()} | {c}{v2.item()}{COLORS.END}" @@ -151,7 +167,7 @@ def pretty_print_double(contents1: dict, contents2: dict, args): pass else: if not args.diff: - print('\n') + print("\n") return diffs_found @@ -168,10 +184,10 @@ def get_files(pth): if os.path.isdir(pth): files = list(Path(pth).glob("*.pt")) + list(Path(pth).glob("*.ckpt")) elif os.path.isfile(pth): - assert pth.endswith(".pt") or pth.endswith('.ckpt') + assert pth.endswith(".pt") or pth.endswith(".ckpt") files = list(Path(pth)) else: - raise ValueError('Dir / File not found.') + raise ValueError("Dir / File not found.") return natural_sort(files) @@ -192,7 +208,7 @@ def peek(args: Namespace): current = get_attribute(current, part) selection.update({name: current}) pretty_print(selection) - print('\n') + print("\n") if args.interactive: code.interact( @@ -207,7 +223,9 @@ def get_shared_fnames(files_1, files_2): names_2 = [Path(i).name for i in files_2] names_2_parent = Path(files_2[0]).parent shared_names = list(set.intersection(*map(set, [names_1, names_2]))) - return [names_1_parent / i for i in shared_names], [names_2_parent / i for i in shared_names] + return [names_1_parent / i for i in shared_names], [ + names_2_parent / i for i in shared_names + ] def get_selection(filename, args): @@ -224,7 +242,7 @@ def get_selection(filename, args): def compare(args: Namespace): - dirs = [i.strip() for i in args.dir.split(',')] + dirs = [i.strip() for i in args.dir.split(",")] assert len(dirs) == 2, "Only works with 2 directories / files" files_1 = get_files(dirs[0]) files_2 = get_files(dirs[1]) @@ -239,16 +257,21 @@ def compare(args: Namespace): selection_2 = get_selection(file2, args) diffs_found = pretty_print_double(selection_1, selection_2, args) if args.diff and diffs_found: - print(f"{COLORS.RED}THE ABOVE DIFFS WERE FOUND IN {file1.name} & {file2.name} ^{COLORS.END}\n") + print( + f"{COLORS.RED}THE ABOVE DIFFS WERE FOUND IN {file1.name} & {file2.name} ^{COLORS.END}\n" + ) if args.interactive: code.interact( banner="Entering interactive shell. You can access the checkpoint contents through the local variable 'selection_1' / 'selection_2'.\nPress Ctrl-D to exit.", - local={"selection_1": selection_1, "selection_2": selection_2, "torch": torch}, + local={ + "selection_1": selection_1, + "selection_2": selection_2, + "torch": torch, + }, ) - def main(): parser = ArgumentParser() parser.add_argument( @@ -263,7 +286,7 @@ def main(): "--attributes", nargs="*", help="Name of one or several attributes to query. To access an attribute within a nested structure, use '/' as separator.", - default=None + default=None, ) parser.add_argument( "--interactive", @@ -271,8 +294,15 @@ def main(): action="store_true", help="Drops into interactive shell after printing the summary.", ) - parser.add_argument("--compare", "-c", action="store_true", help="If true, script will compare two directories separated by commas") - parser.add_argument("--diff", "-d", action="store_true", help="In compare mode, only print diffs") + parser.add_argument( + "--compare", + "-c", + action="store_true", + help="If true, script will compare two directories separated by commas", + ) + parser.add_argument( + "--diff", "-d", action="store_true", help="In compare mode, only print diffs" + ) args = parser.parse_args() if args.compare: diff --git a/tools/kill.sh b/tools/kill.sh index f39cbf1aa..bccd46d7e 100755 --- a/tools/kill.sh +++ b/tools/kill.sh @@ -1 +1 @@ -pkill -9 python \ No newline at end of file +pkill -9 python diff --git a/tools/merge_mp_partitions.py b/tools/merge_mp_partitions.py index 33bdb44ba..e949a6458 100644 --- a/tools/merge_mp_partitions.py +++ b/tools/merge_mp_partitions.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,8 +16,10 @@ import os import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir))) + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +) import torch @@ -32,18 +33,16 @@ def split_into_partitions(tensor, num_partitions, partition_dim, stride): - per_partition_size = mpu.utils.divide(tensor.size(partition_dim), - num_partitions) + per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions) per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) - partitions_list = torch.split(tensor, - per_partition_per_stride_size, - dim=partition_dim) + partitions_list = torch.split( + tensor, per_partition_per_stride_size, dim=partition_dim + ) partitions = [] for i in range(num_partitions): - partition = torch.cat(partitions_list[i::num_partitions], - dim=partition_dim) + partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim) partitions.append(partition) return partitions @@ -62,20 +61,24 @@ def merge_partitions(merged, partitions, partition_dim, stride): def concat_partitions(partitions_): with torch.no_grad(): - if (per_partition_size * num_partitions) == merged.size( - partition_dim): + if (per_partition_size * num_partitions) == merged.size(partition_dim): torch.cat(partitions_, dim=partition_dim, out=merged) else: - print(' ***WARNING*** sizes do not match. Will cut ' - 'the merged partitions by {} along dimension {} ' - 'to reduce the size from {} to {} ...'.format( - (per_partition_size * num_partitions) - \ - merged.size(partition_dim), partition_dim, - per_partition_size * num_partitions, - merged.size(partition_dim))) + print( + " ***WARNING*** sizes do not match. Will cut " + "the merged partitions by {} along dimension {} " + "to reduce the size from {} to {} ...".format( + (per_partition_size * num_partitions) + - merged.size(partition_dim), + partition_dim, + per_partition_size * num_partitions, + merged.size(partition_dim), + ) + ) merged_ = torch.cat(partitions_, dim=partition_dim) - merged_split = torch.split(merged_, merged.size(partition_dim), - dim=partition_dim) + merged_split = torch.split( + merged_, merged.size(partition_dim), dim=partition_dim + ) merged_ = merged_split[0] assert merged_.size(partition_dim) == merged.size(partition_dim) merged.data.copy_(merged_.data) @@ -90,12 +93,10 @@ def concat_partitions(partitions_): # Chunk and build a list. chunks = None for i, partition in enumerate(partitions): - chunk = torch.split(partition, - per_partition_per_stride_size, - dim=partition_dim) + chunk = torch.split(partition, per_partition_per_stride_size, dim=partition_dim) if chunks is None: - chunks = [0]*(num_partitions*len(chunk)) + chunks = [0] * (num_partitions * len(chunk)) chunks[i::num_partitions] = chunk # Concatinate. @@ -106,10 +107,10 @@ def concat_partitions(partitions_): def get_model(model_type): - if model_type == 'GPT2': + if model_type == "GPT2": from pretrain_gpt2 import model_provider else: - raise Exception('unrecognized model type: {}'.format(model_type)) + raise Exception("unrecognized model type: {}".format(model_type)) model = model_provider() model = model.half() @@ -121,7 +122,7 @@ def get_parallel_checkpoint_name(path): tracker_filename = get_checkpoint_tracker_filename(path) iteration = 0 - with open(tracker_filename, 'r') as f: + with open(tracker_filename, "r") as f: metastring = f.read().strip() iteration = int(metastring) assert iteration > 0 @@ -132,42 +133,49 @@ def get_parallel_checkpoint_name(path): def test_split_merge(): - print('testing split and merge ...') - - #[QKV.ROW-COL] - tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15], - [1.21, 1.22, 1.23, 1.24, 1.25], - [1.31, 1.32, 1.33, 1.34, 1.35], - [1.41, 1.42, 1.43, 1.44, 1.45], - [2.11, 2.12, 2.13, 2.14, 2.15], - [2.21, 2.22, 2.23, 2.24, 2.25], - [2.31, 2.32, 2.33, 2.34, 2.35], - [2.41, 2.42, 2.43, 2.44, 2.45], - [3.11, 3.12, 3.13, 3.14, 3.15], - [3.21, 3.22, 3.23, 3.24, 3.25], - [3.31, 3.32, 3.33, 3.34, 3.35], - [3.41, 3.42, 3.43, 3.44, 3.45]]) + print("testing split and merge ...") + + # [QKV.ROW-COL] + tensor = torch.FloatTensor( + [ + [1.11, 1.12, 1.13, 1.14, 1.15], + [1.21, 1.22, 1.23, 1.24, 1.25], + [1.31, 1.32, 1.33, 1.34, 1.35], + [1.41, 1.42, 1.43, 1.44, 1.45], + [2.11, 2.12, 2.13, 2.14, 2.15], + [2.21, 2.22, 2.23, 2.24, 2.25], + [2.31, 2.32, 2.33, 2.34, 2.35], + [2.41, 2.42, 2.43, 2.44, 2.45], + [3.11, 3.12, 3.13, 3.14, 3.15], + [3.21, 3.22, 3.23, 3.24, 3.25], + [3.31, 3.32, 3.33, 3.34, 3.35], + [3.41, 3.42, 3.43, 3.44, 3.45], + ] + ) num_partitions = 2 partition_dim = 0 stride = 3 - partitions = split_into_partitions(tensor, num_partitions, - partition_dim, stride) + partitions = split_into_partitions(tensor, num_partitions, partition_dim, stride) merged = torch.zeros_like(tensor) merge_partitions(merged, partitions, partition_dim, stride) max_error = (merged - tensor).abs().max() - print(' > max error (should be zero): {}'.format(max_error)) + print(" > max error (should be zero): {}".format(max_error)) def get_mp_merge_args(parser): """Provide extra arguments required for merging.""" - group = parser.add_argument_group(title='mp merge') + group = parser.add_argument_group(title="mp merge") - group.add_argument('--model-type', type=str, required=True, - choices=['BERT', 'GPT2', 'RACE', 'MNLI', 'QQP'], - help='Type of the mdoel.') + group.add_argument( + "--model-type", + type=str, + required=True, + choices=["BERT", "GPT2", "RACE", "MNLI", "QQP"], + help="Type of the mdoel.", + ) return parser @@ -181,21 +189,20 @@ def main(): args.model_parallel_size = 1 tokenizer = rebuild_tokenizer(args) - print('\n merging model parallel partitions ...') - print(' > number of partitions: {}'.format(orig_model_parallel_size)) - print(' > checkpoint path: {}'.format(args.load)) - print(' > model parameters:') - print(' number of tokens ................ {} '.format( - tokenizer.vocab_size)) - print(' number of layers ................ {}'.format(args.num_layers)) - print(' hidden sise ..................... {}'.format(args.hidden_size)) - print(' number of attention heads ....... {}'.format( - args.num_attention_heads)) - print(' maximum position embeddings ..... {}'.format( - args.max_position_embeddings)) + print("\n merging model parallel partitions ...") + print(" > number of partitions: {}".format(orig_model_parallel_size)) + print(" > checkpoint path: {}".format(args.load)) + print(" > model parameters:") + print(" number of tokens ................ {} ".format(tokenizer.vocab_size)) + print(" number of layers ................ {}".format(args.num_layers)) + print(" hidden sise ..................... {}".format(args.hidden_size)) + print(" number of attention heads ....... {}".format(args.num_attention_heads)) + print( + " maximum position embeddings ..... {}".format(args.max_position_embeddings) + ) # Full model. - print('> building the full model ...') + print("> building the full model ...") mpu.initialize.set_model_parallel_world_size(1) mpu.initialize.set_model_parallel_rank(0) merged_model = get_model(model_type) @@ -209,67 +216,75 @@ def main(): for rank in range(args.model_parallel_size): mpu.initialize.set_model_parallel_rank(rank) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) - print('> loading {} ...'.format(checkpoint_name)) + print("> loading {} ...".format(checkpoint_name)) model_ = get_model(model_type) - sd = torch.load(checkpoint_name, map_location='cpu') - model_.load_state_dict(sd['model']) + sd = torch.load(checkpoint_name, map_location="cpu") + model_.load_state_dict(sd["model"]) partitions.append(model_) - # Parameter generators so we can loop through them semiltaneouly. merged_params_gen = merged_model.named_parameters() - partitions_params_gen = [partition.named_parameters() - for partition in partitions] + partitions_params_gen = [partition.named_parameters() for partition in partitions] while True: try: # Get the params and check names. name, merged_param = next(merged_params_gen) - print(' > working on {} ...'.format(name)) - print(' merged type: {}, size: {}'.format( - merged_param.dtype, list(merged_param.size()))) + print(" > working on {} ...".format(name)) + print( + " merged type: {}, size: {}".format( + merged_param.dtype, list(merged_param.size()) + ) + ) partitions_param = [] for rank, partition_params_gen in enumerate(partitions_params_gen): partition_name, partition_param = next(partition_params_gen) assert partition_name == name partitions_param.append(partition_param) - print(' partition {} type: {}, size: {}'.format( - rank, partition_param.dtype, list(partition_param.size()))) + print( + " partition {} type: {}, size: {}".format( + rank, partition_param.dtype, list(partition_param.size()) + ) + ) # For the non-parallel parameters, simply copy the rank 0 values. - if not hasattr(merged_param, 'model_parallel'): - print(' none-parallel parameter, simple copy from rank 0') + if not hasattr(merged_param, "model_parallel"): + print(" none-parallel parameter, simple copy from rank 0") with torch.no_grad(): merged_param.data.copy_(partitions_param[0].data) # For parallel parameters, merge the values else: - print(' parallel parameter merge with stride {} along ' - 'dimention {}'.format(merged_param.stride, - merged_param.partition_dim)) - merge_partitions(merged_param, - partitions_param, - merged_param.partition_dim, - merged_param.stride) + print( + " parallel parameter merge with stride {} along " + "dimention {}".format( + merged_param.stride, merged_param.partition_dim + ) + ) + merge_partitions( + merged_param, + partitions_param, + merged_param.partition_dim, + merged_param.stride, + ) except StopIteration: break - # Save the model. args.model_parallel_size = 1 mpu.initialize.set_model_parallel_rank(0) sd = {} - sd['model'] = merged_model.state_dict() - sd['iteration'] = iteration - merged_path = os.path.join(args.load, 'merged') + sd["model"] = merged_model.state_dict() + sd["iteration"] = iteration + merged_path = os.path.join(args.load, "merged") checkpoint_name = get_checkpoint_name(merged_path, iteration) ensure_directory_exists(checkpoint_name) - print('> saving merged model to {}'.format(checkpoint_name)) + print("> saving merged model to {}".format(checkpoint_name)) torch.save(sd, checkpoint_name) - print('done :-)') + print("done :-)") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 4dec16145..a562b15ac 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tools/sync.sh b/tools/sync.sh index 80cf21140..28f537b70 100755 --- a/tools/sync.sh +++ b/tools/sync.sh @@ -11,4 +11,4 @@ do full_path=$(realpath $file) echo Uploading $full_path pdcp -f 1024 -R ssh -w ^/job/hosts $full_path $full_path -done \ No newline at end of file +done diff --git a/tools/sync_cmd.sh b/tools/sync_cmd.sh index da629a632..b8dac7b24 100644 --- a/tools/sync_cmd.sh +++ b/tools/sync_cmd.sh @@ -5,4 +5,4 @@ # sync_cmd.sh 'echo "hello world"' echo "Command: $1"; -pdsh -R ssh -w ^/job/hosts $1 \ No newline at end of file +pdsh -R ssh -w ^/job/hosts $1 diff --git a/tools/syncdir.sh b/tools/syncdir.sh index 754a74fea..272f20d13 100755 --- a/tools/syncdir.sh +++ b/tools/syncdir.sh @@ -12,4 +12,4 @@ do parentdir="$(dirname "$full_path")" echo Uploading $full_path to $parentdir pdcp -f 1024 -R ssh -w ^/job/hosts -r $full_path $parentdir -done \ No newline at end of file +done diff --git a/train.py b/train.py index 36537c81b..3dd4985a5 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,6 @@ if __name__ == "__main__": neox_args = NeoXArgs.consume_neox_args() neox_args.configure_distributed_args() - neox_args.build_tokenizer() # tokenizer needs to be build in training in order to set the padding vocab + neox_args.build_tokenizer() # tokenizer needs to be build in training in order to set the padding vocab neox_args.initialize_tensorboard_writer() # is initialized if tensorboard directory is defined pretrain(neox_args=neox_args)