Package weka.classifiers.functions
Class SGD
java.lang.Object
weka.classifiers.AbstractClassifier
weka.classifiers.RandomizableClassifier
weka.classifiers.functions.SGD
- All Implemented Interfaces:
Serializable,Cloneable,Classifier,UpdateableClassifier,Aggregateable<SGD>,BatchPredictor,CapabilitiesHandler,CapabilitiesIgnorer,CommandlineRunnable,OptionHandler,Randomizable,RevisionHandler
public class SGD
extends RandomizableClassifier
implements UpdateableClassifier, OptionHandler, Aggregateable<SGD>
Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression, squared loss, Huber loss and epsilon-insensitive loss linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.
For numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate. Valid options are:
For numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate. Valid options are:
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
- Version:
- $Revision: 15520 $
- Author:
- Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz), Mark Hall (mhall{[at]}pentaho{[dot]}com)
- See Also:
-
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final intThe epsilon insensitive loss functionstatic final intthe hinge loss function.static final intThe Huber loss functionstatic final intthe log loss function.static final intthe squared loss function.static final Tag[]Loss functions to choose fromFields inherited from class weka.classifiers.AbstractClassifier
BATCH_SIZE_DEFAULT, NUM_DECIMAL_PLACES_DEFAULT -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionAggregate an object with this onevoidbuildClassifier(Instances data) Method for building the classifier.double[]Computes the distribution for a given instanceReturns the tip text for this propertyReturns the tip text for this propertyReturns the tip text for this propertyReturns the tip text for this propertyvoidCall to complete the aggregation process.Returns default capabilities of the classifier.booleanGet whether normalization has been turned off.booleanGet whether global replacement of missing values has been disabled.intGet current number of epochsdoubleGet the epsilon threshold on the error for epsilon insensitive and Huber loss functionsdoubleGet the current value of lambdadoubleGet the learning rate.Get the current loss function.String[]Gets the current settings of the classifier.Returns the revision string.double[]Returns a string describing classifierReturns the tip text for this propertyReturns the tip text for this propertyReturns an enumeration describing the available options.Returns the tip text for this propertystatic voidMain method for testing this class.voidreset()Reset the classifier.voidsetDontNormalize(boolean m) Turn normalization off/on.voidsetDontReplaceMissing(boolean m) Turn global replacement of missing values off/on.voidsetEpochs(int e) Set the number of epochs to usevoidsetEpsilon(double e) Set the epsilon threshold on the error for epsilon insensitive and Huber loss functionsvoidsetLambda(double lambda) Set the value of lambda to usevoidsetLearningRate(double lr) Set the learning rate.voidsetLossFunction(SelectedTag function) Set the loss function to use.voidsetOptions(String[] options) Parses a given list of options.toString()Prints out the classifier.voidupdateClassifier(Instance instance) Updates the classifier with the given instance.Methods inherited from class weka.classifiers.RandomizableClassifier
getSeed, seedTipText, setSeedMethods inherited from class weka.classifiers.AbstractClassifier
batchSizeTipText, classifyInstance, debugTipText, distributionsForInstances, doNotCheckCapabilitiesTipText, forName, getBatchSize, getDebug, getDoNotCheckCapabilities, getNumDecimalPlaces, implementsMoreEfficientBatchPrediction, makeCopies, makeCopy, numDecimalPlacesTipText, postExecution, preExecution, run, runClassifier, setBatchSize, setDebug, setDoNotCheckCapabilities, setNumDecimalPlaces
-
Field Details
-
HINGE
public static final int HINGEthe hinge loss function.- See Also:
-
LOGLOSS
public static final int LOGLOSSthe log loss function.- See Also:
-
SQUAREDLOSS
public static final int SQUAREDLOSSthe squared loss function.- See Also:
-
EPSILON_INSENSITIVE
public static final int EPSILON_INSENSITIVEThe epsilon insensitive loss function- See Also:
-
HUBER
public static final int HUBERThe Huber loss function- See Also:
-
TAGS_SELECTION
Loss functions to choose from
-
-
Constructor Details
-
SGD
public SGD()
-
-
Method Details
-
getCapabilities
Returns default capabilities of the classifier.- Specified by:
getCapabilitiesin interfaceCapabilitiesHandler- Specified by:
getCapabilitiesin interfaceClassifier- Overrides:
getCapabilitiesin classAbstractClassifier- Returns:
- the capabilities of this classifier
- See Also:
-
epsilonTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setEpsilon
public void setEpsilon(double e) Set the epsilon threshold on the error for epsilon insensitive and Huber loss functions- Parameters:
e- the value of epsilon to use
-
getEpsilon
public double getEpsilon()Get the epsilon threshold on the error for epsilon insensitive and Huber loss functions- Returns:
- the value of epsilon to use
-
lambdaTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setLambda
public void setLambda(double lambda) Set the value of lambda to use- Parameters:
lambda- the value of lambda to use
-
getLambda
public double getLambda()Get the current value of lambda- Returns:
- the current value of lambda
-
setLearningRate
public void setLearningRate(double lr) Set the learning rate.- Parameters:
lr- the learning rate to use.
-
getLearningRate
public double getLearningRate()Get the learning rate.- Returns:
- the learning rate
-
learningRateTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
epochsTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setEpochs
public void setEpochs(int e) Set the number of epochs to use- Parameters:
e- the number of epochs to use
-
getEpochs
public int getEpochs()Get current number of epochs- Returns:
- the current number of epochs
-
setDontNormalize
public void setDontNormalize(boolean m) Turn normalization off/on.- Parameters:
m- true if normalization is to be disabled.
-
getDontNormalize
public boolean getDontNormalize()Get whether normalization has been turned off.- Returns:
- true if normalization has been disabled.
-
dontNormalizeTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setDontReplaceMissing
public void setDontReplaceMissing(boolean m) Turn global replacement of missing values off/on. If turned off, then missing values are effectively ignored.- Parameters:
m- true if global replacement of missing values is to be turned off.
-
getDontReplaceMissing
public boolean getDontReplaceMissing()Get whether global replacement of missing values has been disabled.- Returns:
- true if global replacement of missing values has been turned off
-
dontReplaceMissingTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setLossFunction
Set the loss function to use.- Parameters:
function- the loss function to use.
-
getLossFunction
Get the current loss function.- Returns:
- the current loss function.
-
lossFunctionTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
listOptions
Returns an enumeration describing the available options.- Specified by:
listOptionsin interfaceOptionHandler- Overrides:
listOptionsin classRandomizableClassifier- Returns:
- an enumeration of all the available options.
-
setOptions
Parses a given list of options. Valid options are:-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
- Specified by:
setOptionsin interfaceOptionHandler- Overrides:
setOptionsin classRandomizableClassifier- Parameters:
options- the list of options as an array of strings- Throws:
Exception- if an option is not supported
-
getOptions
Gets the current settings of the classifier.- Specified by:
getOptionsin interfaceOptionHandler- Overrides:
getOptionsin classRandomizableClassifier- Returns:
- an array of strings suitable for passing to setOptions
-
globalInfo
Returns a string describing classifier- Returns:
- a description suitable for displaying in the explorer/experimenter gui
-
reset
public void reset()Reset the classifier. -
buildClassifier
Method for building the classifier.- Specified by:
buildClassifierin interfaceClassifier- Parameters:
data- the set of training instances.- Throws:
Exception- if the classifier can't be built successfully.
-
updateClassifier
Updates the classifier with the given instance.- Specified by:
updateClassifierin interfaceUpdateableClassifier- Parameters:
instance- the new training instance to include in the model- Throws:
Exception- if the instance could not be incorporated in the model.
-
distributionForInstance
Computes the distribution for a given instance- Specified by:
distributionForInstancein interfaceClassifier- Overrides:
distributionForInstancein classAbstractClassifier- Parameters:
inst- the instance for which distribution is computed- Returns:
- the distribution
- Throws:
Exception- if the distribution can't be computed successfully
-
getWeights
public double[] getWeights() -
toString
Prints out the classifier. -
getRevision
Returns the revision string.- Specified by:
getRevisionin interfaceRevisionHandler- Overrides:
getRevisionin classAbstractClassifier- Returns:
- the revision
-
aggregate
Aggregate an object with this one- Specified by:
aggregatein interfaceAggregateable<SGD>- Parameters:
toAggregate- the object to aggregate- Returns:
- the result of aggregation
- Throws:
Exception- if the supplied object can't be aggregated for some reason
-
finalizeAggregation
Call to complete the aggregation process. Allows implementers to do any final processing based on how many objects were aggregated.- Specified by:
finalizeAggregationin interfaceAggregateable<SGD>- Throws:
Exception- if the aggregation can't be finalized for some reason
-
main
Main method for testing this class.
-