Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.Runtime
{
public interface IPredictionTransformer<out TModel> : ITransformer
Copy link
Contributor

@Zruty0 Zruty0 Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IPredictionTransformer [](start = 21, length = 22)

I don't think we even need an interface for FFM, since for the time being it's the only trainer that accepts multiple feature columns. #Closed

Copy link
Member Author

@sfilipi sfilipi Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to inquire about all trainers, it is useful to have them extend one interface. #Closed

where TModel : IPredictor
{
string FeatureColumn { get; }
string[] FeatureColumn { get; }

ColumnType FeatureColumnType { get; }
ColumnType[] FeatureColumnType { get; }

TModel Model { get; }
}
Expand Down
93 changes: 69 additions & 24 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Model;
using static Microsoft.ML.Runtime.Data.RoleMappedSchema;

[assembly: LoadableClass(typeof(BinaryPredictionTransformer<IPredictorProducing<float>>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel),
"", BinaryPredictionTransformer.LoaderSignature)]
Expand All @@ -30,23 +32,33 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
protected readonly ISchemaBindableMapper BindableMapper;
protected readonly ISchema TrainSchema;

public string FeatureColumn { get; }
public string[] FeatureColumn { get; }
Copy link
Contributor

@Zruty0 Zruty0 Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FeatureColumn [](start = 24, length = 13)

oh no... I don' like this change already.

Forcing ALL predictors to expose parallel arrays of feature columns is not a great change #Closed

Copy link
Member Author

@sfilipi sfilipi Sep 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's much overhead from it. Do you think it will cause problems? #Closed


public ColumnType FeatureColumnType { get; }
public ColumnType[] FeatureColumnType { get; }

public TModel Model { get; }

public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string[] featureColumns)
{
Contracts.CheckValue(host, nameof(host));
Host = host;
Host.CheckValue(trainSchema, nameof(trainSchema));
Host.CheckValue(featureColumns, nameof(featureColumns));

int featCount = featureColumns.Length;
Host.Check(featCount >= 0 , "Empty features column.");

Model = model;
FeatureColumn = featureColumn;
if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
FeatureColumnType = trainSchema.GetColumnType(col);
FeatureColumn = featureColumns;
FeatureColumnType = new ColumnType[featCount];

int i = 0;
foreach (var feat in featureColumns)
{
if (!trainSchema.TryGetColumnIndex(feat, out int col))
throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat);
FeatureColumnType[i++] = trainSchema.GetColumnType(col);
}

TrainSchema = trainSchema;
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
Expand All @@ -62,7 +74,8 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
// *** Binary format ***
// model: prediction model.
// stream: empty data view that contains train schema.
// id of string: feature column.
// count of features
// id of string: feature columns.

// Clone the stream with the schema into memory.
var ms = new MemoryStream();
Expand All @@ -75,10 +88,19 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);
TrainSchema = loader.Schema;

FeatureColumn = ctx.LoadString();
if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
FeatureColumnType = TrainSchema.GetColumnType(col);
// count of feature columns. FAFM uses more than one.
int featCount = int.Parse(ctx.LoadString());

FeatureColumn = new string[featCount];
FeatureColumnType = new ColumnType[featCount];

for (int i = 0; i < featCount; i++)
{
FeatureColumn[i] = ctx.LoadString();
if (!TrainSchema.TryGetColumnIndex(FeatureColumn[i], out int col))
throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn[i]);
FeatureColumnType[i] = TrainSchema.GetColumnType(col);
}

BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
}
Expand All @@ -87,10 +109,15 @@ public ISchema GetOutputSchema(ISchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
for (int i=0; i< FeatureColumn.Length; i++)
{
var feat = FeatureColumn[i];
if (!inputSchema.TryGetColumnIndex(feat, out int col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), null);

if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType[i]))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), inputSchema.GetColumnType(col).ToString());
}

return Transform(new EmptyDataView(Host, inputSchema)).Schema;
}
Expand All @@ -109,6 +136,7 @@ protected virtual void SaveCore(ModelSaveContext ctx)
// *** Binary format ***
Copy link
Contributor

@Zruty0 Zruty0 Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*** Binary format *** [](start = 15, length = 21)

whenever you save or load, need *** Binary format *** #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?


In reply to: 218633828 [](ancestors = 218633828)

// model: prediction model.
Copy link
Contributor

@TomFinley TomFinley Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model: prediction model. [](start = 15, length = 24)

Technically the model isn't part of this format, since you're not writing it to the stream, you're writing it somewhere else entirely, but that's OK. Consider fixing if you have to change the code anyway. #Resolved

Copy link
Member Author

@sfilipi sfilipi Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@[email protected] fixing it == remove the comment?


In reply to: 218818528 [](ancestors = 218818528)

// stream: empty data view that contains train schema.
// number of feature columns
// id of string: feature column.

ctx.SaveModel(Model, DirModel);
Expand All @@ -121,7 +149,24 @@ protected virtual void SaveCore(ModelSaveContext ctx)
}
});

ctx.SaveString(FeatureColumn);
int featCount = FeatureColumn.Length;

ctx.SaveString(featCount.ToString());
for(int i=0; i< featCount; i++)
ctx.SaveString(FeatureColumn[i]);
}

protected RoleMappedSchema GetSchema(ISchema inputSchema = null, string trainLabelColumn = null)
{
var roles = new List<KeyValuePair<ColumnRole, string>>();
foreach (var feat in FeatureColumn)
roles.Add(new KeyValuePair<ColumnRole, string>(ColumnRole.Feature, feat));

if(trainLabelColumn !=null)
roles.Add(new KeyValuePair<ColumnRole, string>(ColumnRole.Label, trainLabelColumn));

var schema = new RoleMappedSchema(inputSchema ?? TrainSchema, roles);
return schema;
}
}

Expand All @@ -133,12 +178,12 @@ public sealed class BinaryPredictionTransformer<TModel> : PredictionTransformerB
public readonly string ThresholdColumn;
public readonly float Threshold;

public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn,
public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string[] featureColumn,
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
var schema = new RoleMappedSchema(inputSchema, null, featureColumn);
var schema = GetSchema(inputSchema);
Threshold = threshold;
ThresholdColumn = thresholdColumn;

Expand All @@ -157,7 +202,7 @@ public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
Threshold = ctx.Reader.ReadSingle();
ThresholdColumn = ctx.LoadString();

var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
var schema = GetSchema();
var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn };
_scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}
Expand Down Expand Up @@ -201,7 +246,7 @@ public sealed class MulticlassPredictionTransformer<TModel> : PredictionTransfor
private readonly string _trainLabelColumn;

public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, string labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), model, inputSchema, new[] { featureColumn })
{
Host.CheckValueOrNull(labelColumn);

Expand All @@ -220,7 +265,7 @@ public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ct

_trainLabelColumn = ctx.LoadStringOrNull();

var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn);
var schema = GetSchema(trainLabelColumn: _trainLabelColumn);
var args = new MultiClassClassifierScorer.Arguments();
_scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}
Expand Down Expand Up @@ -261,7 +306,7 @@ public sealed class RegressionPredictionTransformer<TModel> : PredictionTransfor
private readonly GenericScorer _scorer;

public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), model, inputSchema, new[] { featureColumn })
{
var schema = new RoleMappedSchema(inputSchema, null, featureColumn);
_scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema);
Expand All @@ -270,7 +315,7 @@ public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISche
internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), ctx)
{
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
var schema = GetSchema();
_scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}

Expand Down
50 changes: 50 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Training;

namespace Microsoft.ML.Core.Prediction
{
/// <summary>
/// Holds information relevant to trainers. It is passed to the constructor of the<see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/>
/// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model.
/// This holds at least a training set, as well as optioonally a predictor.
/// </summary>
public class TrainerEstimatorContext
{
/// <summary>
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into
/// a trainer that does not support validation sets should not be considered an error condition. It
/// should simply be ignored in that case.
/// </summary>
public IDataView ValidationSet { get; }

/// <summary>
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> implementor
/// does not support incremental training, then it can ignore it similarly to how one would ignore
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception.
/// </summary>
public IPredictor InitialPredictor { get; }

/// <summary>
/// Initializes a new instance of <see cref="TrainerEstimatorContext"/>, given a training set and optional other arguments.
/// </summary>
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null)
{
Contracts.CheckValueOrNull(validationSet);
Contracts.CheckValueOrNull(initialPredictor);

ValidationSet = validationSet;
InitialPredictor = initialPredictor;
}
}
}
Loading