Skip to content

Commit dddb5c1

Browse files
authored
Fix ResultProcessor bug, LogisticRegression bug and missing value conversion bug (#1236)
* Fix some bugs, add some unit tests. * Fix LR stats bug * Undo accidental TermTransform change * Sweeper needs to load all components into ComponentCatalog * Rename Mapping.de-de.txt * Fix cat transform issue * Compare pr baseline only on Windows * Move baselines to Common folder * Compare pr baseline only on Windows in another test * Code review comment * Fix ConcatTransform bug * Add baselines for ConcatTransform bug * Fix another bug in TermTransform * NelderMead sweeper default value for FirstBatchSweeper arg * Add some more unit tests * Add more unit tests * Fix unit test baseline, and baseline comparison with tolerance. * Change back MatchWithTolerance method * Fix bad merge
1 parent 9a33cd4 commit dddb5c1

File tree

84 files changed

+56347
-59
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+56347
-59
lines changed

src/Microsoft.ML.Data/Data/Conversion.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,7 @@ private bool IsStdMissing(ref ReadOnlySpan<char> span)
11701170
public bool TryParseKey(ref TX src, U8 min, U8 max, out U8 dst)
11711171
{
11721172
var span = src.Span;
1173-
Contracts.Check(!IsStdMissing(ref span), "Missing text value cannot be converted to unsigned integer type.");
1173+
Contracts.Check(span.IsEmpty || !IsStdMissing(ref span), "Missing text value cannot be converted to unsigned integer type.");
11741174
Contracts.Assert(min <= max);
11751175

11761176
// This simply ensures we don't have min == 0 and max == U8.MaxValue. This is illegal since
@@ -1530,7 +1530,7 @@ public bool TryParse(ref TX src, out BL dst)
15301530
{
15311531
var span = src.Span;
15321532

1533-
Contracts.Check(!IsStdMissing(ref span), "Missing text values cannot be converted to bool value.");
1533+
Contracts.Check(span.IsEmpty || !IsStdMissing(ref span), "Missing text value cannot be converted to bool type.");
15341534

15351535
char ch;
15361536
switch (src.Length)

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ public static IDataTransform Create(IHostEnvironment env, TaggedArguments args,
384384
env.CheckUserArg(Utils.Size(args.Column[i].Source) > 0, nameof(args.Column));
385385

386386
var cols = args.Column
387-
.Select(c => new ColumnInfo(c.Name, c.Source.Select(kvp => (kvp.Value, kvp.Key))))
387+
.Select(c => new ColumnInfo(c.Name, c.Source.Select(kvp => (kvp.Value, kvp.Key != "" ? kvp.Key : null))))
388388
.ToArray();
389389
var transformer = new ConcatTransform(env, cols);
390390
return transformer.MakeDataTransform(input);

src/Microsoft.ML.Data/Transforms/TermEstimator.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public static class Defaults
2121

2222
private readonly IHost _host;
2323
private readonly TermTransform.ColumnInfo[] _columns;
24+
private readonly string _file;
25+
private readonly string _termsColumn;
26+
private readonly IComponentFactory<IMultiStreamSource, IDataLoader> _loaderFactory;
2427

2528
/// <summary>
2629
/// Convenience constructor for public facing API.
@@ -32,18 +35,23 @@ public static class Defaults
3235
/// <param name="sort">How items should be ordered when vectorized. By default, they will be in the order encountered.
3336
/// If by value items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a').</param>
3437
public TermEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
35-
this(env, new TermTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort))
38+
this(env, new[] { new TermTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort) })
3639
{
3740
}
3841

39-
public TermEstimator(IHostEnvironment env, params TermTransform.ColumnInfo[] columns)
42+
public TermEstimator(IHostEnvironment env, TermTransform.ColumnInfo[] columns,
43+
string file = null, string termsColumn = null,
44+
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory = null)
4045
{
4146
Contracts.CheckValue(env, nameof(env));
4247
_host = env.Register(nameof(TermEstimator));
4348
_columns = columns;
49+
_file = file;
50+
_termsColumn = termsColumn;
51+
_loaderFactory = loaderFactory;
4452
}
4553

46-
public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns);
54+
public TermTransform Fit(IDataView input) => new TermTransform(_host, input, _columns, _file, _termsColumn, _loaderFactory);
4755

4856
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
4957
{

src/Microsoft.ML.Data/Transforms/TermTransform.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ public TermTransform(IHostEnvironment env, IDataView input,
268268
this(env, input, columns, null, null, null)
269269
{ }
270270

271-
private TermTransform(IHostEnvironment env, IDataView input,
271+
internal TermTransform(IHostEnvironment env, IDataView input,
272272
ColumnInfo[] columns,
273273
string file = null, string termsColumn = null,
274274
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory = null)
@@ -314,13 +314,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
314314
if (!Enum.IsDefined(typeof(SortOrder), sortOrder))
315315
throw env.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected for column '{1}'", sortOrder, item.Name);
316316

317-
cols[i] = new ColumnInfo(item.Source,
317+
cols[i] = new ColumnInfo(item.Source ?? item.Name,
318318
item.Name,
319319
item.MaxNumTerms ?? args.MaxNumTerms,
320320
sortOrder,
321321
item.Term,
322322
item.TextKeyValues ?? args.TextKeyValues);
323-
cols[i].Terms = item.Terms;
323+
cols[i].Terms = item.Terms ?? args.Terms;
324324
};
325325
}
326326
return new TermTransform(env, input, cols, args.DataFile, args.TermsColumn, args.Loader).MakeDataTransform(input);

src/Microsoft.ML.ResultProcessor/ResultProcessor.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,10 +1063,10 @@ private static Experiment CreateVisualizationExperiment(ExperimentItemResult res
10631063
var experiment = new ML.Runtime.ExperimentVisualization.Experiment
10641064
{
10651065
Key = index.ToString(),
1066-
CompareGroup = string.IsNullOrEmpty(result.CustomizedTag) ? result.Trainer.Kind : result.CustomizedTag,
1066+
CompareGroup = string.IsNullOrEmpty(result.CustomizedTag) ? result.TrainerKind : result.CustomizedTag,
10671067
Trainer = new ML.Runtime.ExperimentVisualization.Trainer
10681068
{
1069-
Name = result.Trainer.Kind,
1069+
Name = result.TrainerKind,
10701070
ParameterSets = new List<ML.Runtime.ExperimentVisualization.Item>()
10711071
},
10721072
DataSet = new ML.Runtime.ExperimentVisualization.DataSet { File = result.Datafile },
@@ -1152,7 +1152,10 @@ private static object Load(Stream stream)
11521152

11531153
public static int Main(string[] args)
11541154
{
1155-
return Main(new ConsoleEnvironment(42), args);
1155+
string currentDirectory = Path.GetDirectoryName(typeof(ResultProcessor).Module.FullyQualifiedName);
1156+
using (var env = new ConsoleEnvironment(42))
1157+
using (AssemblyLoadingUtils.CreateAssemblyRegistrar(env, currentDirectory))
1158+
return Main(env, args);
11561159
}
11571160

11581161
public static int Main(IHostEnvironment env, string[] args)

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ protected override ParameterMixingCalibratedPredictor CreatePredictor()
373373
CurrentWeights.GetItemOrDefault(0, ref bias);
374374
CurrentWeights.CopyTo(ref weights, 1, CurrentWeights.Length - 1);
375375
return new ParameterMixingCalibratedPredictor(Host,
376-
new LinearBinaryPredictor(Host, ref weights, bias),
376+
new LinearBinaryPredictor(Host, ref weights, bias, _stats),
377377
new PlattCalibrator(Host, -1, 0));
378378
}
379379

src/Microsoft.ML.Sweeper/Algorithms/NelderMead.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public sealed class Arguments
2626
public IComponentFactory<IValueGenerator>[] SweptParameters;
2727

2828
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The sweeper used to get the initial results.", ShortName = "init", SignatureType = typeof(SignatureSweeperFromParameterList))]
29-
public IComponentFactory<IValueGenerator[], ISweeper> FirstBatchSweeper;
29+
public IComponentFactory<IValueGenerator[], ISweeper> FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction<IValueGenerator[], ISweeper>((host, array) => new UniformRandomSweeper(host, new SweeperBase.ArgumentsBase(), array));
3030

3131
[Argument(ArgumentType.AtMostOnce, HelpText = "Seed for the random number generator for the first batch sweeper", ShortName = "seed")]
3232
public int RandomSeed;

src/Microsoft.ML.Sweeper/ConfigRunner.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ public virtual void Finish()
107107
if (Exe == null || Exe.EndsWith("maml", StringComparison.OrdinalIgnoreCase) ||
108108
Exe.EndsWith("maml.exe", StringComparison.OrdinalIgnoreCase))
109109
{
110+
string currentDirectory = Path.GetDirectoryName(typeof(ExeConfigRunnerBase).Module.FullyQualifiedName);
111+
110112
using (var ch = Host.Start("Finish"))
113+
using (AssemblyLoadingUtils.CreateAssemblyRegistrar(Host, currentDirectory))
111114
{
112115
var runs = RunNums.ToArray();
113116
var args = Utils.BuildArray(RunNums.Count + 2,
@@ -120,7 +123,7 @@ public virtual void Finish()
120123
return string.Format("{{{0}}}", GetFilePath(runs[i], "out"));
121124
});
122125

123-
ResultProcessorInternal.ResultProcessor.Main (args);
126+
ResultProcessorInternal.ResultProcessor.Main(args);
124127

125128
ch.Info(@"The summary of the run results has been saved to the file {0}\{1}.summary.txt", OutputFolder, Prefix);
126129
}

src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,8 @@
1616

1717
</ItemGroup>
1818

19+
<ItemGroup>
20+
<Compile Include="..\Common\AssemblyLoadingUtils.cs" Link="Common\AssemblyLoadingUtils.cs" />
21+
</ItemGroup>
22+
1923
</Project>

src/Microsoft.ML.Transforms/CategoricalTransform.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,20 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
135135
column.MaxNumTerms ?? args.MaxNumTerms,
136136
column.Sort ?? args.Sort,
137137
column.Term ?? args.Term);
138-
col.SetTerms(column.Terms);
138+
col.SetTerms(column.Terms ?? args.Terms);
139139
columns.Add(col);
140140
}
141-
return new CategoricalEstimator(env, columns.ToArray()).Fit(input).Transform(input) as IDataTransform;
141+
return new CategoricalEstimator(env, columns.ToArray(), args.DataFile, args.TermsColumn, args.Loader).Fit(input).Transform(input) as IDataTransform;
142142
}
143143

144144
private readonly TransformerChain<ITransformer> _transformer;
145145

146146
public CategoricalTransform(TermEstimator term, IEstimator<ITransformer> toVector, IDataView input)
147147
{
148-
var chain = term.Append(toVector);
149-
_transformer = chain.Fit(input);
148+
if (toVector != null)
149+
_transformer = term.Append(toVector).Fit(input);
150+
else
151+
_transformer = new TransformerChain<ITransformer>(term.Fit(input));
150152
}
151153

152154
public Schema GetOutputSchema(Schema inputSchema) => _transformer.GetOutputSchema(inputSchema);
@@ -198,15 +200,17 @@ internal void SetTerms(string terms)
198200
/// <param name="outputKind">The type of output expected.</param>
199201
public CategoricalEstimator(IHostEnvironment env, string input,
200202
string output = null, CategoricalTransform.OutputKind outputKind = Defaults.OutKind)
201-
: this(env, new ColumnInfo(input, output ?? input, outputKind))
203+
: this(env, new[] { new ColumnInfo(input, output ?? input, outputKind) })
202204
{
203205
}
204206

205-
public CategoricalEstimator(IHostEnvironment env, params ColumnInfo[] columns)
207+
public CategoricalEstimator(IHostEnvironment env, ColumnInfo[] columns,
208+
string file = null, string termsColumn = null,
209+
IComponentFactory<IMultiStreamSource, IDataLoader> loaderFactory = null)
206210
{
207211
Contracts.CheckValue(env, nameof(env));
208212
_host = env.Register(nameof(TermEstimator));
209-
_term = new TermEstimator(_host, columns);
213+
_term = new TermEstimator(_host, columns, file, termsColumn, loaderFactory);
210214
var binaryCols = new List<(string input, string output)>();
211215
var cols = new List<(string input, string output, bool bag)>();
212216
for (int i = 0; i < columns.Length; i++)

0 commit comments

Comments
 (0)