Skip to content

Commit c0d449f

Browse files
FUNCTIONAL BREAKING CHANGE. Transform chooses score scope by default. (#6269)
1 parent 61c347c commit c0d449f

File tree

6 files changed

+55
-16
lines changed

6 files changed

+55
-16
lines changed

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,26 +119,46 @@ public TransformerChain(params ITransformer[] transformers)
119119
}
120120

121121
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
122+
{
123+
// Default to only scoring scope.
124+
return GetOutputSchema(inputSchema, TransformerScope.Scoring);
125+
}
126+
127+
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema, TransformerScope scope)
122128
{
123129
Contracts.CheckValue(inputSchema, nameof(inputSchema));
124130

131+
var chain = GetModelFor(scope);
132+
125133
var s = inputSchema;
126-
foreach (var xf in _transformers)
134+
foreach (var xf in chain)
127135
s = xf.GetOutputSchema(s);
128136
return s;
129137
}
130138

131139
public IDataView Transform(IDataView input)
140+
{
141+
// Default to only scoring scope.
142+
return Transform(input, TransformerScope.Scoring);
143+
}
144+
145+
public IDataView Transform(IDataView input, TransformerScope scope)
132146
{
133147
Contracts.CheckValue(input, nameof(input));
134148

149+
// Default to all scopes, but still allow for smaller scopes.
150+
var chain = GetModelFor(scope);
151+
135152
// Trigger schema propagation prior to transforming.
136153
// REVIEW: does this actually constitute 'early warning', given that Transform call is lazy anyway?
137-
GetOutputSchema(input.Schema);
154+
chain.GetOutputSchema(input.Schema);
138155

139156
var dv = input;
140-
foreach (var xf in _transformers)
141-
dv = xf.Transform(dv);
157+
foreach (var transformer in chain)
158+
{
159+
dv = transformer.Transform(dv);
160+
}
161+
142162
return dv;
143163
}
144164

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8+
using System.Runtime.CompilerServices;
89
using Microsoft.ML.Calibrators;
910
using Microsoft.ML.Data;
1011
using Microsoft.ML.Runtime;
@@ -102,14 +103,34 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs
102103
foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, splitColumn, numFolds))
103104
{
104105
var model = estimator.Fit(split.TrainSet);
105-
var scoredTest = model.Transform(split.TestSet);
106+
IDataView scoredTest;
107+
108+
if (IsCastableToTransformerChainOfITransformer(model))
109+
scoredTest = (Unsafe.As<TransformerChain<ITransformer>>(model)).Transform(split.TestSet, TransformerScope.Everything);
110+
else
111+
scoredTest = model.Transform(split.TestSet);
106112
result[fold] = new CrossValidationResult(model, scoredTest, fold);
107113
fold++;
108114
}
109115

110116
return result;
111117
}
112118

119+
private static bool IsCastableToTransformerChainOfITransformer(object o)
120+
{
121+
var type = o.GetType();
122+
while (!type!.FullName!.StartsWith("Microsoft.ML.Data.TransformerChain`1[", StringComparison.Ordinal))
123+
{
124+
type = type!.BaseType;
125+
if (type is null)
126+
{
127+
return false;
128+
}
129+
}
130+
131+
return true;
132+
}
133+
113134
[BestFriend]
114135
private protected TrainCatalogBase(IHostEnvironment env, string registrationName)
115136
{

test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ private void CrossValidationOn(string dataPath)
657657
// Train the model.
658658
var model = pipeline.Fit(split.TrainSet);
659659
// Compute quality metrics on the test set.
660-
var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet));
660+
var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet, TransformerScope.Everything));
661661
Console.WriteLine(metrics.MicroAccuracy);
662662

663663
// Now run the 5-fold cross-validation experiment, using the same pipeline.

test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
8686
Assert.True(prediction.PredictedPlant == "Iris-versicolor");
8787

8888
// Evaluate the trained pipeline
89-
var predicted = trainedModel.Transform(testData);
89+
var predicted = trainedModel.Transform(testData, TransformerScope.Everything);
9090
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topKPredictionCount: 3);
9191

9292
Assert.Equal(.98, metrics.MacroAccuracy);

test/Microsoft.ML.Tests/TextClassificationTests.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ public void TestSingleSentence2Classes()
9292
.Append(ML.MulticlassClassification.Trainers.TextClassification(outputColumnName: "outputColumn"))
9393
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
9494

95-
TestEstimatorCore(estimator, dataView);
9695
var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
9796

9897
Assert.Equal(5, estimatorSchema.Count);
@@ -104,9 +103,9 @@ public void TestSingleSentence2Classes()
104103

105104
var filteredModel = transformer.GetModelFor(TransformerScope.Scoring);
106105

107-
Assert.Equal(6, transformerSchema.Count);
108-
Assert.Equal("outputColumn", transformerSchema[4].Name);
109-
Assert.Equal(TextDataViewType.Instance, transformerSchema[4].Type);
106+
Assert.Equal(5, transformerSchema.Count);
107+
Assert.Equal("outputColumn", transformerSchema[3].Name);
108+
Assert.Equal(TextDataViewType.Instance, transformerSchema[3].Type);
110109

111110
var dataNoLabel = ML.Data.LoadFromEnumerable(
112111
new List<TestSingleSentenceDataNoLabel>(new TestSingleSentenceDataNoLabel[] {
@@ -144,16 +143,15 @@ public void TestSingleSentence2Classes()
144143
}
145144
}));
146145

147-
var predictedLabel = filteredModel.Transform(dataNoLabel).GetColumn<ReadOnlyMemory<char>>(transformerSchema[4].Name);
146+
var predictedLabel = filteredModel.Transform(dataNoLabel).GetColumn<ReadOnlyMemory<char>>(transformerSchema[3].Name);
148147

149148
// Make sure that we can use the multiclass evaluate method
150-
var metrics = ML.MulticlassClassification.Evaluate(transformer.Transform(dataView), predictedLabelColumnName: "outputColumn");
149+
var metrics = ML.MulticlassClassification.Evaluate(transformer.Transform(dataView, TransformerScope.Everything), predictedLabelColumnName: "outputColumn");
151150
Assert.NotNull(metrics);
152151

153-
// Not enough training is done to get good results so just make sure the count is right and are negative.
152+
// Not enough training is done to get good results so just make sure the count is right.
154153
var a = predictedLabel.ToList();
155154
Assert.Equal(8, a.Count());
156-
Assert.True(predictedLabel.All(value => value.ToString() == "Negative"));
157155
}
158156

159157
[Fact]

test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public void MetacomponentsFeaturesRenamed()
9595
});
9696

9797
var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
98-
.Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest)
98+
.Append(new ValueToKeyMappingEstimator(Env, "Label"))
9999
.Append(ML.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer))
100100
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
101101

0 commit comments

Comments
 (0)