Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
[SPARK-32310] Add *args to different Params constructors (#515)
Browse files Browse the repository at this point in the history
* Add *args to different Params constructors

Resolves #441
  • Loading branch information
zero323 committed Sep 6, 2020
1 parent 160b33e commit 56aab84
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 5 deletions.
11 changes: 9 additions & 2 deletions third_party/3/pyspark/ml/classification.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class _LogisticRegressionParams(
upperBoundsOnCoefficients: Param[Matrix]
lowerBoundsOnIntercepts: Param[Vector]
upperBoundsOnIntercepts: Param[Vector]
def __init__(self, *args: Any): ...
def setThreshold(self: P, value: float) -> P: ...
def getThreshold(self) -> float: ...
def setThresholds(self: P, value: List[float]) -> P: ...
Expand Down Expand Up @@ -371,7 +372,9 @@ class BinaryLogisticRegressionSummary(
class BinaryLogisticRegressionTrainingSummary(
BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
): ...
class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams): ...

class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
def __init__(self, *args: Any): ...

class DecisionTreeClassifier(
_JavaProbabilisticClassifier[DecisionTreeClassificationModel],
Expand Down Expand Up @@ -443,7 +446,8 @@ class DecisionTreeClassificationModel(
@property
def featureImportances(self) -> Vector: ...

class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams): ...
class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
def __init__(self, *args: Any): ...

class RandomForestClassifier(
_JavaProbabilisticClassifier[RandomForestClassificationModel],
Expand Down Expand Up @@ -544,6 +548,7 @@ class BinaryRandomForestClassificationTrainingSummary(
class _GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
supportedLossTypes: List[str]
lossType: Param[str]
def __init__(self, *args: Any): ...
def getLossType(self) -> str: ...

class GBTClassifier(
Expand Down Expand Up @@ -636,6 +641,7 @@ class GBTClassificationModel(
class _NaiveBayesParams(_PredictorParams, HasWeightCol):
smoothing: Param[float]
modelType: Param[str]
def __init__(self, *args: Any): ...
def getSmoothing(self) -> float: ...
def getModelType(self) -> str: ...

Expand Down Expand Up @@ -702,6 +708,7 @@ class _MultilayerPerceptronParams(
layers: Param[List[int]]
solver: Param[str]
initialWeights: Param[Vector]
def __init__(self, *args: Any): ...
def getLayers(self) -> List[int]: ...
def getInitialWeights(self) -> Vector: ...

Expand Down
5 changes: 5 additions & 0 deletions third_party/3/pyspark/ml/clustering.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class _GaussianMixtureParams(
HasBlockSize,
):
k: Param[int]
def __init__(self, *args: Any): ...
def getK(self) -> int: ...

class GaussianMixtureModel(
Expand Down Expand Up @@ -149,6 +150,7 @@ class _KMeansParams(
k: Param[int]
initMode: Param[str]
initSteps: Param[int]
def __init__(self, *args: Any): ...
def getK(self) -> int: ...
def getInitMode(self) -> str: ...
def getInitSteps(self) -> int: ...
Expand Down Expand Up @@ -219,6 +221,7 @@ class _BisectingKMeansParams(
):
k: Param[int]
minDivisibleClusterSize: Param[float]
def __init__(self, *args: Any): ...
def getK(self) -> int: ...
def getMinDivisibleClusterSize(self) -> float: ...

Expand Down Expand Up @@ -291,6 +294,7 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval):
topicConcentration: Param[float]
topicDistributionCol: Param[str]
keepLastCheckpoint: Param[bool]
def __init__(self, *args: Any): ...
def setK(self, value: int) -> LDA: ...
def getOptimizer(self) -> str: ...
def getLearningOffset(self) -> float: ...
Expand Down Expand Up @@ -381,6 +385,7 @@ class _PowerIterationClusteringParams(HasMaxIter, HasWeightCol):
initMode: Param[str]
srcCol: Param[str]
dstCol: Param[str]
def __init__(self, *args: Any): ...
def getK(self) -> int: ...
def getInitMode(self) -> str: ...
def getSrcCol(self) -> str: ...
Expand Down
10 changes: 10 additions & 0 deletions third_party/3/pyspark/ml/feature.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class Binarizer(

class _LSHParams(HasInputCol, HasOutputCol):
numHashTables: Param[int]
def __init__(self, *args: Any): ...
def getNumHashTables(self) -> int: ...

class _LSH(Generic[JM], JavaEstimator[JM], _LSHParams, JavaMLReadable, JavaMLWritable):
Expand Down Expand Up @@ -386,6 +387,7 @@ class HashingTF(

class _IDFParams(HasInputCol, HasOutputCol):
minDocFreq: Param[int]
def __init__(self, *args: Any): ...
def getMinDocFreq(self) -> int: ...

class IDF(JavaEstimator[IDFModel], _IDFParams, JavaMLReadable[IDF], JavaMLWritable):
Expand Down Expand Up @@ -558,6 +560,7 @@ class MinHashLSHModel(_LSHModel, JavaMLReadable[MinHashLSHModel], JavaMLWritable
class _MinMaxScalerParams(HasInputCol, HasOutputCol):
min: Param[float]
max: Param[float]
def __init__(self, *args: Any): ...
def getMin(self) -> float: ...
def getMax(self) -> float: ...

Expand Down Expand Up @@ -653,6 +656,7 @@ class Normalizer(
class _OneHotEncoderParams(HasInputCols, HasOutputCols, HasHandleInvalid):
handleInvalid: Param[str]
dropLast: Param[bool]
def __init__(self, *args: Any): ...
def getDropLast(self) -> bool: ...

class OneHotEncoder(
Expand Down Expand Up @@ -813,6 +817,7 @@ class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
upper: Param[float]
withCentering: Param[bool]
withScaling: Param[bool]
def __init__(self, *args: Any): ...
def getLower(self) -> float: ...
def getUpper(self) -> float: ...
def getWithCentering(self) -> bool: ...
Expand Down Expand Up @@ -913,6 +918,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable[SQLTransformer], JavaMLWrit
class _StandardScalerParams(HasInputCol, HasOutputCol):
withMean: Param[bool]
withStd: Param[bool]
def __init__(self, *args: Any): ...
def getWithMean(self) -> bool: ...
def getWithStd(self) -> bool: ...

Expand Down Expand Up @@ -1178,6 +1184,7 @@ class VectorAssembler(
class _VectorIndexerParams(HasInputCol, HasOutputCol, HasHandleInvalid):
maxCategories: Param[int]
handleInvalid: Param[str]
def __init__(self, *args: Any): ...
def getMaxCategories(self) -> int: ...

class VectorIndexer(
Expand Down Expand Up @@ -1256,6 +1263,7 @@ class _Word2VecParams(HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCo
minCount: Param[int]
windowSize: Param[int]
maxSentenceLength: Param[int]
def __init__(self, *args: Any): ...
def getVectorSize(self) -> int: ...
def getNumPartitions(self) -> int: ...
def getMinCount(self) -> int: ...
Expand Down Expand Up @@ -1358,6 +1366,7 @@ class _RFormulaParams(HasFeaturesCol, HasLabelCol, HasHandleInvalid):
forceIndexLabel: Param[bool]
stringIndexerOrderType: Param[str]
handleInvalid: Param[str]
def __init__(self, *args: Any): ...
def getFormula(self) -> str: ...
def getForceIndexLabel(self) -> bool: ...
def getStringIndexerOrderType(self) -> str: ...
Expand Down Expand Up @@ -1406,6 +1415,7 @@ class _SelectorParams(HasFeaturesCol, HasOutputCol, HasLabelCol):
fpr: Param[float]
fdr: Param[float]
fwe: Param[float]
def __init__(self, *args: Any): ...
def getSelectorType(self) -> str: ...
def getNumTopFeatures(self) -> int: ...
def getPercentile(self) -> float: ...
Expand Down
3 changes: 2 additions & 1 deletion third_party/3/pyspark/ml/fpm.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# Stubs for pyspark.ml.base (Python 3)
#

from typing import Optional
from typing import Any, Optional

from pyspark.ml._typing import P
from pyspark.ml.util import *
Expand All @@ -32,6 +32,7 @@ class _FPGrowthParams(HasPredictionCol):
minSupport: Param[float]
numPartitions: Param[int]
minConfidence: Param[float]
def __init__(self, *args: Any): ...
def getItemsCol(self) -> str: ...
def getMinSupport(self) -> float: ...
def getNumPartitions(self) -> int: ...
Expand Down
1 change: 1 addition & 0 deletions third_party/3/pyspark/ml/recommendation.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class _ALSParams(
nonnegative: Param[bool]
intermediateStorageLevel: Param[str]
finalStorageLevel: Param[str]
def __init__(self, *args: Any): ...
def getRank(self) -> int: ...
def getNumUserBlocks(self) -> int: ...
def getNumItemBlocks(self) -> int: ...
Expand Down
12 changes: 10 additions & 2 deletions third_party/3/pyspark/ml/regression.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class _LinearRegressionParams(
solver: Param[str]
loss: Param[str]
epsilon: Param[float]
def __init__(self, *args: Any): ...
def getEpsilon(self) -> float: ...

class LinearRegression(
Expand Down Expand Up @@ -251,7 +252,8 @@ class IsotonicRegressionModel(

class _DecisionTreeRegressorParams(
_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol
): ...
):
def __init__(self, *args: Any): ...

class DecisionTreeRegressor(
_JavaRegressor[DecisionTreeRegressionModel],
Expand Down Expand Up @@ -323,7 +325,8 @@ class DecisionTreeRegressionModel(
@property
def featureImportances(self) -> Vector: ...

class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams): ...
class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
def __init__(self, *args: Any): ...

class RandomForestRegressor(
_JavaRegressor[RandomForestRegressionModel],
Expand Down Expand Up @@ -406,6 +409,7 @@ class RandomForestRegressionModel(
class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
supportedLossTypes: List[str]
lossType: Param[str]
def __init__(self, *args: Any): ...
def getLossType(self) -> str: ...

class GBTRegressor(
Expand Down Expand Up @@ -508,6 +512,7 @@ class _AFTSurvivalRegressionParams(
censorCol: Param[str]
quantileProbabilities: Param[List[float]]
quantilesCol: Param[str]
def __init__(self, *args: Any): ...
def getCensorCol(self) -> str: ...
def getQuantileProbabilities(self) -> List[float]: ...
def getQuantilesCol(self) -> str: ...
Expand Down Expand Up @@ -593,6 +598,7 @@ class _GeneralizedLinearRegressionParams(
linkPower: Param[float]
solver: Param[str]
offsetCol: Param[str]
def __init__(self, *args: Any): ...
def getFamily(self) -> str: ...
def getLinkPredictionCol(self) -> str: ...
def getLink(self) -> str: ...
Expand Down Expand Up @@ -722,12 +728,14 @@ class _FactorizationMachinesParams(
HasSeed,
HasFitIntercept,
HasRegParam,
HasWeightCol,
):
factorSize: Param[int]
fitLinear: Param[bool]
miniBatchFraction: Param[float]
initStd: Param[float]
solver: Param[str]
def __init__(self, *args: Any): ...
def getFactorSize(self): ...
def getFitLinear(self): ...
def getMiniBatchFraction(self): ...
Expand Down
2 changes: 2 additions & 0 deletions third_party/3/pyspark/ml/tuning.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class _ValidatorParams(HasSeed):
class _CrossValidatorParams(_ValidatorParams):
numFolds: Param[int]
foldCol: Param[str]
def __init__(self, *args: Any): ...
def getNumFolds(self) -> int: ...
def getFoldCol(self) -> str: ...

Expand Down Expand Up @@ -115,6 +116,7 @@ class CrossValidatorModel(

class _TrainValidationSplitParams(_ValidatorParams):
trainRatio: Param[float]
def __init__(self, *args: Any): ...
def getTrainRatio(self) -> float: ...

class TrainValidationSplit(
Expand Down

0 comments on commit 56aab84

Please sign in to comment.