22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
5- using Float = System . Single ;
6-
75using Microsoft . ML . Core . Data ;
86using Microsoft . ML . Runtime ;
97using Microsoft . ML . Runtime . CommandLine ;
@@ -41,7 +39,7 @@ public sealed class AveragedPerceptronTrainer : AveragedLinearTrainer<BinaryPred
4139
4240 private readonly Arguments _args ;
4341
44- public class Arguments : AveragedLinearArguments
42+ public sealed class Arguments : AveragedLinearArguments
4543 {
4644 [ Argument ( ArgumentType . Multiple , HelpText = "Loss Function" , ShortName = "loss" , SortOrder = 50 ) ]
4745 public ISupportClassificationLossFactory LossFunction = new HingeLoss . Arguments ( ) ;
@@ -51,6 +49,38 @@ public class Arguments : AveragedLinearArguments
5149
5250 [ Argument ( ArgumentType . AtMostOnce , HelpText = "The maximum number of examples to use when training the calibrator" , Visibility = ArgumentAttribute . VisibilityType . EntryPointsOnly ) ]
5351 public int MaxCalibrationExamples = 1000000 ;
52+
53+ internal override IComponentFactory < IScalarOutputLoss > LossFunctionFactory => LossFunction ;
54+ }
55+
56+ private sealed class TrainState : AveragedTrainStateBase
57+ {
58+ public TrainState ( IChannel ch , int numFeatures , LinearPredictor predictor , AveragedPerceptronTrainer parent )
59+ : base ( ch , numFeatures , predictor , parent )
60+ {
61+ }
62+
63+ public override LinearBinaryPredictor CreatePredictor ( )
64+ {
65+ Contracts . Assert ( WeightsScale == 1 ) ;
66+
67+ VBuffer < float > weights = default ;
68+ float bias ;
69+
70+ if ( ! Averaged )
71+ {
72+ Weights . CopyTo ( ref weights ) ;
73+ bias = Bias ;
74+ }
75+ else
76+ {
77+ TotalWeights . CopyTo ( ref weights ) ;
78+ VectorUtils . ScaleBy ( ref weights , 1 / ( float ) NumWeightUpdates ) ;
79+ bias = TotalBias / ( float ) NumWeightUpdates ;
80+ }
81+
82+ return new LinearBinaryPredictor ( ParentHost , ref weights , bias ) ;
83+ }
5484 }
5585
5686 internal AveragedPerceptronTrainer ( IHostEnvironment env , Arguments args )
@@ -78,32 +108,36 @@ public AveragedPerceptronTrainer(IHostEnvironment env,
78108 string label ,
79109 string features ,
80110 string weights = null ,
81- ISupportClassificationLossFactory lossFunction = null ,
111+ IClassificationLoss lossFunction = null ,
82112 float learningRate = Arguments . AveragedDefaultArgs . LearningRate ,
83113 bool decreaseLearningRate = Arguments . AveragedDefaultArgs . DecreaseLearningRate ,
84114 float l2RegularizerWeight = Arguments . AveragedDefaultArgs . L2RegularizerWeight ,
85115 int numIterations = Arguments . AveragedDefaultArgs . NumIterations ,
86116 Action < Arguments > advancedSettings = null )
87- : this ( env , new Arguments
117+ : this ( env , InvokeAdvanced ( advancedSettings , new Arguments
88118 {
89119 LabelColumn = label ,
90120 FeatureColumn = features ,
91121 InitialWeights = weights ,
92122 LearningRate = learningRate ,
93123 DecreaseLearningRate = decreaseLearningRate ,
94124 L2RegularizerWeight = l2RegularizerWeight ,
95- NumIterations = numIterations
96-
97- } )
125+ NumIterations = numIterations ,
126+ LossFunction = new TrivialFactory ( lossFunction ?? new HingeLoss ( ) )
127+ } ) )
98128 {
99- if ( lossFunction == null )
100- lossFunction = new HingeLoss . Arguments ( ) ;
129+ }
101130
102- LossFunction = lossFunction . CreateComponent ( env ) ;
131+ private sealed class TrivialFactory : ISupportClassificationLossFactory
132+ {
133+ private IClassificationLoss _loss ;
103134
104- if ( advancedSettings != null )
105- advancedSettings . Invoke ( _args ) ;
135+ public TrivialFactory ( IClassificationLoss loss )
136+ {
137+ _loss = loss ;
138+ }
106139
140+ IClassificationLoss IComponentFactory < IClassificationLoss > . CreateComponent ( IHostEnvironment env ) => _loss ;
107141 }
108142
109143 public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
@@ -120,7 +154,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
120154 } ;
121155 }
122156
123- protected override void CheckLabel ( RoleMappedData data )
157+ protected override void CheckLabels ( RoleMappedData data )
124158 {
125159 Contracts . AssertValue ( data ) ;
126160 data . CheckBinaryLabel ( ) ;
@@ -140,26 +174,9 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol)
140174 error ( ) ;
141175 }
142176
143- protected override LinearBinaryPredictor CreatePredictor ( )
177+ private protected override TrainStateBase MakeState ( IChannel ch , int numFeatures , LinearPredictor predictor )
144178 {
145- Contracts . Assert ( WeightsScale == 1 ) ;
146-
147- VBuffer < Float > weights = default ( VBuffer < Float > ) ;
148- Float bias ;
149-
150- if ( ! _args . Averaged )
151- {
152- Weights . CopyTo ( ref weights ) ;
153- bias = Bias ;
154- }
155- else
156- {
157- TotalWeights . CopyTo ( ref weights ) ;
158- VectorUtils . ScaleBy ( ref weights , 1 / ( Float ) NumWeightUpdates ) ;
159- bias = TotalBias / ( Float ) NumWeightUpdates ;
160- }
161-
162- return new LinearBinaryPredictor ( Host , ref weights , bias ) ;
179+ return new TrainState ( ch , numFeatures , predictor , this ) ;
163180 }
164181
165182 protected override BinaryPredictionTransformer < LinearBinaryPredictor > MakeTransformer ( LinearBinaryPredictor model , Schema trainSchema )
0 commit comments