diff --git a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs index c12c3f55eb..d0b5dfd607 100644 --- a/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs +++ b/src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs @@ -247,7 +247,7 @@ private static VersionInfo GetVersionInfo() public override ISchema Schema { get { return _bindings; } } - public bool CanSaveOnnx => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx : false; + public bool CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false; public bool CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; @@ -339,7 +339,7 @@ public void SaveAsOnnx(OnnxContext ctx) Host.CheckValue(ctx, nameof(ctx)); if (_mapper is ISaveAsOnnx onnx) { - Host.Check(onnx.CanSaveOnnx, "Cannot be saved as ONNX."); + Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX."); onnx.SaveAsOnnx(ctx); } } diff --git a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs index ebb6bbbf0a..37d938e234 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs @@ -15,7 +15,7 @@ public interface ICanSaveOnnx /// only detectable during runtime that would prevent its being savable. (For example, /// it may wrap some other object that may or may not be savable.) /// - bool CanSaveOnnx { get; } + bool CanSaveOnnx(OnnxContext ctx); } /// diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 44d64bc87d..18f65f37cc 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -7,6 +7,8 @@ namespace Microsoft.ML.Runtime.Model.Onnx { + public enum OnnxVersion { Stable=0, Experimental=1 } + /// /// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This /// same context object is iteratively given to exportable components via the interface @@ -99,6 +101,12 @@ public abstract OnnxNode CreateNode(string opType, IEnumerable inputs, public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null) => CreateNode(opType, new[] { input }, new[] { output }, name, domain); + /// + /// Get the targeted ONNX version string. Only two values are allowed now: "Stable" and "Experimental". + /// + /// + public abstract OnnxVersion GetOnnxVersion(); + /// /// Retrieve the shape of an ONNX variable. Returns null if no shape for the specified variable can be found. /// diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index d20a112ba6..4c5f4c373a 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -225,7 +225,7 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa public ColumnType OutputType => _mapper.OutputType; public ColumnType DistType => NumberType.Float; public bool CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx => (_mapper as ICanSaveOnnx)?.CanSaveOnnx == true; + public bool CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing predictor, ICalibrator calibrator) : base(env, name, predictor, calibrator) @@ -308,7 +308,7 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColu return false; var calibrator = Calibrator as ISingleCanSaveOnnx; - if (!(calibrator?.CanSaveOnnx == true && calibrator.SaveAsOnnx(ctx, new[] { outputNames[1], outputNames[2] }, featureColumnName))) + if (!(calibrator?.CanSaveOnnx(ctx) == true && calibrator.SaveAsOnnx(ctx, new[] { outputNames[1], outputNames[2] }, featureColumnName))) ctx.RemoveVariable(outputNames[1], true); return true; @@ -622,7 +622,7 @@ private static VersionInfo GetVersionInfo() /// public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx => (_bindable as ICanSaveOnnx)?.CanSaveOnnx == true; + public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; public SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator calibrator) : base(env, LoaderSignature, predictor, calibrator) @@ -668,7 +668,7 @@ public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] output Host.CheckValue(ctx, nameof(ctx)); Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs"); Host.CheckValue(schema, nameof(schema)); - Host.Check(CanSaveOnnx, "Called despite not being savable"); + Host.Check(CanSaveOnnx(ctx), "Called despite not being savable"); return false; } @@ -1349,7 +1349,7 @@ private static VersionInfo GetVersionInfo() public Double ParamA { get; } public Double ParamB { get; } public bool CanSavePfa => true; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; public PlattCalibrator(IHostEnvironment env, Double paramA, Double paramB) { diff --git a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs index ad595ea706..7a8fb6ad34 100644 --- a/src/Microsoft.ML.Data/Scorers/GenericScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/GenericScorer.cs @@ -141,7 +141,7 @@ private static VersionInfo GetVersionInfo() public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx => (Bindable as ICanSaveOnnx)?.CanSaveOnnx == true; + public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; /// /// The entry point for creating a . diff --git a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs index 4d48552a66..46747c4b3a 100644 --- a/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs @@ -78,7 +78,7 @@ public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveMod public VectorType Type => _type; public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx => (_bindable as ICanSaveOnnx)?.CanSaveOnnx == true; + public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; public ISchemaBindableMapper InnerBindable => _bindable; private static VersionInfo GetVersionInfo() @@ -209,7 +209,7 @@ public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] output { Contracts.CheckValue(ctx, nameof(ctx)); Contracts.CheckValue(schema, nameof(schema)); - Contracts.Check(CanSaveOnnx, "Cannot be saved as ONNX."); + Contracts.Check(CanSaveOnnx(ctx), "Cannot be saved as ONNX."); Contracts.Assert(_bindable is IBindableCanSaveOnnx); return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames); } diff --git a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs index a9c251dd60..a10567df6c 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs @@ -284,7 +284,7 @@ protected override BindingsBase GetBindings() public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx => (Bindable as ICanSaveOnnx)?.CanSaveOnnx == true; + public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName, diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index c64a887db7..4be45f683f 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -46,7 +46,7 @@ public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper public bool CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true; - public bool CanSaveOnnx => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx == true; + public bool CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true; public SchemaBindablePredictorWrapperBase(IPredictor predictor) { diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index ac50c8c82c..d10941b905 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -431,7 +431,7 @@ private sealed class Mapper : IRowMapper, ISaveAsOnnx, ISaveAsPfa private readonly ConcatTransform _parent; private readonly BoundColumn[] _columns; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; public bool CanSavePfa => true; public Mapper(ConcatTransform parent, ISchema inputSchema) @@ -902,7 +902,7 @@ public void SaveAsPfa(BoundPfaContext ctx) public void SaveAsOnnx(OnnxContext ctx) { _host.CheckValue(ctx, nameof(ctx)); - Contracts.Assert(CanSaveOnnx); + Contracts.Assert(CanSaveOnnx(ctx)); string opType = "FeatureVectorizer"; for (int iinfo = 0; iinfo < _columns.Length; ++iinfo) diff --git a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs index 62906ffa55..6f46a894b3 100644 --- a/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConvertTransform.cs @@ -16,6 +16,7 @@ using Microsoft.ML.Runtime.Data.Conversion; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.Runtime.Command; using Microsoft.ML.Runtime.EntryPoints; @@ -375,6 +376,30 @@ public override void Save(ModelSaveContext ctx) } } + public override bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental; + + protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) + { + var opType = "CSharp"; + + for (int i = 0; i < _exes.Length; i++) + { + var ex = _exes[i]; + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("type", LoaderSignature); + node.AddAttribute("to", (byte)ex.Kind); + if (ex.HasKeyRange) + { + var key = ex.TypeDst.ItemType.AsKey; + node.AddAttribute("min", key.Min); + node.AddAttribute("max", key.Count); + node.AddAttribute("contiguous", key.Contiguous); + } + } + + return true; + } + private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyRange range, out PrimitiveType itemType, out ColInfoEx ex) { ectx.AssertValue(info); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index a1e6b1f915..0b257587d3 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -605,7 +605,7 @@ private ValueGetter> MakeGetterInd(IRow input, int iinfo) }; } - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; public bool CanSavePfa => true; diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index a91c4645eb..74d635d791 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -384,7 +384,7 @@ private AffineColumnFunction(IHost host) public abstract void Save(ModelSaveContext ctx); public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); public abstract Delegate GetGetter(IRow input, int icol); @@ -503,7 +503,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return null; } - public bool CanSaveOnnx => false; + public bool CanSaveOnnx(OnnxContext ctx) => false; public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); @@ -637,7 +637,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken) return null; } - public bool CanSaveOnnx => false; + public bool CanSaveOnnx(OnnxContext ctx) => false; public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount) => throw Host.ExceptNotSupp(); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs index 8ef572af2b..82562282c4 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs @@ -60,7 +60,7 @@ internal interface IColumnFunction : ICanSaveModel JToken PfaInfo(BoundPfaContext ctx, JToken srcToken); - bool CanSaveOnnx { get; } + bool CanSaveOnnx(OnnxContext ctx); bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount); } diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index 4c3d9fd51d..4dfd33c113 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -453,7 +453,7 @@ private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa { private NormalizerTransformer _parent; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; public bool CanSavePfa => true; public Mapper(NormalizerTransformer parent, ISchema schema) @@ -563,12 +563,12 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColumnInfo info, string Contracts.AssertValue(ctx); Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); Contracts.Assert(_parent._columns[iinfo] == info); - Contracts.Assert(CanSaveOnnx); + Contracts.Assert(CanSaveOnnx(ctx)); if (info.InputType.ValueCount == 0) return false; - if (info.ColumnFunction.CanSaveOnnx) + if (info.ColumnFunction.CanSaveOnnx(ctx)) { string opType = "Scaler"; var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index b35c39d808..3726fec48a 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -722,7 +722,7 @@ private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa private readonly BoundTermMap[] _termMap; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; public bool CanSavePfa => true; diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index b66ba2a83a..0310c986ee 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -470,7 +470,7 @@ private sealed class ColumnTmp : OneToOneColumn public virtual bool CanSavePfa => false; - public virtual bool CanSaveOnnx => false; + public virtual bool CanSaveOnnx(OnnxContext ctx) => false; protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneColumn[] column, IDataView input, Func testType) @@ -575,7 +575,7 @@ public void SaveAsPfa(BoundPfaContext ctx) public void SaveAsOnnx(OnnxContext ctx) { Host.CheckValue(ctx, nameof(ctx)); - Host.Assert(CanSaveOnnx); + Host.Assert(CanSaveOnnx(ctx)); for (int iinfo = 0; iinfo < Infos.Length; ++iinfo) { diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 49f6787938..1f34a3f522 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -2859,7 +2859,7 @@ public abstract class FastTreePredictionWrapper : public ColumnType InputType { get; } public ColumnType OutputType => NumberType.Float; public bool CanSavePfa => true; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; protected FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs) : base(env, name) diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 95a1f31899..42518890a5 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -3362,6 +3362,12 @@ public OneVersusAllPipelineStep(Output output) namespace Legacy.Models { + public enum OnnxVersion + { + Stable = 0, + Experimental = 1 + } + /// /// Converts the model to ONNX format. @@ -3405,6 +3411,11 @@ public sealed partial class OnnxConverter /// public Var Model { get; set; } = new Var(); + /// + /// The targeted ONNX version. It can be either "Stable" or "Experimental". If "Experimental" is used, produced model can contain components that is not officially supported in ONNX standard. + /// + public OnnxVersion OnnxVersion { get; set; } = OnnxVersion.Stable; + /// /// The data file /// diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs index a426d4881f..21b3958fbe 100644 --- a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -32,9 +32,10 @@ internal sealed class OnnxContextImpl : OnnxContext private readonly string _domain; private readonly string _producerVersion; private readonly long _modelVersion; + private readonly OnnxVersion _onnxVersion; public OnnxContextImpl(IHostEnvironment env, string name, string producerName, - string producerVersion, long modelVersion, string domain) + string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(OnnxContext)); @@ -54,6 +55,7 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName, _producerVersion = producerVersion; _modelVersion = modelVersion; _domain = domain; + _onnxVersion = onnxVersion; } public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); @@ -330,5 +332,11 @@ public override string AddInitializer(IEnumerable values, IEnumerable public ModelProto MakeModel() => OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers); + + /// + /// Return either "Experimental" or "Stable". The string "Experimental" indicates that some experimental features which are + /// not officially supported in the official ONNX standard. Otherwise, only official ONNX features should be used. + /// + public override OnnxVersion GetOnnxVersion() => _onnxVersion; } } diff --git a/src/Microsoft.ML.Onnx/OnnxUtils.cs b/src/Microsoft.ML.Onnx/OnnxUtils.cs index d57d08336d..83623ab0ba 100644 --- a/src/Microsoft.ML.Onnx/OnnxUtils.cs +++ b/src/Microsoft.ML.Onnx/OnnxUtils.cs @@ -311,14 +311,39 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName, case DataKind.TX: dataType = TensorProto.Types.DataType.String; break; + case DataKind.I1: + dataType = TensorProto.Types.DataType.Int8; + break; + case DataKind.U1: + dataType = TensorProto.Types.DataType.Uint8; + break; + case DataKind.I2: + dataType = TensorProto.Types.DataType.Int16; + break; + case DataKind.U2: + dataType = TensorProto.Types.DataType.Uint16; + break; + case DataKind.I4: + dataType = TensorProto.Types.DataType.Int32; + break; case DataKind.U4: dataType = TensorProto.Types.DataType.Int64; break; + case DataKind.I8: + dataType = TensorProto.Types.DataType.Int64; + break; + case DataKind.U8: + dataType = TensorProto.Types.DataType.Uint64; + break; case DataKind.R4: dataType = TensorProto.Types.DataType.Float; break; + case DataKind.R8: + dataType = TensorProto.Types.DataType.Double; + break; default: - Contracts.Assert(false, "Unknown type."); + string msg = "Unsupported type: DataKind " + rawKind.ToString(); + Contracts.Check(false, msg); break; } diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 131c327369..84ff9beb81 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -57,6 +57,9 @@ public sealed class Arguments : DataCommand.ArgumentsBase [Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)] public ITransformModel Model; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.", SortOrder = 11)] + public OnnxVersion OnnxVersion; } private readonly string _outputModelPath; @@ -110,7 +113,7 @@ public override void Run() } } - private void GetPipe(IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList transforms) + private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataView source, out IDataView trueEnd, out LinkedList transforms) { Host.AssertValue(end); source = trueEnd = (end as CompositeDataLoader)?.View ?? end; @@ -119,7 +122,7 @@ private void GetPipe(IChannel ch, IDataView end, out IDataView source, out IData while (transform != null) { ITransformCanSaveOnnx onnxTransform = transform as ITransformCanSaveOnnx; - if (onnxTransform == null || !onnxTransform.CanSaveOnnx) + if (onnxTransform == null || !onnxTransform.CanSaveOnnx(ctx)) { ch.Warning("Had to stop walkback of pipeline at {0} since it cannot save itself as ONNX.", transform.GetType().Name); while (source as IDataTransform != null) @@ -159,18 +162,19 @@ private void Run(IChannel ch) else view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema)); + // Create the ONNX context for storing global information + var assembly = System.Reflection.Assembly.GetExecutingAssembly(); + var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); + var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, + ModelVersion, _domain, Args.OnnxVersion); + // Get the transform chain. IDataView source; IDataView end; LinkedList transforms; - GetPipe(ch, view, out source, out end, out transforms); + GetPipe(ctx, ch, view, out source, out end, out transforms); Host.Assert(transforms.Count == 0 || transforms.Last.Value == end); - var assembly = System.Reflection.Assembly.GetExecutingAssembly(); - var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location); - - var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion, - ModelVersion, _domain); // If we have a predictor, try to get the scorer for it. if (rawPred != null) { @@ -187,7 +191,7 @@ private void Run(IChannel ch) var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema); var scoreOnnx = scorePipe as ITransformCanSaveOnnx; - if (scoreOnnx?.CanSaveOnnx == true) + if (scoreOnnx?.CanSaveOnnx(ctx) == true) { Host.Assert(scorePipe.Source == end); end = scorePipe; @@ -221,7 +225,7 @@ private void Run(IChannel ch) //Create graph nodes, outputs and intermediate values. foreach (var trans in transforms) { - Host.Assert(trans.CanSaveOnnx); + Host.Assert(trans.CanSaveOnnx(ctx)); trans.SaveAsOnnx(ctx); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs index ed41a14174..db1075fc2d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs @@ -101,7 +101,7 @@ IEnumerator IEnumerable.GetEnumerator() public bool CanSavePfa => true; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; /// /// Constructs a new linear predictor. diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs index 2d6e8daa7b..608d5aaf00 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/MulticlassLogisticRegression.cs @@ -392,7 +392,7 @@ private static VersionInfo GetVersionInfo() public ColumnType InputType { get; } public ColumnType OutputType { get; } public bool CanSavePfa => true; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; internal MulticlassLogisticRegressionPredictor(IHostEnvironment env, ref VBuffer weights, int numClasses, int numFeatures, string[] labelNames, LinearModelStatistics stats = null) : base(env, RegistrationName) diff --git a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs index 38d91bcbea..5fd853080c 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceTransform.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceTransform.cs @@ -564,7 +564,6 @@ protected override IRowMapper MakeRowMapper(ISchema schema) private sealed class Mapper : MapperBase, ISaveAsOnnx { - private sealed class ColInfo { public readonly string Name; @@ -584,7 +583,7 @@ public ColInfo(string name, string source, ColumnType type) private readonly ColumnType[] _types; // The isNA delegates, parallel to Infos. private readonly Delegate[] _isNAs; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => true; public Mapper(NAReplaceTransform parent, ISchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index c189660df5..edd0fa93f4 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -2479,6 +2479,21 @@ "Required": true, "SortOrder": 10.0, "IsNullable": false + }, + { + "Name": "OnnxVersion", + "Type": { + "Kind": "Enum", + "Values": [ + "Stable", + "Experimental" + ] + }, + "Desc": "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.", + "Required": false, + "SortOrder": 11.0, + "IsNullable": false, + "Default": "Stable" } ], "Outputs": [] diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs index 11604969a1..ace43ee442 100644 --- a/test/Microsoft.ML.Tests/OnnxTests.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -59,7 +59,7 @@ public void InitializerCreationTest() using (var env = new ConsoleEnvironment()) { // Create the actual implementation - var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test"); + var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test", Runtime.Model.Onnx.OnnxVersion.Stable); // Use implementation as in the actual conversion code var ctx = ctxImpl as OnnxContext;