diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 18f65f37cc..850d2f6b64 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -25,6 +25,13 @@ public abstract class OnnxContext /// A name that has not yet been returned from this function, starting with public abstract string GetNodeName(string prefix); + /// + /// Determine if a string has been used as ONNX variable name somewhere. + /// + /// examined string + /// True if the input argument has been used to denote an ONNX variable. Otherwise, False. + public abstract bool IsVariableDefined(string variableName); + /// /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can /// safely call . diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 1d6b9b47a4..6dde8a251d 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Model.Onnx; using Microsoft.ML.Transforms; [assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(IDataTransform), typeof(CopyColumnsTransform), @@ -159,11 +160,13 @@ public override void Save(ModelSaveContext ctx) protected override IRowMapper MakeRowMapper(ISchema inputSchema) => new Mapper(this, inputSchema, ColumnPairs); - private sealed class Mapper : MapperBase + private sealed class Mapper : MapperBase, ISaveAsOnnx { private readonly ISchema _schema; private readonly (string Source, string Name)[] _columns; + public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental; + internal Mapper(CopyColumnsTransform parent, ISchema inputSchema, (string Source, string Name)[] columns) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { @@ -197,6 +200,20 @@ public override RowMapperColumnInfo[] GetOutputColumns() } return result; } + + public void SaveAsOnnx(OnnxContext ctx) + { + var opType = "CSharp"; + + foreach (var column in _columns) + { + var srcVariableName = ctx.GetVariableName(column.Source); + _schema.TryGetColumnIndex(column.Source, out int colIndex); + var dstVariableName = ctx.AddIntermediateVariable(_schema.GetColumnType(colIndex), column.Name); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("type", LoaderSignature); + } + } } } } diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs index 21b3958fbe..3f0e6f2fb7 100644 --- a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -60,6 +60,8 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName, public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); + public override bool IsVariableDefined(string variableName) => _variableNames.Contains(variableName); + /// /// Stops tracking a column. If removeVariable is true then it also removes the /// variable associated with it, this is useful in the event where an output variable is @@ -200,7 +202,7 @@ public string TryGetVariableName(string colName) /// /// IDataView column name. /// Unique variable name. - private string AddVariable(string colName) + public string AddVariable(string colName) { _host.CheckNonEmpty(colName, nameof(colName)); _columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains); @@ -226,16 +228,11 @@ public override string AddIntermediateVariable(ColumnType type, string colName, /// /// Adds an output variable to the list. /// - public string AddOutputVariable(ColumnType type, string colName, List dim = null) + public void AddOutputVariable(ColumnType type, string variableName, List dim = null) { _host.CheckValue(type, nameof(type)); - - if (!ContainsColumn(colName)) - AddVariable(colName); - - colName = GetVariableName(colName); - _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim)); - return colName; + _host.CheckParam(IsVariableDefined(variableName), nameof(variableName)); + _outputs.Add(OnnxUtils.GetModelArgs(type, variableName, dim)); } /// diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 84ff9beb81..241abc6fbb 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -235,13 +235,17 @@ private void Run(IChannel ch) if (end.Schema.IsHidden(i)) continue; - var idataviewColumnName = end.Schema.GetColumnName(i);; - if (_outputsToDrop.Contains(idataviewColumnName) || _inputsToDrop.Contains(idataviewColumnName)) + var idataviewColumnName = end.Schema.GetColumnName(i); + + // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in + // _inputToDrop should be removed too. + if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName)) continue; var variableName = ctx.TryGetVariableName(idataviewColumnName); - if (variableName != null) - ctx.AddOutputVariable(end.Schema.GetColumnType(i), variableName); + var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true); + ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), ""); + ctx.AddOutputVariable(end.Schema.GetColumnType(i), trueVariableName); } var model = ctx.MakeModel(); diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json index f1ebe10358..b032fc1aaf 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json @@ -340,6 +340,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "BinaryClassificationFastTreeSaveModelToOnnxTest", @@ -383,7 +413,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -401,7 +431,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -419,7 +449,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json index 92d9816c37..217e7b1fbb 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json @@ -142,6 +142,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "BinaryClassificationLRSaveModelToOnnxTest", @@ -167,7 +197,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -185,7 +215,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -203,7 +233,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json index 3989a39903..578322d150 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json @@ -193,6 +193,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "BinaryClassificationLightGBMSaveModelToOnnxTest", @@ -218,7 +248,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -236,7 +266,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -254,7 +284,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json index bcbc839312..aa498a07ad 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json @@ -311,6 +311,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "KeyToVectorBag", @@ -354,7 +384,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -372,7 +402,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -390,7 +420,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json index b21412a2bb..f7976875f1 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json @@ -111,6 +111,26 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" } ], "name": "MultiClassificationLRSaveModelToOnnxTest", @@ -136,7 +156,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "INT64", @@ -154,7 +174,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT",