Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions src/Microsoft.ML.Data/Utils/LossFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public sealed class HingeLoss : ISupportSdcaClassificationLoss
public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Margin value", ShortName = "marg")]
public Float Margin = 1;
public Float Margin = Defaults.Margin;

public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) => new HingeLoss(this);

Expand All @@ -177,11 +177,21 @@ public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportC
private const Float Threshold = 0.5f;
private readonly Float _margin;

public HingeLoss(Arguments args)
internal HingeLoss(Arguments args)
{
_margin = args.Margin;
}

private static class Defaults
{
public const float Margin = 1;
}

public HingeLoss(float margin = Defaults.Margin)
: this(new Arguments() { Margin = margin })
{
}

public Double Loss(Float output, Float label)
{
Float truth = label > 0 ? 1 : -1;
Expand Down Expand Up @@ -228,7 +238,7 @@ public sealed class SmoothedHingeLoss : ISupportSdcaClassificationLoss
public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing constant", ShortName = "smooth")]
public Float SmoothingConst = 1;
public Float SmoothingConst = Defaults.SmoothingConst;

public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) => new SmoothedHingeLoss(env, this);

Expand All @@ -242,14 +252,28 @@ public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportC
private readonly Double _halfSmoothConst;
private readonly Double _doubleSmoothConst;

public SmoothedHingeLoss(IHostEnvironment env, Arguments args)
private static class Defaults
{
public const float SmoothingConst = 1;
}

/// <summary>
/// Constructor for smoothed hinge losee.
/// </summary>
/// <param name="smoothingConstant">The smoothing constant.</param>
public SmoothedHingeLoss(float smoothingConstant = Defaults.SmoothingConst)
{
Contracts.Check(args.SmoothingConst >= 0, "smooth constant must be non-negative");
_smoothConst = args.SmoothingConst;
Contracts.CheckParam(smoothingConstant >= 0, nameof(smoothingConstant), "Must be non-negative.");
_smoothConst = smoothingConstant;
_halfSmoothConst = _smoothConst / 2;
_doubleSmoothConst = _smoothConst * 2;
}

private SmoothedHingeLoss(IHostEnvironment env, Arguments args)
: this(args.SmoothingConst)
{
}

public Double Loss(Float output, Float label)
{
Float truth = label > 0 ? 1 : -1;
Expand Down
343 changes: 149 additions & 194 deletions src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Float = System.Single;

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
Expand Down Expand Up @@ -41,7 +39,7 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPred

private readonly Arguments _args;

public class Arguments : AveragedLinearArguments
public sealed class Arguments : AveragedLinearArguments
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments();
Expand All @@ -51,6 +49,38 @@ public class Arguments : AveragedLinearArguments

[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public int MaxCalibrationExamples = 1000000;

internal override IComponentFactory<IScalarOutputLoss> LossFunctionFactory => LossFunction;
}

private sealed class TrainState : AveragedTrainStateBase
{
public TrainState(IChannel ch, int numFeatures, LinearPredictor predictor, AveragedPerceptronTrainer parent)
: base(ch, numFeatures, predictor, parent)
{
}

public override LinearBinaryPredictor CreatePredictor()
{
Contracts.Assert(WeightsScale == 1);

VBuffer<float> weights = default;
float bias;

if (!Averaged)
{
Weights.CopyTo(ref weights);
bias = Bias;
}
else
{
TotalWeights.CopyTo(ref weights);
VectorUtils.ScaleBy(ref weights, 1 / (float)NumWeightUpdates);
bias = TotalBias / (float)NumWeightUpdates;
}

return new LinearBinaryPredictor(ParentHost, ref weights, bias);
}
}

internal AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
Expand Down Expand Up @@ -78,32 +108,36 @@ public AveragedPerceptronTrainer(IHostEnvironment env,
string label,
string features,
string weights = null,
ISupportClassificationLossFactory lossFunction = null,
IClassificationLoss lossFunction = null,
float learningRate = Arguments.AveragedDefaultArgs.LearningRate,
bool decreaseLearningRate = Arguments.AveragedDefaultArgs.DecreaseLearningRate,
float l2RegularizerWeight = Arguments.AveragedDefaultArgs.L2RegularizerWeight,
int numIterations = Arguments.AveragedDefaultArgs.NumIterations,
Action<Arguments> advancedSettings = null)
: this(env, new Arguments
: this(env, InvokeAdvanced(advancedSettings, new Arguments
{
LabelColumn = label,
FeatureColumn = features,
InitialWeights = weights,
LearningRate = learningRate,
DecreaseLearningRate = decreaseLearningRate,
L2RegularizerWeight = l2RegularizerWeight,
NumIterations = numIterations

})
NumIterations = numIterations,
LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss())
}))
{
if (lossFunction == null)
lossFunction = new HingeLoss.Arguments();
}

LossFunction = lossFunction.CreateComponent(env);
private sealed class TrivialFactory : ISupportClassificationLossFactory
{
private IClassificationLoss _loss;

if (advancedSettings != null)
advancedSettings.Invoke(_args);
public TrivialFactory(IClassificationLoss loss)
{
_loss = loss;
}

IClassificationLoss IComponentFactory<IClassificationLoss>.CreateComponent(IHostEnvironment env) => _loss;
}

public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
Expand All @@ -120,7 +154,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
};
}

protected override void CheckLabel(RoleMappedData data)
protected override void CheckLabels(RoleMappedData data)
{
Contracts.AssertValue(data);
data.CheckBinaryLabel();
Expand All @@ -140,26 +174,9 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
error();
}

protected override LinearBinaryPredictor CreatePredictor()
private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearPredictor predictor)
{
Contracts.Assert(WeightsScale == 1);

VBuffer<Float> weights = default(VBuffer<Float>);
Float bias;

if (!_args.Averaged)
{
Weights.CopyTo(ref weights);
bias = Bias;
}
else
{
TotalWeights.CopyTo(ref weights);
VectorUtils.ScaleBy(ref weights, 1 / (Float)NumWeightUpdates);
bias = TotalBias / (Float)NumWeightUpdates;
}

return new LinearBinaryPredictor(Host, ref weights, bias);
return new TrainState(ch, numFeatures, predictor, this);
}

protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, Schema trainSchema)
Expand Down
Loading