-
Notifications
You must be signed in to change notification settings - Fork 392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Write and read Spark stages to/from MLeap instead of Spark classes #475
Conversation
Codecov Report
@@ Coverage Diff @@
## master #475 +/- ##
==========================================
- Coverage 87.04% 86.74% -0.31%
==========================================
Files 346 346
Lines 11782 11848 +66
Branches 385 374 -11
==========================================
+ Hits 10256 10277 +21
- Misses 1526 1571 +45
Continue to review full report at Codecov.
|
@TuanNguyen27 the test that you put in that should have failed on the local XGboost is (correctly) failing in this PR. |
.setParent(this) | ||
.setInput(in1.asFeatureLike[RealNN], in2.asFeatureLike[OPVector]) | ||
.setMetadata(getMetadata()) | ||
.setOutputFeatureName(getOutputFeatureName) | ||
|
||
if (model.isInstanceOf[XGBoostClassificationModel] || model.isInstanceOf[XGBoostRegressionModel]) { | ||
wrappedModel.setOutputDF(model.transform(dataset.limit(1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious why do we have .limit(1)
here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we only need one example for the xgboost mleap save (it has a step that calls .first() to get the vector size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment for it. Looks like this is the only such exception so far.
core/src/main/scala/com/salesforce/op/stages/sparkwrappers/specific/OpPredictorWrapper.scala
Show resolved
Hide resolved
@@ -125,4 +127,9 @@ class OpRandomForestRegressionModel | |||
ttov: TypeTag[Prediction#Value] | |||
) extends OpPredictionModel[RandomForestRegressionModel]( | |||
sparkModel = sparkModel, uid = uid, operationName = operationName | |||
) | |||
){ | |||
@transient lazy protected val predict: Vector => Double = getSparkMlStage().map(s => s.predict(_)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also seems to be a very repetitive pattern. We can add a helper method for it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the problem with a helper function in this one is there is no shared class for the mleap regressors that contains the predict function. they all implement it but you have to cast them to their specific type to get the predict. Thus a helper function only saves one map and makes it hard to read. I suppose I could use reflection in a shared helper...do you think that is better?
core/src/main/scala/com/salesforce/op/stages/sparkwrappers/specific/OpPredictorWrapper.scala
Outdated
Show resolved
Hide resolved
...n/scala/com/salesforce/op/stages/sparkwrappers/specific/OpProbabilisticClassifierModel.scala
Show resolved
Hide resolved
(opStage, sparkStage, i) | ||
val mleapStages = stagesWithIndex.filterNot(_._1.isInstanceOf[OpTransformer]).collect { | ||
case (opStage: OPStage with SparkWrapperParams[_], i) if opStage.getLocalMlStage().isDefined => | ||
val model = opStage.getLocalMlStage().get |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better pattern match and error gracefully when local model is missing. For example:
opStage.getLocalMlStage() match {
case None => throw new RuntimeException("Local model not found for stage ${opStage.uid} of type ${opStage.getClass}")
case Some(model) =>
// Apply model
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!!
🥳 🥳 🥳 |
this seems to have broken some of our inhouse unit tests. in some cases it was because we wrote to relative paths i think. those were easily fixed by making paths absolute. in other situations the paths were absolute and i am unsure why it broke at this point... stacktraces all have to do with mleap BundleFile on reading and writing. always the same NPE in
|
is protobuf 3 going to be an issue on spark/hadoop? |
@koertkuipers Can you please open an issue to track this? Can you also share which transformer / estimator are you using in your workflow? |
… On Mon, Sep 21, 2020 at 7:52 PM Matthew Tovbin ***@***.***> wrote:
@koertkuipers <https://github.com/koertkuipers> Can you please open an
issue to track this? Can you also share which transformer / estimator are
you using in your workflow?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#475 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGIQJE22J6KKLMXXQWMUODSG7RKZANCNFSM4MZ6P4JQ>
.
|
Thanks for the contribution! It looks like @leahmcguire is an internal user so signing the CLA is not required. However, we need to confirm this. |
Thanks for the contribution! Unfortunately we can't verify the commit author(s): leahmcguire <l***@s***.com> Leah McGuire <l***@s***.com>. One possible solution is to add that email to your GitHub account. Alternatively you can change your commits to another email and force push the change. After getting your commits associated with your GitHub account, refresh the status of this Pull Request. |
Related issues
Currently, Spark save method is used to serialize and deserialize Spark wrapped stages. This PR changes the underlying serialization to write and read from MLeap bundles.
Describe the proposed solution
Writes to MLeap and reads from MLeap with fallback to trying to read from Spark save.
Describe alternatives you've considered
N/A
Additional context
Next steps will be PR's to read the stages directly with the MLeap context rather than the Spark context for local scoring (and possibly all scoring - to better optimize the DAG)