public class SGDMultiClass extends AbstractClassifier implements Regressor
Modifier and Type | Field and Description |
---|---|
protected static int |
HINGE |
FloatOption |
lambdaRegularizationOption |
FloatOption |
learningRateOption |
protected static int |
LOGLOSS |
MultiChoiceOption |
lossFunctionOption |
protected double[] |
m_bias |
protected double |
m_lambda
The regularization parameter
|
protected double |
m_learningRate
The learning rate
|
protected int |
m_loss
The current loss function to minimize
|
protected double |
m_numInstances
The number of training instances
|
protected double |
m_t
Holds the current iteration number
|
protected DoubleVector[] |
m_weights
Stores the weights (+ bias in the last element)
|
protected static int |
SQUAREDLOSS |
classifierRandom, modelContext, randomSeed, randomSeedOption, trainingWeightSeenByModel
classOptionNamesToPreparedObjects, options
Constructor and Description |
---|
SGDMultiClass() |
Modifier and Type | Method and Description |
---|---|
protected double |
dloss(double z) |
protected static double |
dotProd(weka.core.Instance inst1,
DoubleVector weights,
int classIndex) |
double |
getLambda()
Get the current value of lambda
|
double |
getLearningRate()
Get the learning rate.
|
int |
getLossFunction()
Get the current loss function.
|
void |
getModelDescription(StringBuilder result,
int indent)
Returns a string representation of the model.
|
protected Measurement[] |
getModelMeasurementsImpl()
Gets the current measurements of this classifier.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods. |
String |
getPurposeString()
Gets the purpose of this object
|
double[] |
getVotesForInstance(weka.core.Instance inst)
Calculates the class membership probabilities for the given test
instance.
|
boolean |
isRandomizable()
Gets whether this classifier needs a random seed.
|
void |
reset()
Reset the classifier.
|
void |
resetLearningImpl()
Resets this classifier.
|
void |
setLambda(double lambda)
Set the value of lambda to use
|
void |
setLearningRate(double lr)
Set the learning rate.
|
void |
setLossFunction(int function)
Set the loss function to use.
|
String |
toString()
Prints out the classifier.
|
void |
trainOnInstanceImpl(weka.core.Instance instance)
Trains the classifier with the given instance.
|
void |
trainOnInstanceImpl(weka.core.Instance instance,
int classLabel) |
contextIsCompatible, copy, correctlyClassifies, getAttributeNameString, getAWTRenderer, getClassLabelString, getClassNameString, getDescription, getModelContext, getModelMeasurements, getNominalValueString, getSubClassifiers, modelAttIndexToInstanceAttIndex, modelAttIndexToInstanceAttIndex, prepareForUseImpl, resetLearning, setModelContext, setRandomSeed, trainingHasStarted, trainingWeightSeenByModel, trainOnInstance
discoverOptionsViaReflection, getCLICreationString, getOptions, getPreparedClassOption, getPreparedClassOption, prepareClassOptions, prepareForUse, prepareForUse
copy, measureByteSize, measureByteSize
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
getCLICreationString, getOptions, prepareForUse, prepareForUse
measureByteSize
protected double m_lambda
public FloatOption lambdaRegularizationOption
protected double m_learningRate
public FloatOption learningRateOption
protected DoubleVector[] m_weights
protected double[] m_bias
protected double m_t
protected double m_numInstances
protected static final int HINGE
protected static final int LOGLOSS
protected static final int SQUAREDLOSS
protected int m_loss
public MultiChoiceOption lossFunctionOption
public String getPurposeString()
OptionHandler
getPurposeString
in interface OptionHandler
getPurposeString
in class AbstractClassifier
public void setLambda(double lambda)
lambda
- the value of lambda to usepublic double getLambda()
public void setLossFunction(int function)
function
- the loss function to use.public int getLossFunction()
public void setLearningRate(double lr)
lr
- the learning rate to use.public double getLearningRate()
public void reset()
protected double dloss(double z)
protected static double dotProd(weka.core.Instance inst1, DoubleVector weights, int classIndex)
public void resetLearningImpl()
AbstractClassifier
resetLearningImpl
in class AbstractClassifier
public void trainOnInstanceImpl(weka.core.Instance instance)
trainOnInstanceImpl
in class AbstractClassifier
instance
- the new training instance to include in the modelpublic void trainOnInstanceImpl(weka.core.Instance instance, int classLabel)
public double[] getVotesForInstance(weka.core.Instance inst)
getVotesForInstance
in interface Classifier
instance
- the instance to be classifiedpublic void getModelDescription(StringBuilder result, int indent)
AbstractClassifier
getModelDescription
in class AbstractClassifier
result
- the stringbuilder to add the descriptionindent
- the number of characters to indentpublic String toString()
toString
in class AbstractMOAObject
protected Measurement[] getModelMeasurementsImpl()
AbstractClassifier
getModelMeasurementsImpl
in class AbstractClassifier
public boolean isRandomizable()
Classifier
isRandomizable
in interface Classifier
Copyright © 2014 University of Waikato, Hamilton, NZ. All Rights Reserved.