-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT MERGE] Goodput prototype v1 #337
base: main
Are you sure you want to change the base?
Conversation
@@ -0,0 +1,127 @@ | |||
""" | |||
Goodput Prototype (TODO: Move to packaged library) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why is this writing to GCS? Is there some metric sservice we can use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Writing to GCS is a lot more complicated that it seems so we'd need to think about that carefully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, I think we can readily use Cloud Logging to stream logs: https://cloud.google.com/logging/docs/write-query-log-entries-python#linux
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excited for goodput! But lots to discuss here.
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): | ||
state, metrics, nextrng = p_train_step( | ||
state, example_batch, nextrng | ||
) | ||
step_stop_time = datetime.datetime.now() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code doesn't do what you think it does -- Jax calls are non blocking. I think it would better to provide a goodput.goodput_log_step_time call that took step time (and inside of it you can get a datetime).
@@ -0,0 +1,127 @@ | |||
""" | |||
Goodput Prototype (TODO: Move to packaged library) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Writing to GCS is a lot more complicated that it seems so we'd need to think about that carefully.
storage_client = storage.Client() | ||
bucket = storage_client.bucket(bucket_name) | ||
blob = bucket.blob(blob_name) | ||
blob.download_to_filename(destination_file_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW this type of logic is unwelcome on our control thread! We need to be writing asynchronously to something that gracefully handles appends.
(GCS really doesn't like appends. And this code seems to be synchronous on our control thread.)
with diagnostic.diagnose(diagnostic_config): | ||
train_loop(config) | ||
|
||
job_end_time = datetime.datetime.now() | ||
goodput.goodput_log_job_runtime(config, job_end_time - job_start_time) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also you really don't want to pass a config outside of MaxText into your generic interface! Pass what you need!
job_end_time = datetime.datetime.now() | ||
goodput.goodput_log_job_runtime(config, job_end_time - job_start_time) | ||
total_goodput = goodput.get_job_goodput(config) | ||
max_logging.log(f"Total job goodput: {total_goodput:.2f}%") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure we want to print this? Because this run doesn't know the overall goodput -- the goodput depends on fusing multiple runs together.
Quick prototype to compute Goodput based on total step time and job time on MaxText.