-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored main branch #1
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to GitHub API limits, only the first 60 comments can be shown.
help=f"Directory to which to download datasets / tokenizer " | ||
f"files - defaults to ./data", | ||
help="Directory to which to download datasets / tokenizer ", | ||
) | ||
|
||
parser.add_argument( | ||
"-v", "--vocab-file", default=None, help=f"Tokenizer vocab file (if required)" | ||
"-v", | ||
"--vocab-file", | ||
default=None, | ||
help="Tokenizer vocab file (if required)", | ||
) | ||
|
||
parser.add_argument( | ||
"-m", "--merge-file", default=None, help=f"Tokenizer merge file (if required)" | ||
"-m", | ||
"--merge-file", | ||
default=None, | ||
help="Tokenizer merge file (if required)", | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_args
refactored with the following changes:
- Replace f-string with no interpolated values with string (
remove-redundant-fstring
)
lines = [] | ||
lines.append(intro_str) | ||
lines = [intro_str] | ||
for name, doc in docs.items(): | ||
lines.append(f"## {name}") | ||
lines.append(f"{doc['doc']}") | ||
lines.append("") | ||
lines.extend((f"{doc['doc']}", "")) | ||
for field_name, field_def in doc["attributes"].items(): | ||
# attribute name and type | ||
lines.append(f"- **{field_name}**: {field_def['type']}") | ||
# default value | ||
lines.append(f" Default = {str(field_def['default'])}") | ||
lines.append(f" {field_def['doc']}") | ||
lines.append("") | ||
lines.extend((f" {field_def['doc']}", "")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function to_md
refactored with the following changes:
- Merge append into list declaration (
merge-list-append
) - Merge consecutive list appends into a single extend (
merge-list-appends-into-extend
)
self.is_last_stage = ( | ||
True if not self.is_pipe_parallel else model.is_last_stage() | ||
) # only the last stage of the pipeline model will receive the logits | ||
self.is_last_stage = model.is_last_stage() if self.is_pipe_parallel else True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function EvalHarnessAdapter.__init__
refactored with the following changes:
- Swap if/else branches of if expression to remove negation (
swap-if-expression
)
This removes the following comments ( why? ):
# only the last stage of the pipeline model will receive the logits
error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format( | ||
checkpoint_arg_name, checkpoint_arg_value, args_value | ||
) | ||
error_message = f"{checkpoint_arg_name} value from checkpoint ({checkpoint_arg_value}) is not equal to the currently set argument value ({args_value})." | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function check_checkpoint_args
refactored with the following changes:
- Replace call to format with f-string. (
use-fstring-for-formatting
)
if ( | ||
logits is not None and checkpoint_logits is not None | ||
): # this could be the case for non-final pipeline stages | ||
if not (logits == checkpoint_logits).all().item(): | ||
if mpu.get_data_parallel_rank() == 0: | ||
print( | ||
" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result" | ||
) | ||
assert ( | ||
torch.isclose(logits, checkpoint_logits).all().item() | ||
), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result" | ||
if (logits is not None and checkpoint_logits is not None) and not ( | ||
logits == checkpoint_logits | ||
).all().item(): | ||
if mpu.get_data_parallel_rank() == 0: | ||
print( | ||
" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result" | ||
) | ||
assert ( | ||
torch.isclose(logits, checkpoint_logits).all().item() | ||
), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function check_forward_pass
refactored with the following changes:
- Merge nested if conditions (
merge-nested-ifs
)
This removes the following comments ( why? ):
# this could be the case for non-final pipeline stages
if not 0.0 <= lr: | ||
if lr < 0.0: | ||
raise ValueError("Invalid learning rate: {0}".format(lr)) | ||
if not 0.0 <= momentum < 1.0: | ||
raise ValueError("Invalid momentum: {0}".format(momentum)) | ||
if not 0.0 <= beta < 1.0: | ||
raise ValueError("Invalid beta: {0}".format(beta)) | ||
if not 0.0 <= eps: | ||
if eps < 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function SM3.__init__
refactored with the following changes:
- Simplify logical expression using De Morgan identities (
de-morgan
) - Ensure constant in comparison is on the right (
flip-comparison
)
return "accumulator_" + str(i) | ||
return f"accumulator_{str(i)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function _key
refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
raise ValueError(f"Eps must be non-negative") | ||
raise ValueError("Eps must be non-negative") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function madgrad_wd.__init__
refactored with the following changes:
- Replace f-string with no interpolated values with string (
remove-redundant-fstring
)
else: | ||
# we need to format inputs this way because: | ||
# a) deepspeed pipeline only accepts iterables | ||
# b) deepspeed pipeline *requires* that you pass in labels for the loss, it's not easy to get around this | ||
# so we wrap the inputs in an iterable, and pad them (because internally, we get labels as inputs[:, 1:] and inputs as inputs[:, :-1]) | ||
model_inputs = iter([{"text": F.pad(model_inputs[0], pad=(0, 1))}]) | ||
|
||
# set num microbatches to 1 at inference time | ||
micro_batches_before = model.micro_batches | ||
model.micro_batches = 1 | ||
|
||
# deepspeed sends metadata across pipeline stages only once in the first step, then assumes it will stay | ||
# constant. In inference, the metadata of the tensors being sent across pipe stages may change, so we need to set | ||
# these two flags in order for deepspeed to send the metadata every step, otherwise torch.distributed hangs | ||
# silently. Fun stuff. | ||
model.first_output_send = True | ||
model.pipe_recv_buf = None | ||
|
||
loss, logits = model.eval_batch(model_inputs, return_logits=True) | ||
model.micro_batches = micro_batches_before | ||
return logits | ||
# we need to format inputs this way because: | ||
# a) deepspeed pipeline only accepts iterables | ||
# b) deepspeed pipeline *requires* that you pass in labels for the loss, it's not easy to get around this | ||
# so we wrap the inputs in an iterable, and pad them (because internally, we get labels as inputs[:, 1:] and inputs as inputs[:, :-1]) | ||
model_inputs = iter([{"text": F.pad(model_inputs[0], pad=(0, 1))}]) | ||
|
||
# set num microbatches to 1 at inference time | ||
micro_batches_before = model.micro_batches | ||
model.micro_batches = 1 | ||
|
||
# deepspeed sends metadata across pipeline stages only once in the first step, then assumes it will stay | ||
# constant. In inference, the metadata of the tensors being sent across pipe stages may change, so we need to set | ||
# these two flags in order for deepspeed to send the metadata every step, otherwise torch.distributed hangs | ||
# silently. Fun stuff. | ||
model.first_output_send = True | ||
model.pipe_recv_buf = None | ||
|
||
loss, logits = model.eval_batch(model_inputs, return_logits=True) | ||
model.micro_batches = micro_batches_before | ||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function forward_model
refactored with the following changes:
- Swap if/else branches (
swap-if-else-branches
) - Add guard clause (
last-if-guard
)
"generate_samples_input_from_file() loading input from {}".format(input_file) | ||
f"generate_samples_input_from_file() loading input from {input_file}" | ||
) | ||
|
||
with open(input_file, "r") as f: | ||
prompts = f.readlines() | ||
prompts = [p.strip() for p in prompts] | ||
prompts = [p for p in prompts if len(p) > 0] | ||
print_rank_0( | ||
"generate_samples_input_from_file() prompts loaded: {}".format(len(prompts)) | ||
f"generate_samples_input_from_file() prompts loaded: {len(prompts)}" | ||
) | ||
|
||
if is_mp_rank_0(): | ||
if output_file is None: | ||
output_file = str(input_file) + ".output.jsonl" | ||
print_rank_0( | ||
"generate_samples_input_from_file() setting default output file to {}".format( | ||
output_file | ||
) | ||
) | ||
|
||
if is_mp_rank_0() and output_file is None: | ||
output_file = f"{str(input_file)}.output.jsonl" | ||
print_rank_0( | ||
f"generate_samples_input_from_file() setting default output file to {output_file}" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function generate_samples_input_from_file
refactored with the following changes:
- Replace call to format with f-string. (
use-fstring-for-formatting
) - Merge nested if conditions (
merge-nested-ifs
) - Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
if is_mp_rank_0(): | ||
if output_file is not None: | ||
with open(output_file, "w") as f_out: | ||
for item in generated_texts: | ||
f_out.write(json.dumps(item) + "\n") | ||
if is_mp_rank_0() and output_file is not None: | ||
with open(output_file, "w") as f_out: | ||
for item in generated_texts: | ||
f_out.write(json.dumps(item) + "\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function generate_samples_unconditional
refactored with the following changes:
- Merge nested if conditions (
merge-nested-ifs
)
print_rank_0("Generated Text: " + generated_text) | ||
print_rank_0(f"Generated Text: {generated_text}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function generate_samples_interactive
refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
if data_iterator is not None: | ||
data = next(data_iterator) | ||
else: | ||
data = None | ||
data = next(data_iterator) if data_iterator is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_batch
refactored with the following changes:
- Replace if statement with if expression (
assign-if-exp
)
if not "soft_embedding" in name: | ||
if "soft_embedding" not in name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_model
refactored with the following changes:
- Simplify logical expression using De Morgan identities (
de-morgan
)
lr_scheduler = AnnealingLR( | ||
return AnnealingLR( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_learning_rate_scheduler
refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable
)
print_rank_0(" {}:".format(name)) | ||
print_rank_0(" no. of documents:{}".format(total_num_of_documents)) | ||
print_rank_0(f" {name}:") | ||
print_rank_0(f" no. of documents:{total_num_of_documents}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function build_the_dataset
refactored with the following changes:
- Replace call to format with f-string. (
use-fstring-for-formatting
)
print_rank_0(" {}:".format(name)) | ||
print_rank_0(f" {name}:") | ||
print_rank_0( | ||
" document indices in [{}, {}) total of {} " | ||
"documents".format( | ||
splits[index], splits[index + 1], splits[index + 1] - splits[index] | ||
) | ||
f" document indices in [{splits[index]}, {splits[index + 1]}) total of {splits[index + 1] - splits[index]} documents" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function build_train_valid_test_datasets.print_split_stats
refactored with the following changes:
- Replace call to format with f-string. (
use-fstring-for-formatting
)
for index, split in enumerate(splits): | ||
splits_index.append(splits_index[index] + int(round(split * float(size)))) | ||
splits_index.extend( | ||
splits_index[index] + int(round(split * float(size))) | ||
for index, split in enumerate(splits) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_train_valid_test_split_
refactored with the following changes:
- Replace a for append loop with list extend (
for-append-to-extend
)
weighted_num_samples = [] | ||
for weight in weights: | ||
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) | ||
weighted_num_samples = [ | ||
int(math.ceil(num_samples * weight * 1.005)) for weight in weights | ||
] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_normalized_weights_and_num_samples
refactored with the following changes:
- Convert for loop into list comprehension (
list-comprehension
)
"setting training data start iteration to {}".format( | ||
train_dataloader.batch_sampler.start_iter | ||
) | ||
f"setting training data start iteration to {train_dataloader.batch_sampler.start_iter}" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function build_train_valid_test_data_iterators
refactored with the following changes:
- Replace call to format with f-string. (
use-fstring-for-formatting
) - Swap if/else branches (
swap-if-else-branches
) - Replace if statement with if expression (
assign-if-exp
)
for i in range(doc_index_f + 1, doc_index_l): | ||
sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) | ||
sample_list.extend( | ||
self.indexed_dataset.get(self.doc_idx[i]) | ||
for i in range(doc_index_f + 1, doc_index_l) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function GPT2Dataset.__getitem__
refactored with the following changes:
- Replace a for append loop with list extend (
for-append-to-extend
)
doc_idx_filename = _filename + "_doc_idx.npy" | ||
sample_idx_filename = _filename + "_sample_idx.npy" | ||
shuffle_idx_filename = _filename + "_shuffle_idx.npy" | ||
_filename += f"_{name}_indexmap" | ||
_filename += f"_{num_samples}ns" | ||
_filename += f"_{seq_length}sl" | ||
_filename += f"_{seed}s" | ||
doc_idx_filename = f"{_filename}_doc_idx.npy" | ||
sample_idx_filename = f"{_filename}_sample_idx.npy" | ||
shuffle_idx_filename = f"{_filename}_shuffle_idx.npy" | ||
|
||
# Build the indexed mapping if not exist. | ||
if torch.distributed.get_rank() == 0: | ||
if ( | ||
if torch.distributed.get_rank() == 0 and ( | ||
( | ||
(not os.path.isfile(doc_idx_filename)) | ||
or (not os.path.isfile(sample_idx_filename)) | ||
or (not os.path.isfile(shuffle_idx_filename)) | ||
): | ||
print_rank_0( | ||
" > WARNING: could not find index map files, building " | ||
"the indices on rank 0 ..." | ||
) | ||
# doc-idx. | ||
start_time = time.time() | ||
doc_idx = _build_doc_idx(documents, num_epochs, np_rng) | ||
np.save(doc_idx_filename, doc_idx, allow_pickle=True) | ||
print_rank_0( | ||
" > elasped time to build and save doc-idx mapping " | ||
"(seconds): {:4f}".format(time.time() - start_time) | ||
) | ||
# sample-idx. | ||
start_time = time.time() | ||
# Use C++ implementation for speed. | ||
from megatron.data import helpers | ||
|
||
assert doc_idx.dtype == np.int32 | ||
assert sizes.dtype == np.int32 | ||
sample_idx = helpers.build_sample_idx( | ||
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch | ||
) | ||
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, | ||
# num_epochs, tokens_per_epoch) | ||
np.save(sample_idx_filename, sample_idx, allow_pickle=True) | ||
print_rank_0( | ||
" > elapsed time to build and save sample-idx mapping " | ||
"(seconds): {:4f}".format(time.time() - start_time) | ||
) | ||
# shuffle-idx. | ||
start_time = time.time() | ||
# -1 is due to data structure used to retieve the index: | ||
# sample i --> [sample_idx[i], sample_idx[i+1]) | ||
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) | ||
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) | ||
print_rank_0( | ||
" > elapsed time to build and save shuffle-idx mapping" | ||
" (seconds): {:4f}".format(time.time() - start_time) | ||
) | ||
) | ||
): | ||
print_rank_0( | ||
" > WARNING: could not find index map files, building " | ||
"the indices on rank 0 ..." | ||
) | ||
# doc-idx. | ||
start_time = time.time() | ||
doc_idx = _build_doc_idx(documents, num_epochs, np_rng) | ||
np.save(doc_idx_filename, doc_idx, allow_pickle=True) | ||
print_rank_0( | ||
" > elasped time to build and save doc-idx mapping " | ||
"(seconds): {:4f}".format(time.time() - start_time) | ||
) | ||
# sample-idx. | ||
start_time = time.time() | ||
# Use C++ implementation for speed. | ||
from megatron.data import helpers | ||
|
||
assert doc_idx.dtype == np.int32 | ||
assert sizes.dtype == np.int32 | ||
sample_idx = helpers.build_sample_idx( | ||
sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch | ||
) | ||
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, | ||
# num_epochs, tokens_per_epoch) | ||
np.save(sample_idx_filename, sample_idx, allow_pickle=True) | ||
print_rank_0( | ||
" > elapsed time to build and save sample-idx mapping " | ||
"(seconds): {:4f}".format(time.time() - start_time) | ||
) | ||
# shuffle-idx. | ||
start_time = time.time() | ||
# -1 is due to data structure used to retieve the index: | ||
# sample i --> [sample_idx[i], sample_idx[i+1]) | ||
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) | ||
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) | ||
print_rank_0( | ||
" > elapsed time to build and save shuffle-idx mapping" | ||
" (seconds): {:4f}".format(time.time() - start_time) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function _build_index_mappings
refactored with the following changes:
- Replace call to format with f-string. (
use-fstring-for-formatting
) - Use f-string instead of string concatenation (
use-fstring-for-concatenation
) - Merge nested if conditions (
merge-nested-ifs
)
if vocab_size is not None and vocab_size < 65500: | ||
return np.uint16 | ||
else: | ||
return np.int32 | ||
return np.uint16 if vocab_size is not None and vocab_size < 65500 else np.int32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function __best_fitting_dtype
refactored with the following changes:
- Replace if statement with if expression (
assign-if-exp
)
return prefix_path + ".idx" | ||
return f"{prefix_path}.idx" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function index_file_path
refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
return prefix_path + ".bin" | ||
return f"{prefix_path}.bin" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function data_file_path
refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True | ||
[f"{cuda_dir}/bin/nvcc", "-V"], universal_newlines=True | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function _get_cuda_bare_metal_version
refactored with the following changes:
- Use f-string instead of string concatenation (
use-fstring-for-concatenation
)
cc_flag.append("-gencode") | ||
cc_flag.append("arch=compute_80,code=sm_80") | ||
|
||
cc_flag.extend(("-gencode", "arch=compute_80,code=sm_80")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 26-28
refactored with the following changes:
- Merge consecutive list appends into a single extend (
merge-list-appends-into-extend
)
grads = [] | ||
assert hasattr( | ||
self.model, "stored_gradients" | ||
), "You might need to update DeeperSpeed" | ||
if self.model.stored_gradients is not None: | ||
grads = [] | ||
for g in self.model.stored_gradients: | ||
if g is not None and not g.isnan().any() and not g.isinf().any(): | ||
g = g.flatten().view(-1, 1) | ||
if self.cpu_offload: | ||
g = g.cpu() | ||
grads.append(g) | ||
else: | ||
if g is None or g.isnan().any() or g.isinf().any(): | ||
return None | ||
g = g.flatten().view(-1, 1) | ||
if self.cpu_offload: | ||
g = g.cpu() | ||
grads.append(g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function GradientNoiseScale.flatten_grads
refactored with the following changes:
- Move assignments closer to their usage (
move-assign
) - Swap if/else branches (
swap-if-else-branches
) - Remove unnecessary else after guard condition (
remove-unnecessary-else
)
if self.neox_args.is_pipe_parallel: | ||
# Since each model parallel GPU carries only part of the model, | ||
# make sure overflow flag is synced across all the pipe parallel GPUs | ||
overflow_gpu = torch.cuda.ByteTensor([is_overflow]) | ||
torch.distributed.all_reduce( | ||
overflow_gpu, | ||
op=torch.distributed.ReduceOp.MAX, | ||
group=self.mpu.get_pipe_parallel_group(), | ||
) | ||
overflow = overflow_gpu[0].item() | ||
else: | ||
overflow = is_overflow | ||
return overflow | ||
if not self.neox_args.is_pipe_parallel: | ||
return is_overflow | ||
# Since each model parallel GPU carries only part of the model, | ||
# make sure overflow flag is synced across all the pipe parallel GPUs | ||
overflow_gpu = torch.cuda.ByteTensor([is_overflow]) | ||
torch.distributed.all_reduce( | ||
overflow_gpu, | ||
op=torch.distributed.ReduceOp.MAX, | ||
group=self.mpu.get_pipe_parallel_group(), | ||
) | ||
return overflow_gpu[0].item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function GradientNoiseScale._sync_overflow
refactored with the following changes:
- Lift return into if (
lift-return-into-if
) - Swap if/else branches (
swap-if-else-branches
) - Remove unnecessary else after guard condition (
remove-unnecessary-else
)
is_overflow = self._sync_overflow(grad is None) | ||
if is_overflow: | ||
if is_overflow := self._sync_overflow(grad is None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function GradientNoiseScale._update
refactored with the following changes:
- Use named expression to simplify assignment and conditional (
use-named-expression
)
Sourcery Code Quality Report✅ Merging this PR will increase code quality in the affected files by 0.27%.
Here are some functions in these files that still need a tune-up:
Legend and ExplanationThe emojis denote the absolute quality of the code:
The 👍 and 👎 indicate whether the quality has improved or gotten worse with this pull request. Please see our documentation here for details on how these metrics are calculated. We are actively working on this report - lots more documentation and extra metrics to come! Help us improve this quality report! |
Branch
main
refactored by Sourcery.If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.
See our documentation here.
Run Sourcery locally
Reduce the feedback loop during development by using the Sourcery editor plugin:
Review changes via command line
To manually merge these changes, make sure you're on the
main
branch, then run:Help us improve this pull request!