diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
index 18f65f37cc..850d2f6b64 100644
--- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
+++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
@@ -25,6 +25,13 @@ public abstract class OnnxContext
/// A name that has not yet been returned from this function, starting with
public abstract string GetNodeName(string prefix);
+ ///
+ /// Determine if a string has been used as ONNX variable name somewhere.
+ ///
+ /// examined string
+ /// True if the input argument has been used to denote an ONNX variable. Otherwise, False.
+ public abstract bool IsVariableDefined(string variableName);
+
///
/// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
/// safely call .
diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
index 1d6b9b47a4..6dde8a251d 100644
--- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs
@@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Transforms;
[assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(IDataTransform), typeof(CopyColumnsTransform),
@@ -159,11 +160,13 @@ public override void Save(ModelSaveContext ctx)
protected override IRowMapper MakeRowMapper(ISchema inputSchema)
=> new Mapper(this, inputSchema, ColumnPairs);
- private sealed class Mapper : MapperBase
+ private sealed class Mapper : MapperBase, ISaveAsOnnx
{
private readonly ISchema _schema;
private readonly (string Source, string Name)[] _columns;
+ public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
+
internal Mapper(CopyColumnsTransform parent, ISchema inputSchema, (string Source, string Name)[] columns)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
{
@@ -197,6 +200,20 @@ public override RowMapperColumnInfo[] GetOutputColumns()
}
return result;
}
+
+ public void SaveAsOnnx(OnnxContext ctx)
+ {
+ var opType = "CSharp";
+
+ foreach (var column in _columns)
+ {
+ var srcVariableName = ctx.GetVariableName(column.Source);
+ _schema.TryGetColumnIndex(column.Source, out int colIndex);
+ var dstVariableName = ctx.AddIntermediateVariable(_schema.GetColumnType(colIndex), column.Name);
+ var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
+ node.AddAttribute("type", LoaderSignature);
+ }
+ }
}
}
}
diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs
index 21b3958fbe..3f0e6f2fb7 100644
--- a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs
+++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs
@@ -60,6 +60,8 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
+ public override bool IsVariableDefined(string variableName) => _variableNames.Contains(variableName);
+
///
/// Stops tracking a column. If removeVariable is true then it also removes the
/// variable associated with it, this is useful in the event where an output variable is
@@ -200,7 +202,7 @@ public string TryGetVariableName(string colName)
///
/// IDataView column name.
/// Unique variable name.
- private string AddVariable(string colName)
+ public string AddVariable(string colName)
{
_host.CheckNonEmpty(colName, nameof(colName));
_columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains);
@@ -226,16 +228,11 @@ public override string AddIntermediateVariable(ColumnType type, string colName,
///
/// Adds an output variable to the list.
///
- public string AddOutputVariable(ColumnType type, string colName, List dim = null)
+ public void AddOutputVariable(ColumnType type, string variableName, List dim = null)
{
_host.CheckValue(type, nameof(type));
-
- if (!ContainsColumn(colName))
- AddVariable(colName);
-
- colName = GetVariableName(colName);
- _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
- return colName;
+ _host.CheckParam(IsVariableDefined(variableName), nameof(variableName));
+ _outputs.Add(OnnxUtils.GetModelArgs(type, variableName, dim));
}
///
diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
index 84ff9beb81..241abc6fbb 100644
--- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
+++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs
@@ -235,13 +235,17 @@ private void Run(IChannel ch)
if (end.Schema.IsHidden(i))
continue;
- var idataviewColumnName = end.Schema.GetColumnName(i);;
- if (_outputsToDrop.Contains(idataviewColumnName) || _inputsToDrop.Contains(idataviewColumnName))
+ var idataviewColumnName = end.Schema.GetColumnName(i);
+
+ // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
+ // _inputToDrop should be removed too.
+ if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName))
continue;
var variableName = ctx.TryGetVariableName(idataviewColumnName);
- if (variableName != null)
- ctx.AddOutputVariable(end.Schema.GetColumnType(i), variableName);
+ var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
+ ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
+ ctx.AddOutputVariable(end.Schema.GetColumnType(i), trueVariableName);
}
var model = ctx.MakeModel();
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json
index f1ebe10358..b032fc1aaf 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json
@@ -340,6 +340,36 @@
}
],
"domain": "ai.onnx.ml"
+ },
+ {
+ "input": [
+ "PredictedLabel"
+ ],
+ "output": [
+ "PredictedLabel0"
+ ],
+ "name": "Identity",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Score"
+ ],
+ "output": [
+ "Score0"
+ ],
+ "name": "Identity0",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Probability"
+ ],
+ "output": [
+ "Probability0"
+ ],
+ "name": "Identity1",
+ "opType": "Identity"
}
],
"name": "BinaryClassificationFastTreeSaveModelToOnnxTest",
@@ -383,7 +413,7 @@
],
"output": [
{
- "name": "PredictedLabel",
+ "name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -401,7 +431,7 @@
}
},
{
- "name": "Score",
+ "name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -419,7 +449,7 @@
}
},
{
- "name": "Probability",
+ "name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json
index 92d9816c37..217e7b1fbb 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json
@@ -142,6 +142,36 @@
}
],
"domain": "ai.onnx.ml"
+ },
+ {
+ "input": [
+ "PredictedLabel"
+ ],
+ "output": [
+ "PredictedLabel0"
+ ],
+ "name": "Identity",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Score"
+ ],
+ "output": [
+ "Score0"
+ ],
+ "name": "Identity0",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Probability"
+ ],
+ "output": [
+ "Probability0"
+ ],
+ "name": "Identity1",
+ "opType": "Identity"
}
],
"name": "BinaryClassificationLRSaveModelToOnnxTest",
@@ -167,7 +197,7 @@
],
"output": [
{
- "name": "PredictedLabel",
+ "name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -185,7 +215,7 @@
}
},
{
- "name": "Score",
+ "name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -203,7 +233,7 @@
}
},
{
- "name": "Probability",
+ "name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json
index 3989a39903..578322d150 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json
@@ -193,6 +193,36 @@
}
],
"domain": "ai.onnx.ml"
+ },
+ {
+ "input": [
+ "PredictedLabel"
+ ],
+ "output": [
+ "PredictedLabel0"
+ ],
+ "name": "Identity",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Score"
+ ],
+ "output": [
+ "Score0"
+ ],
+ "name": "Identity0",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Probability"
+ ],
+ "output": [
+ "Probability0"
+ ],
+ "name": "Identity1",
+ "opType": "Identity"
}
],
"name": "BinaryClassificationLightGBMSaveModelToOnnxTest",
@@ -218,7 +248,7 @@
],
"output": [
{
- "name": "PredictedLabel",
+ "name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -236,7 +266,7 @@
}
},
{
- "name": "Score",
+ "name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -254,7 +284,7 @@
}
},
{
- "name": "Probability",
+ "name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json
index bcbc839312..aa498a07ad 100644
--- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json
+++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json
@@ -311,6 +311,36 @@
}
],
"domain": "ai.onnx.ml"
+ },
+ {
+ "input": [
+ "PredictedLabel"
+ ],
+ "output": [
+ "PredictedLabel0"
+ ],
+ "name": "Identity",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Score"
+ ],
+ "output": [
+ "Score0"
+ ],
+ "name": "Identity0",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Probability"
+ ],
+ "output": [
+ "Probability0"
+ ],
+ "name": "Identity1",
+ "opType": "Identity"
}
],
"name": "KeyToVectorBag",
@@ -354,7 +384,7 @@
],
"output": [
{
- "name": "PredictedLabel",
+ "name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -372,7 +402,7 @@
}
},
{
- "name": "Score",
+ "name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",
@@ -390,7 +420,7 @@
}
},
{
- "name": "Probability",
+ "name": "Probability0",
"type": {
"tensorType": {
"elemType": "FLOAT",
diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json
index b21412a2bb..f7976875f1 100644
--- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json
+++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json
@@ -111,6 +111,26 @@
}
],
"domain": "ai.onnx.ml"
+ },
+ {
+ "input": [
+ "PredictedLabel"
+ ],
+ "output": [
+ "PredictedLabel0"
+ ],
+ "name": "Identity",
+ "opType": "Identity"
+ },
+ {
+ "input": [
+ "Score"
+ ],
+ "output": [
+ "Score0"
+ ],
+ "name": "Identity0",
+ "opType": "Identity"
}
],
"name": "MultiClassificationLRSaveModelToOnnxTest",
@@ -136,7 +156,7 @@
],
"output": [
{
- "name": "PredictedLabel",
+ "name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "INT64",
@@ -154,7 +174,7 @@
}
},
{
- "name": "Score",
+ "name": "Score0",
"type": {
"tensorType": {
"elemType": "FLOAT",