-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit test to check backward function for conv, checks there is no graph breaks #1709
Conversation
Signed-off-by: Xavier Dupre <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1709 +/- ##
==========================================
- Coverage 74.70% 74.65% -0.05%
==========================================
Files 242 243 +1
Lines 25862 25933 +71
Branches 4661 4678 +17
==========================================
+ Hits 19319 19360 +41
- Misses 5674 5699 +25
- Partials 869 874 +5 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Xavier Dupre <[email protected]>
def train_loop( | ||
model: Any, | ||
*args, | ||
loss_fn: Any | None = None, | ||
optimizer: Any | None = None, | ||
dump_onnx_models: bool = False, | ||
dump_prefix: str = "dump_train_loop", | ||
dump_clean_first: bool = True, | ||
) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]: |
Check notice
Code scanning / CodeQL
Returning tuples with varying lengths Note
tuple of size 2
tuple of size 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the OpInfo data structure, I have seen a field that says supports_grad or something which may make it easier for us to generate backward tests. @xiaowuhu do you have some ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems be a different scenario than the OpInfo way. Here, we need to go through the aot-compile-training-backward process which is an e2e scenario, although it is not a straight forward way. But this requirement will only benefit not more than 20 backward functions, so I think it is OK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG. Thanks!
Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
Signed-off-by: Xavier Dupre <[email protected]>
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking | ||
model, *input_tensors | ||
) | ||
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( |
Check warning
Code scanning / lintrunner
RUFF/F841 Warning
See https://docs.astral.sh/ruff/rules/unused-variable
expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( # pylint: disable=unbalanced-tuple-unpacking | ||
model, *input_tensors | ||
) | ||
results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( |
Check warning
Code scanning / lintrunner
PYLINT/W0612 Warning
See unused-variable. To disable, use # pylint: disable=unused-variable
@xadupre do we still need this, or can we close it? |
No description provided.