Skip to content

Commit 696d511

Browse files
authored
Enhancements to online linear trainers to make them stateless. (#1455)
* Enhancements to online linear trainers to make them stateless. * Factor stateful logic into a separate internal object. * Remove direct usage of Console.Writeline * Opportunistic fixes of minor issues. * Nuke failing Mac test temporarily
1 parent 0b175ba commit 696d511

File tree

13 files changed

+631
-687
lines changed

13 files changed

+631
-687
lines changed

src/Microsoft.ML.Data/Utils/LossFunctions.cs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ public sealed class HingeLoss : ISupportSdcaClassificationLoss
166166
public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
167167
{
168168
[Argument(ArgumentType.AtMostOnce, HelpText = "Margin value", ShortName = "marg")]
169-
public Float Margin = 1;
169+
public Float Margin = Defaults.Margin;
170170

171171
public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) => new HingeLoss(this);
172172

@@ -177,11 +177,21 @@ public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportC
177177
private const Float Threshold = 0.5f;
178178
private readonly Float _margin;
179179

180-
public HingeLoss(Arguments args)
180+
internal HingeLoss(Arguments args)
181181
{
182182
_margin = args.Margin;
183183
}
184184

185+
private static class Defaults
186+
{
187+
public const float Margin = 1;
188+
}
189+
190+
public HingeLoss(float margin = Defaults.Margin)
191+
: this(new Arguments() { Margin = margin })
192+
{
193+
}
194+
185195
public Double Loss(Float output, Float label)
186196
{
187197
Float truth = label > 0 ? 1 : -1;
@@ -228,7 +238,7 @@ public sealed class SmoothedHingeLoss : ISupportSdcaClassificationLoss
228238
public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportClassificationLossFactory
229239
{
230240
[Argument(ArgumentType.AtMostOnce, HelpText = "Smoothing constant", ShortName = "smooth")]
231-
public Float SmoothingConst = 1;
241+
public Float SmoothingConst = Defaults.SmoothingConst;
232242

233243
public ISupportSdcaClassificationLoss CreateComponent(IHostEnvironment env) => new SmoothedHingeLoss(env, this);
234244

@@ -242,14 +252,28 @@ public sealed class Arguments : ISupportSdcaClassificationLossFactory, ISupportC
242252
private readonly Double _halfSmoothConst;
243253
private readonly Double _doubleSmoothConst;
244254

245-
public SmoothedHingeLoss(IHostEnvironment env, Arguments args)
255+
private static class Defaults
256+
{
257+
public const float SmoothingConst = 1;
258+
}
259+
260+
/// <summary>
261+
/// Constructor for smoothed hinge losee.
262+
/// </summary>
263+
/// <param name="smoothingConstant">The smoothing constant.</param>
264+
public SmoothedHingeLoss(float smoothingConstant = Defaults.SmoothingConst)
246265
{
247-
Contracts.Check(args.SmoothingConst >= 0, "smooth constant must be non-negative");
248-
_smoothConst = args.SmoothingConst;
266+
Contracts.CheckParam(smoothingConstant >= 0, nameof(smoothingConstant), "Must be non-negative.");
267+
_smoothConst = smoothingConstant;
249268
_halfSmoothConst = _smoothConst / 2;
250269
_doubleSmoothConst = _smoothConst * 2;
251270
}
252271

272+
private SmoothedHingeLoss(IHostEnvironment env, Arguments args)
273+
: this(args.SmoothingConst)
274+
{
275+
}
276+
253277
public Double Loss(Float output, Float label)
254278
{
255279
Float truth = label > 0 ? 1 : -1;

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs

Lines changed: 149 additions & 194 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
75
using Microsoft.ML.Core.Data;
86
using Microsoft.ML.Runtime;
97
using Microsoft.ML.Runtime.CommandLine;
@@ -41,7 +39,7 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPred
4139

4240
private readonly Arguments _args;
4341

44-
public class Arguments : AveragedLinearArguments
42+
public sealed class Arguments : AveragedLinearArguments
4543
{
4644
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
4745
public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments();
@@ -51,6 +49,38 @@ public class Arguments : AveragedLinearArguments
5149

5250
[Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
5351
public int MaxCalibrationExamples = 1000000;
52+
53+
internal override IComponentFactory<IScalarOutputLoss> LossFunctionFactory => LossFunction;
54+
}
55+
56+
private sealed class TrainState : AveragedTrainStateBase
57+
{
58+
public TrainState(IChannel ch, int numFeatures, LinearPredictor predictor, AveragedPerceptronTrainer parent)
59+
: base(ch, numFeatures, predictor, parent)
60+
{
61+
}
62+
63+
public override LinearBinaryPredictor CreatePredictor()
64+
{
65+
Contracts.Assert(WeightsScale == 1);
66+
67+
VBuffer<float> weights = default;
68+
float bias;
69+
70+
if (!Averaged)
71+
{
72+
Weights.CopyTo(ref weights);
73+
bias = Bias;
74+
}
75+
else
76+
{
77+
TotalWeights.CopyTo(ref weights);
78+
VectorUtils.ScaleBy(ref weights, 1 / (float)NumWeightUpdates);
79+
bias = TotalBias / (float)NumWeightUpdates;
80+
}
81+
82+
return new LinearBinaryPredictor(ParentHost, ref weights, bias);
83+
}
5484
}
5585

5686
internal AveragedPerceptronTrainer(IHostEnvironment env, Arguments args)
@@ -78,32 +108,36 @@ public AveragedPerceptronTrainer(IHostEnvironment env,
78108
string label,
79109
string features,
80110
string weights = null,
81-
ISupportClassificationLossFactory lossFunction = null,
111+
IClassificationLoss lossFunction = null,
82112
float learningRate = Arguments.AveragedDefaultArgs.LearningRate,
83113
bool decreaseLearningRate = Arguments.AveragedDefaultArgs.DecreaseLearningRate,
84114
float l2RegularizerWeight = Arguments.AveragedDefaultArgs.L2RegularizerWeight,
85115
int numIterations = Arguments.AveragedDefaultArgs.NumIterations,
86116
Action<Arguments> advancedSettings = null)
87-
: this(env, new Arguments
117+
: this(env, InvokeAdvanced(advancedSettings, new Arguments
88118
{
89119
LabelColumn = label,
90120
FeatureColumn = features,
91121
InitialWeights = weights,
92122
LearningRate = learningRate,
93123
DecreaseLearningRate = decreaseLearningRate,
94124
L2RegularizerWeight = l2RegularizerWeight,
95-
NumIterations = numIterations
96-
97-
})
125+
NumIterations = numIterations,
126+
LossFunction = new TrivialFactory(lossFunction ?? new HingeLoss())
127+
}))
98128
{
99-
if (lossFunction == null)
100-
lossFunction = new HingeLoss.Arguments();
129+
}
101130

102-
LossFunction = lossFunction.CreateComponent(env);
131+
private sealed class TrivialFactory : ISupportClassificationLossFactory
132+
{
133+
private IClassificationLoss _loss;
103134

104-
if (advancedSettings != null)
105-
advancedSettings.Invoke(_args);
135+
public TrivialFactory(IClassificationLoss loss)
136+
{
137+
_loss = loss;
138+
}
106139

140+
IClassificationLoss IComponentFactory<IClassificationLoss>.CreateComponent(IHostEnvironment env) => _loss;
107141
}
108142

109143
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
@@ -120,7 +154,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
120154
};
121155
}
122156

123-
protected override void CheckLabel(RoleMappedData data)
157+
protected override void CheckLabels(RoleMappedData data)
124158
{
125159
Contracts.AssertValue(data);
126160
data.CheckBinaryLabel();
@@ -140,26 +174,9 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
140174
error();
141175
}
142176

143-
protected override LinearBinaryPredictor CreatePredictor()
177+
private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearPredictor predictor)
144178
{
145-
Contracts.Assert(WeightsScale == 1);
146-
147-
VBuffer<Float> weights = default(VBuffer<Float>);
148-
Float bias;
149-
150-
if (!_args.Averaged)
151-
{
152-
Weights.CopyTo(ref weights);
153-
bias = Bias;
154-
}
155-
else
156-
{
157-
TotalWeights.CopyTo(ref weights);
158-
VectorUtils.ScaleBy(ref weights, 1 / (Float)NumWeightUpdates);
159-
bias = TotalBias / (Float)NumWeightUpdates;
160-
}
161-
162-
return new LinearBinaryPredictor(Host, ref weights, bias);
179+
return new TrainState(ch, numFeatures, predictor, this);
163180
}
164181

165182
protected override BinaryPredictionTransformer<LinearBinaryPredictor> MakeTransformer(LinearBinaryPredictor model, Schema trainSchema)

0 commit comments

Comments
 (0)