Skip to content

Commit 5684398

Browse files
authored
Save ConvertTransform as ONNX Operator and Control the Use of Experimental Features with a Flag (#947)
* Some changes for adding experimental conversion to ONNX 1. Introduce a new argument to SaveOnnx, which is OnnxVersion. Two values are currently allowed, "Latest" and "Experimental". Note that "Latest" means that the produced ONNX model meets the latest ONNX release while "Experimental" may produce things not officially supported in ONNX. 2. For (1), the interface of saving ONNX is slightly changed. Now, CanSaveOnnx requires an OnnxContext as its input argument. 3. Add exporter for ConvertTransform. It doesn't use standard ONNX operator. * Update CSharpAPI and entry point * Address one comment * Update old APIs to reflect enum's change * Extend doc string for targeted version * Address comments * Update API
1 parent 8ca1c93 commit 5684398

File tree

26 files changed

+139
-44
lines changed

26 files changed

+139
-44
lines changed

src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ private static VersionInfo GetVersionInfo()
247247

248248
public override ISchema Schema { get { return _bindings; } }
249249

250-
public bool CanSaveOnnx => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx : false;
250+
public bool CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
251251

252252
public bool CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;
253253

@@ -339,7 +339,7 @@ public void SaveAsOnnx(OnnxContext ctx)
339339
Host.CheckValue(ctx, nameof(ctx));
340340
if (_mapper is ISaveAsOnnx onnx)
341341
{
342-
Host.Check(onnx.CanSaveOnnx, "Cannot be saved as ONNX.");
342+
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
343343
onnx.SaveAsOnnx(ctx);
344344
}
345345
}

src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public interface ICanSaveOnnx
1515
/// only detectable during runtime that would prevent its being savable. (For example,
1616
/// it may wrap some other object that may or may not be savable.)
1717
/// </summary>
18-
bool CanSaveOnnx { get; }
18+
bool CanSaveOnnx(OnnxContext ctx);
1919
}
2020

2121
/// <summary>

src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
namespace Microsoft.ML.Runtime.Model.Onnx
99
{
10+
public enum OnnxVersion { Stable=0, Experimental=1 }
11+
1012
/// <summary>
1113
/// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
1214
/// same context object is iteratively given to exportable components via the <see cref="ICanSaveOnnx"/> interface
@@ -99,6 +101,12 @@ public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
99101
public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
100102
=> CreateNode(opType, new[] { input }, new[] { output }, name, domain);
101103

104+
/// <summary>
105+
/// Get the targeted ONNX version string. Only two values are allowed now: "Stable" and "Experimental".
106+
/// </summary>
107+
/// <returns></returns>
108+
public abstract OnnxVersion GetOnnxVersion();
109+
102110
/// <summary>
103111
/// Retrieve the shape of an ONNX variable. Returns null if no shape for the specified variable can be found.
104112
/// </summary>

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa
225225
public ColumnType OutputType => _mapper.OutputType;
226226
public ColumnType DistType => NumberType.Float;
227227
public bool CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true;
228-
public bool CanSaveOnnx => (_mapper as ICanSaveOnnx)?.CanSaveOnnx == true;
228+
public bool CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
229229

230230
protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing<Float> predictor, ICalibrator calibrator)
231231
: base(env, name, predictor, calibrator)
@@ -308,7 +308,7 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColu
308308
return false;
309309

310310
var calibrator = Calibrator as ISingleCanSaveOnnx;
311-
if (!(calibrator?.CanSaveOnnx == true && calibrator.SaveAsOnnx(ctx, new[] { outputNames[1], outputNames[2] }, featureColumnName)))
311+
if (!(calibrator?.CanSaveOnnx(ctx) == true && calibrator.SaveAsOnnx(ctx, new[] { outputNames[1], outputNames[2] }, featureColumnName)))
312312
ctx.RemoveVariable(outputNames[1], true);
313313

314314
return true;
@@ -622,7 +622,7 @@ private static VersionInfo GetVersionInfo()
622622
/// </summary>
623623
public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;
624624

625-
public bool CanSaveOnnx => (_bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
625+
public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
626626

627627
public SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Single> predictor, ICalibrator calibrator)
628628
: base(env, LoaderSignature, predictor, calibrator)
@@ -668,7 +668,7 @@ public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] output
668668
Host.CheckValue(ctx, nameof(ctx));
669669
Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs");
670670
Host.CheckValue(schema, nameof(schema));
671-
Host.Check(CanSaveOnnx, "Called despite not being savable");
671+
Host.Check(CanSaveOnnx(ctx), "Called despite not being savable");
672672
return false;
673673
}
674674

@@ -1349,7 +1349,7 @@ private static VersionInfo GetVersionInfo()
13491349
public Double ParamA { get; }
13501350
public Double ParamB { get; }
13511351
public bool CanSavePfa => true;
1352-
public bool CanSaveOnnx => true;
1352+
public bool CanSaveOnnx(OnnxContext ctx) => true;
13531353

13541354
public PlattCalibrator(IHostEnvironment env, Double paramA, Double paramB)
13551355
{

src/Microsoft.ML.Data/Scorers/GenericScorer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private static VersionInfo GetVersionInfo()
141141

142142
public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;
143143

144-
public bool CanSaveOnnx => (Bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
144+
public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
145145

146146
/// <summary>
147147
/// The <see cref="SignatureDataScorer"/> entry point for creating a <see cref="GenericScorer"/>.

src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveMod
7878

7979
public VectorType Type => _type;
8080
public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;
81-
public bool CanSaveOnnx => (_bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
81+
public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
8282
public ISchemaBindableMapper InnerBindable => _bindable;
8383

8484
private static VersionInfo GetVersionInfo()
@@ -209,7 +209,7 @@ public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] output
209209
{
210210
Contracts.CheckValue(ctx, nameof(ctx));
211211
Contracts.CheckValue(schema, nameof(schema));
212-
Contracts.Check(CanSaveOnnx, "Cannot be saved as ONNX.");
212+
Contracts.Check(CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
213213
Contracts.Assert(_bindable is IBindableCanSaveOnnx);
214214
return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames);
215215
}

src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ protected override BindingsBase GetBindings()
284284

285285
public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;
286286

287-
public bool CanSaveOnnx => (Bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
287+
public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
288288

289289
protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data,
290290
ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName,

src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper
4646

4747
public bool CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true;
4848

49-
public bool CanSaveOnnx => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx == true;
49+
public bool CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
5050

5151
public SchemaBindablePredictorWrapperBase(IPredictor predictor)
5252
{

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ private sealed class Mapper : IRowMapper, ISaveAsOnnx, ISaveAsPfa
431431
private readonly ConcatTransform _parent;
432432
private readonly BoundColumn[] _columns;
433433

434-
public bool CanSaveOnnx => true;
434+
public bool CanSaveOnnx(OnnxContext ctx) => true;
435435
public bool CanSavePfa => true;
436436

437437
public Mapper(ConcatTransform parent, ISchema inputSchema)
@@ -902,7 +902,7 @@ public void SaveAsPfa(BoundPfaContext ctx)
902902
public void SaveAsOnnx(OnnxContext ctx)
903903
{
904904
_host.CheckValue(ctx, nameof(ctx));
905-
Contracts.Assert(CanSaveOnnx);
905+
Contracts.Assert(CanSaveOnnx(ctx));
906906

907907
string opType = "FeatureVectorizer";
908908
for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)

src/Microsoft.ML.Data/Transforms/ConvertTransform.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using Microsoft.ML.Runtime.Data.Conversion;
1717
using Microsoft.ML.Runtime.Internal.Utilities;
1818
using Microsoft.ML.Runtime.Model;
19+
using Microsoft.ML.Runtime.Model.Onnx;
1920
using Microsoft.ML.Runtime.Command;
2021
using Microsoft.ML.Runtime.EntryPoints;
2122

@@ -375,6 +376,30 @@ public override void Save(ModelSaveContext ctx)
375376
}
376377
}
377378

379+
public override bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
380+
381+
protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
382+
{
383+
var opType = "CSharp";
384+
385+
for (int i = 0; i < _exes.Length; i++)
386+
{
387+
var ex = _exes[i];
388+
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
389+
node.AddAttribute("type", LoaderSignature);
390+
node.AddAttribute("to", (byte)ex.Kind);
391+
if (ex.HasKeyRange)
392+
{
393+
var key = ex.TypeDst.ItemType.AsKey;
394+
node.AddAttribute("min", key.Min);
395+
node.AddAttribute("max", key.Count);
396+
node.AddAttribute("contiguous", key.Contiguous);
397+
}
398+
}
399+
400+
return true;
401+
}
402+
378403
private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyRange range, out PrimitiveType itemType, out ColInfoEx ex)
379404
{
380405
ectx.AssertValue(info);

0 commit comments

Comments
 (0)