-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Field-aware factorization machine to estimator #912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b53d09e
d7c942d
6358777
0e31686
67f41a3
d4f5413
26691e3
5890b11
174e75d
65a6296
492a890
e86cfb9
3d4858d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,13 @@ | |
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.IO; | ||
| using Microsoft.ML.Runtime; | ||
| using Microsoft.ML.Runtime.Data; | ||
| using Microsoft.ML.Runtime.Data.IO; | ||
| using Microsoft.ML.Runtime.Model; | ||
| using static Microsoft.ML.Runtime.Data.RoleMappedSchema; | ||
|
|
||
| [assembly: LoadableClass(typeof(BinaryPredictionTransformer<IPredictorProducing<float>>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel), | ||
| "", BinaryPredictionTransformer.LoaderSignature)] | ||
|
|
@@ -30,23 +32,33 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer | |
| protected readonly ISchemaBindableMapper BindableMapper; | ||
| protected readonly ISchema TrainSchema; | ||
|
|
||
| public string FeatureColumn { get; } | ||
| public string[] FeatureColumn { get; } | ||
|
||
|
|
||
| public ColumnType FeatureColumnType { get; } | ||
| public ColumnType[] FeatureColumnType { get; } | ||
|
|
||
| public TModel Model { get; } | ||
|
|
||
| public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) | ||
| public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string[] featureColumns) | ||
| { | ||
| Contracts.CheckValue(host, nameof(host)); | ||
| Host = host; | ||
| Host.CheckValue(trainSchema, nameof(trainSchema)); | ||
| Host.CheckValue(featureColumns, nameof(featureColumns)); | ||
|
|
||
| int featCount = featureColumns.Length; | ||
| Host.Check(featCount >= 0 , "Empty features column."); | ||
|
|
||
| Model = model; | ||
| FeatureColumn = featureColumn; | ||
| if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) | ||
| throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); | ||
| FeatureColumnType = trainSchema.GetColumnType(col); | ||
| FeatureColumn = featureColumns; | ||
| FeatureColumnType = new ColumnType[featCount]; | ||
|
|
||
| int i = 0; | ||
| foreach (var feat in featureColumns) | ||
| { | ||
| if (!trainSchema.TryGetColumnIndex(feat, out int col)) | ||
| throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat); | ||
| FeatureColumnType[i++] = trainSchema.GetColumnType(col); | ||
| } | ||
|
|
||
| TrainSchema = trainSchema; | ||
| BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); | ||
|
|
@@ -62,7 +74,8 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) | |
| // *** Binary format *** | ||
| // model: prediction model. | ||
| // stream: empty data view that contains train schema. | ||
| // id of string: feature column. | ||
| // count of features | ||
| // id of string: feature columns. | ||
|
|
||
| // Clone the stream with the schema into memory. | ||
| var ms = new MemoryStream(); | ||
|
|
@@ -75,10 +88,19 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) | |
| var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); | ||
| TrainSchema = loader.Schema; | ||
|
|
||
| FeatureColumn = ctx.LoadString(); | ||
| if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) | ||
| throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn); | ||
| FeatureColumnType = TrainSchema.GetColumnType(col); | ||
| // count of feature columns. FAFM uses more than one. | ||
| int featCount = int.Parse(ctx.LoadString()); | ||
|
|
||
| FeatureColumn = new string[featCount]; | ||
| FeatureColumnType = new ColumnType[featCount]; | ||
|
|
||
| for (int i = 0; i < featCount; i++) | ||
| { | ||
| FeatureColumn[i] = ctx.LoadString(); | ||
| if (!TrainSchema.TryGetColumnIndex(FeatureColumn[i], out int col)) | ||
| throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn[i]); | ||
| FeatureColumnType[i] = TrainSchema.GetColumnType(col); | ||
| } | ||
|
|
||
| BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); | ||
| } | ||
|
|
@@ -87,10 +109,15 @@ public ISchema GetOutputSchema(ISchema inputSchema) | |
| { | ||
| Host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
|
||
| if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col)) | ||
| throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null); | ||
| if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType)) | ||
| throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString()); | ||
| for (int i=0; i< FeatureColumn.Length; i++) | ||
| { | ||
| var feat = FeatureColumn[i]; | ||
| if (!inputSchema.TryGetColumnIndex(feat, out int col)) | ||
| throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), null); | ||
|
|
||
| if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType[i])) | ||
| throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnType[i].ToString(), inputSchema.GetColumnType(col).ToString()); | ||
| } | ||
|
|
||
| return Transform(new EmptyDataView(Host, inputSchema)).Schema; | ||
| } | ||
|
|
@@ -109,6 +136,7 @@ protected virtual void SaveCore(ModelSaveContext ctx) | |
| // *** Binary format *** | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
whenever you save or load, need
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| // model: prediction model. | ||
|
||
| // stream: empty data view that contains train schema. | ||
| // number of feature columns | ||
| // id of string: feature column. | ||
|
|
||
| ctx.SaveModel(Model, DirModel); | ||
|
|
@@ -121,7 +149,24 @@ protected virtual void SaveCore(ModelSaveContext ctx) | |
| } | ||
| }); | ||
|
|
||
| ctx.SaveString(FeatureColumn); | ||
| int featCount = FeatureColumn.Length; | ||
|
|
||
| ctx.SaveString(featCount.ToString()); | ||
| for(int i=0; i< featCount; i++) | ||
| ctx.SaveString(FeatureColumn[i]); | ||
| } | ||
|
|
||
| protected RoleMappedSchema GetSchema(ISchema inputSchema = null, string trainLabelColumn = null) | ||
| { | ||
| var roles = new List<KeyValuePair<ColumnRole, string>>(); | ||
| foreach (var feat in FeatureColumn) | ||
| roles.Add(new KeyValuePair<ColumnRole, string>(ColumnRole.Feature, feat)); | ||
|
|
||
| if(trainLabelColumn !=null) | ||
| roles.Add(new KeyValuePair<ColumnRole, string>(ColumnRole.Label, trainLabelColumn)); | ||
|
|
||
| var schema = new RoleMappedSchema(inputSchema ?? TrainSchema, roles); | ||
| return schema; | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -133,12 +178,12 @@ public sealed class BinaryPredictionTransformer<TModel> : PredictionTransformerB | |
| public readonly string ThresholdColumn; | ||
| public readonly float Threshold; | ||
|
|
||
| public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, | ||
| public BinaryPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string[] featureColumn, | ||
| float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), model, inputSchema, featureColumn) | ||
| { | ||
| Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); | ||
| var schema = new RoleMappedSchema(inputSchema, null, featureColumn); | ||
| var schema = GetSchema(inputSchema); | ||
| Threshold = threshold; | ||
| ThresholdColumn = thresholdColumn; | ||
|
|
||
|
|
@@ -157,7 +202,7 @@ public BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) | |
| Threshold = ctx.Reader.ReadSingle(); | ||
| ThresholdColumn = ctx.LoadString(); | ||
|
|
||
| var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); | ||
| var schema = GetSchema(); | ||
| var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn }; | ||
| _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); | ||
| } | ||
|
|
@@ -201,7 +246,7 @@ public sealed class MulticlassPredictionTransformer<TModel> : PredictionTransfor | |
| private readonly string _trainLabelColumn; | ||
|
|
||
| public MulticlassPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, string labelColumn) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), model, inputSchema, featureColumn) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer<TModel>)), model, inputSchema, new[] { featureColumn }) | ||
| { | ||
| Host.CheckValueOrNull(labelColumn); | ||
|
|
||
|
|
@@ -220,7 +265,7 @@ public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ct | |
|
|
||
| _trainLabelColumn = ctx.LoadStringOrNull(); | ||
|
|
||
| var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn); | ||
| var schema = GetSchema(trainLabelColumn: _trainLabelColumn); | ||
| var args = new MultiClassClassifierScorer.Arguments(); | ||
| _scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); | ||
| } | ||
|
|
@@ -261,7 +306,7 @@ public sealed class RegressionPredictionTransformer<TModel> : PredictionTransfor | |
| private readonly GenericScorer _scorer; | ||
|
|
||
| public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), model, inputSchema, featureColumn) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), model, inputSchema, new[] { featureColumn }) | ||
| { | ||
| var schema = new RoleMappedSchema(inputSchema, null, featureColumn); | ||
| _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, inputSchema), BindableMapper.Bind(Host, schema), schema); | ||
|
|
@@ -270,7 +315,7 @@ public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISche | |
| internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx) | ||
| : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), ctx) | ||
| { | ||
| var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn); | ||
| var schema = GetSchema(); | ||
| _scorer = new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Text; | ||
| using Microsoft.ML.Runtime; | ||
| using Microsoft.ML.Runtime.Data; | ||
| using Microsoft.ML.Runtime.Training; | ||
|
|
||
| namespace Microsoft.ML.Core.Prediction | ||
| { | ||
| /// <summary> | ||
| /// Holds information relevant to trainers. It is passed to the constructor of the<see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> | ||
| /// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model. | ||
| /// This holds at least a training set, as well as optioonally a predictor. | ||
| /// </summary> | ||
| public class TrainerEstimatorContext | ||
| { | ||
| /// <summary> | ||
| /// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into | ||
| /// a trainer that does not support validation sets should not be considered an error condition. It | ||
| /// should simply be ignored in that case. | ||
| /// </summary> | ||
| public IDataView ValidationSet { get; } | ||
|
|
||
| /// <summary> | ||
| /// The initial predictor, for incremental training. Note that if a <see cref="ITrainerEstimator{IPredictionTransformer, IPredictor}"/> implementor | ||
| /// does not support incremental training, then it can ignore it similarly to how one would ignore | ||
| /// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there | ||
| /// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception. | ||
| /// </summary> | ||
| public IPredictor InitialPredictor { get; } | ||
|
|
||
| /// <summary> | ||
| /// Initializes a new instance of <see cref="TrainerEstimatorContext"/>, given a training set and optional other arguments. | ||
| /// </summary> | ||
| /// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param> | ||
| /// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param> | ||
| public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null) | ||
| { | ||
| Contracts.CheckValueOrNull(validationSet); | ||
| Contracts.CheckValueOrNull(initialPredictor); | ||
|
|
||
| ValidationSet = validationSet; | ||
| InitialPredictor = initialPredictor; | ||
| } | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we even need an interface for FFM, since for the time being it's the only trainer that accepts multiple feature columns. #Closed
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to inquire about all trainers, it is useful to have them extend one interface. #Closed