Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-class classification training limit #414

Merged
merged 46 commits into from
Nov 1, 2019

Conversation

AdamChit
Copy link
Collaborator

@AdamChit AdamChit commented Oct 8, 2019

Related issues
DataBalancer for binary classification has a parameter that controls the max data passed into modeling - multiclass classification should allow similar limits

Describe the proposed solution
The solution is to downsample once we reach the training set limit

Describe alternatives you've considered
N/A

Additional context
similar to #413

AdamChit and others added 30 commits September 20, 2019 15:46
…egressionModelSelectorTest.scala

Co-Authored-By: Christopher Rupley <[email protected]>
@codecov
Copy link

codecov bot commented Oct 8, 2019

Codecov Report

Merging #414 into master will decrease coverage by 0.03%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #414      +/-   ##
==========================================
- Coverage   86.96%   86.92%   -0.04%     
==========================================
  Files         337      337              
  Lines       11083    11099      +16     
  Branches      356      593     +237     
==========================================
+ Hits         9638     9648      +10     
- Misses       1445     1451       +6
Impacted Files Coverage Δ
...alesforce/op/stages/impl/tuning/DataBalancer.scala 95.95% <ø> (-0.16%) ⬇️
...alesforce/op/stages/impl/tuning/DataSplitter.scala 65% <ø> (-25%) ⬇️
...om/salesforce/op/stages/impl/tuning/Splitter.scala 98.3% <100%> (+0.22%) ⬆️
.../salesforce/op/stages/impl/tuning/DataCutter.scala 97.22% <100%> (+0.38%) ⬆️
.../op/features/types/FeatureTypeSparkConverter.scala 98.23% <0%> (-0.89%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0f03c43...7b23861. Read the comment docs.

summary = Option(DataCutterSummary(
preSplitterDataCount = dataSetSize,
downSamplingFraction = getDownSampleFraction,
labelsKept = getLabelsToKeep,
labelsDropped = getLabelsToDrop,
labelsDroppedTotal = getLabelsDroppedTotal
))
PrevalidationVal(summary, Option(dataPrep))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes more sense to do the downsampling in the pre-validationPrepare than in the validation prepare - the difference being that the validation prepare is called within the CV folds. So since you upsample for binary it needs to be here to prevent label leakage but since there is only downsampling here it can go earlier

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2 main reasons behind doing it this way:

  1. Keeping it consistent with dataSplitter/dataBalancer
  2. Easier to implement stratified sampling (or other data balancing techniques) in the future which would upsample the minority classes

@@ -203,7 +233,11 @@ class DataCutter(uid: String = UID[DataCutter]) extends Splitter(uid = uid) with
s" minLabelFraction = $minLabelFract, maxLabelCategories = $maxLabels. \n" +
s"Label counts were: ${labelCounts.collect().toSeq}")
}
DataCutterSummary(labelsKept.toSeq, labelsDropped.toSeq, labelsDroppedTotal.toLong)
DataCutterSummary(
labelsKept = labelsKept.toSeq,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to add the downsample fraction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to have to pass the dataset (or the dataset count) to the estimate function. So I added the dataset count and downsample fraction in the summary variable here.

Copy link
Collaborator

@tovbinm tovbinm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see minor comments

val dataPrep = super.validationPrepare(data)

// check if down sampling is needed
val balanced: DataFrame = if (getDownSampleFraction < 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use 1.0

@@ -273,6 +307,8 @@ private[impl] trait DataCutterParams extends SplitterParams {
*/
case class DataCutterSummary
(
preSplitterDataCount: Long = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use 0L

@@ -129,6 +129,23 @@ trait SplitterParams extends Params {
def setReserveTestFraction(value: Double): this.type = set(reserveTestFraction, value)
def getReserveTestFraction: Double = $(reserveTestFraction)

/**
* Fraction to sample majority data
* Value should be in ]0.0, 1.0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(0.0, 1.0]

@@ -65,6 +68,8 @@ class DataCutterTest extends FlatSpec with TestSparkContext with SplitterSummary
s.labelsKept.length shouldBe 1000
s.labelsDropped.length shouldBe 0
s shouldBe DataCutterSummary(
preSplitterDataCount = dataSize,
downSamplingFraction = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use 1.0 instead of 1 here and everywhere with Double values

@AdamChit AdamChit merged commit c7f363f into salesforce:master Nov 1, 2019
@nicodv nicodv mentioned this pull request Jun 11, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants