-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-class classification training limit #414
Multi-class classification training limit #414
Conversation
… that inherit from splitter can use them
…ogrifAI into LocoTestRefactor merge master
…AdamChit/TransmogrifAI into achit/regression-training-limit
…plitterTest.scala Co-Authored-By: Christopher Rupley <[email protected]>
…egressionModelSelectorTest.scala Co-Authored-By: Christopher Rupley <[email protected]>
…AdamChit/TransmogrifAI into achit/regression-training-limit
…params to DataCutterSummary
…et/get fuctions moved to SplitterParams
Codecov Report
@@ Coverage Diff @@
## master #414 +/- ##
==========================================
- Coverage 86.96% 86.92% -0.04%
==========================================
Files 337 337
Lines 11083 11099 +16
Branches 356 593 +237
==========================================
+ Hits 9638 9648 +10
- Misses 1445 1451 +6
Continue to review full report at Codecov.
|
summary = Option(DataCutterSummary( | ||
preSplitterDataCount = dataSetSize, | ||
downSamplingFraction = getDownSampleFraction, | ||
labelsKept = getLabelsToKeep, | ||
labelsDropped = getLabelsToDrop, | ||
labelsDroppedTotal = getLabelsDroppedTotal | ||
)) | ||
PrevalidationVal(summary, Option(dataPrep)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it makes more sense to do the downsampling in the pre-validationPrepare than in the validation prepare - the difference being that the validation prepare is called within the CV folds. So since you upsample for binary it needs to be here to prevent label leakage but since there is only downsampling here it can go earlier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 2 main reasons behind doing it this way:
- Keeping it consistent with dataSplitter/dataBalancer
- Easier to implement stratified sampling (or other data balancing techniques) in the future which would upsample the minority classes
@@ -203,7 +233,11 @@ class DataCutter(uid: String = UID[DataCutter]) extends Splitter(uid = uid) with | |||
s" minLabelFraction = $minLabelFract, maxLabelCategories = $maxLabels. \n" + | |||
s"Label counts were: ${labelCounts.collect().toSeq}") | |||
} | |||
DataCutterSummary(labelsKept.toSeq, labelsDropped.toSeq, labelsDroppedTotal.toLong) | |||
DataCutterSummary( | |||
labelsKept = labelsKept.toSeq, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to add the downsample fraction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see minor comments
val dataPrep = super.validationPrepare(data) | ||
|
||
// check if down sampling is needed | ||
val balanced: DataFrame = if (getDownSampleFraction < 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use 1.0
@@ -273,6 +307,8 @@ private[impl] trait DataCutterParams extends SplitterParams { | |||
*/ | |||
case class DataCutterSummary | |||
( | |||
preSplitterDataCount: Long = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use 0L
@@ -129,6 +129,23 @@ trait SplitterParams extends Params { | |||
def setReserveTestFraction(value: Double): this.type = set(reserveTestFraction, value) | |||
def getReserveTestFraction: Double = $(reserveTestFraction) | |||
|
|||
/** | |||
* Fraction to sample majority data | |||
* Value should be in ]0.0, 1.0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(0.0, 1.0]
@@ -65,6 +68,8 @@ class DataCutterTest extends FlatSpec with TestSparkContext with SplitterSummary | |||
s.labelsKept.length shouldBe 1000 | |||
s.labelsDropped.length shouldBe 0 | |||
s shouldBe DataCutterSummary( | |||
preSplitterDataCount = dataSize, | |||
downSamplingFraction = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use 1.0
instead of 1
here and everywhere with Double values
…/AdamChit/TransmogrifAI into achit/multi-class-training-limit
Related issues
DataBalancer for binary classification has a parameter that controls the max data passed into modeling - multiclass classification should allow similar limits
Describe the proposed solution
The solution is to downsample once we reach the training set limit
Describe alternatives you've considered
N/A
Additional context
similar to #413