public class SGD extends AbstractClassifier implements Regressor
| Modifier and Type | Field and Description |
|---|---|
protected static int |
HINGE |
FloatOption |
lambdaRegularizationOption |
FloatOption |
learningRateOption |
protected static int |
LOGLOSS |
MultiChoiceOption |
lossFunctionOption |
protected double |
m_bias |
protected double |
m_lambda
The regularization parameter
|
protected double |
m_learningRate
The learning rate
|
protected int |
m_loss
The current loss function to minimize
|
protected double |
m_numInstances
The number of training instances
|
protected double |
m_t
Holds the current iteration number
|
protected DoubleVector |
m_weights
Stores the weights (+ bias in the last element)
|
protected static int |
SQUAREDLOSS |
classifierRandom, modelContext, randomSeed, randomSeedOption, trainingWeightSeenByModelclassOptionNamesToPreparedObjects, options| Constructor and Description |
|---|
SGD() |
| Modifier and Type | Method and Description |
|---|---|
protected double |
dloss(double z) |
protected static double |
dotProd(weka.core.Instance inst1,
DoubleVector weights,
int classIndex) |
double |
getLambda()
Get the current value of lambda
|
double |
getLearningRate()
Get the learning rate.
|
int |
getLossFunction()
Get the current loss function.
|
void |
getModelDescription(StringBuilder result,
int indent)
Returns a string representation of the model.
|
protected Measurement[] |
getModelMeasurementsImpl()
Gets the current measurements of this classifier.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods. |
String |
getPurposeString()
Gets the purpose of this object
|
double[] |
getVotesForInstance(weka.core.Instance inst)
Calculates the class membership probabilities for the given test
instance.
|
boolean |
isRandomizable()
Gets whether this classifier needs a random seed.
|
void |
reset()
Reset the classifier.
|
void |
resetLearningImpl()
Resets this classifier.
|
void |
setLambda(double lambda)
Set the value of lambda to use
|
void |
setLearningRate(double lr)
Set the learning rate.
|
void |
setLossFunction(int function)
Set the loss function to use.
|
String |
toString()
Prints out the classifier.
|
void |
trainOnInstanceImpl(weka.core.Instance instance)
Trains the classifier with the given instance.
|
contextIsCompatible, copy, correctlyClassifies, getAttributeNameString, getAWTRenderer, getClassLabelString, getClassNameString, getDescription, getModelContext, getModelMeasurements, getNominalValueString, getSubClassifiers, modelAttIndexToInstanceAttIndex, modelAttIndexToInstanceAttIndex, prepareForUseImpl, resetLearning, setModelContext, setRandomSeed, trainingHasStarted, trainingWeightSeenByModel, trainOnInstancediscoverOptionsViaReflection, getCLICreationString, getOptions, getPreparedClassOption, getPreparedClassOption, prepareClassOptions, prepareForUse, prepareForUsecopy, measureByteSize, measureByteSizeclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitgetCLICreationString, getOptions, prepareForUse, prepareForUsemeasureByteSizeprotected double m_lambda
public FloatOption lambdaRegularizationOption
protected double m_learningRate
public FloatOption learningRateOption
protected DoubleVector m_weights
protected double m_bias
protected double m_t
protected double m_numInstances
protected static final int HINGE
protected static final int LOGLOSS
protected static final int SQUAREDLOSS
protected int m_loss
public MultiChoiceOption lossFunctionOption
public String getPurposeString()
OptionHandlergetPurposeString in interface OptionHandlergetPurposeString in class AbstractClassifierpublic void setLambda(double lambda)
lambda - the value of lambda to usepublic double getLambda()
public void setLossFunction(int function)
function - the loss function to use.public int getLossFunction()
public void setLearningRate(double lr)
lr - the learning rate to use.public double getLearningRate()
public void reset()
protected double dloss(double z)
protected static double dotProd(weka.core.Instance inst1,
DoubleVector weights,
int classIndex)
public void resetLearningImpl()
AbstractClassifierresetLearningImpl in class AbstractClassifierpublic void trainOnInstanceImpl(weka.core.Instance instance)
trainOnInstanceImpl in class AbstractClassifierinstance - the new training instance to include in the modelpublic double[] getVotesForInstance(weka.core.Instance inst)
getVotesForInstance in interface Classifierinstance - the instance to be classifiedpublic void getModelDescription(StringBuilder result, int indent)
AbstractClassifiergetModelDescription in class AbstractClassifierresult - the stringbuilder to add the descriptionindent - the number of characters to indentpublic String toString()
toString in class AbstractMOAObjectprotected Measurement[] getModelMeasurementsImpl()
AbstractClassifiergetModelMeasurementsImpl in class AbstractClassifierpublic boolean isRandomizable()
ClassifierisRandomizable in interface ClassifierCopyright © 2014 University of Waikato, Hamilton, NZ. All Rights Reserved.