Skip to content
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ private static VersionInfo GetVersionInfo()

public override ISchema Schema { get { return _bindings; } }

public bool CanSaveOnnx => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx : false;
public bool CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;

public bool CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;

Expand Down Expand Up @@ -338,7 +338,7 @@ public void SaveAsOnnx(OnnxContext ctx)
Host.CheckValue(ctx, nameof(ctx));
if (_mapper is ISaveAsOnnx onnx)
{
Host.Check(onnx.CanSaveOnnx, "Cannot be saved as ONNX.");
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
onnx.SaveAsOnnx(ctx);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public interface ICanSaveOnnx
/// only detectable during runtime that would prevent its being savable. (E.g.,
/// it may wrap some other object that may or may not be savable.)
/// </summary>
bool CanSaveOnnx { get; }
bool CanSaveOnnx(OnnxContext ctx);
}

/// <summary>
Expand Down
8 changes: 8 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

namespace Microsoft.ML.Runtime.Model.Onnx
{
public enum OnnxVersion { Latest, Experimental }

/// <summary>
/// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
/// same context object is iteratively given to exportable components via the <see cref="ICanSaveOnnx"/> interface
Expand Down Expand Up @@ -98,5 +100,11 @@ public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
=> CreateNode(opType, new[] { input }, new[] { output }, name, domain);

/// <summary>
/// Get the targeted ONNX version string. Only two values are allowed now: "latest" and "experimental".
/// </summary>
/// <returns></returns>
public abstract OnnxVersion GetOnnxVersion();
}
}
10 changes: 5 additions & 5 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa
public ColumnType OutputType => _mapper.OutputType;
public ColumnType DistType => NumberType.Float;
public bool CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true;
public bool CanSaveOnnx => (_mapper as ICanSaveOnnx)?.CanSaveOnnx == true;
public bool CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;

protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing<Float> predictor, ICalibrator calibrator)
: base(env, name, predictor, calibrator)
Expand Down Expand Up @@ -308,7 +308,7 @@ public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColu
return false;

var calibrator = Calibrator as ISingleCanSaveOnnx;
if (!(calibrator?.CanSaveOnnx == true && calibrator.SaveAsOnnx(ctx, new[] { outputNames[1], outputNames[2] }, featureColumnName)))
if (!(calibrator?.CanSaveOnnx(ctx) == true && calibrator.SaveAsOnnx(ctx, new[] { outputNames[1], outputNames[2] }, featureColumnName)))
ctx.RemoveVariable(outputNames[1], true);

return true;
Expand Down Expand Up @@ -617,7 +617,7 @@ private static VersionInfo GetVersionInfo()
/// </summary>
public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;

public bool CanSaveOnnx => (_bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;

public SchemaBindableCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Single> predictor, ICalibrator calibrator)
: base(env, LoaderSignature, predictor, calibrator)
Expand Down Expand Up @@ -663,7 +663,7 @@ public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] output
Host.CheckValue(ctx, nameof(ctx));
Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs");
Host.CheckValue(schema, nameof(schema));
Host.Check(CanSaveOnnx, "Called despite not being savable");
Host.Check(CanSaveOnnx(ctx), "Called despite not being savable");
return false;
}

Expand Down Expand Up @@ -1342,7 +1342,7 @@ private static VersionInfo GetVersionInfo()
public Double ParamA { get; }
public Double ParamB { get; }
public bool CanSavePfa => true;
public bool CanSaveOnnx => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;

public PlattCalibrator(IHostEnvironment env, Double paramA, Double paramB)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/GenericScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private static VersionInfo GetVersionInfo()

public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;

public bool CanSaveOnnx => (Bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;

/// <summary>
/// The <see cref="SignatureDataScorer"/> entry point for creating a <see cref="GenericScorer"/>.
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Scorers/MultiClassClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveMod

public VectorType Type => _type;
public bool CanSavePfa => (_bindable as ICanSavePfa)?.CanSavePfa == true;
public bool CanSaveOnnx => (_bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
public bool CanSaveOnnx(OnnxContext ctx) => (_bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
public ISchemaBindableMapper InnerBindable => _bindable;

private static VersionInfo GetVersionInfo()
Expand Down Expand Up @@ -207,7 +207,7 @@ public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] output
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(schema, nameof(schema));
Contracts.Check(CanSaveOnnx, "Cannot be saved as ONNX.");
Contracts.Check(CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
Contracts.Assert(_bindable is IBindableCanSaveOnnx);
return ((IBindableCanSaveOnnx)_bindable).SaveAsOnnx(ctx, schema, outputNames);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ protected override BindingsBase GetBindings()

public bool CanSavePfa => (Bindable as ICanSavePfa)?.CanSavePfa == true;

public bool CanSaveOnnx => (Bindable as ICanSaveOnnx)?.CanSaveOnnx == true;
public bool CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;

protected PredictedLabelScorerBase(ScorerArgumentsBase args, IHostEnvironment env, IDataView data,
ISchemaBoundMapper mapper, RoleMappedSchema trainSchema, string registrationName, string scoreColKind, string scoreColName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper

public bool CanSavePfa => (ValueMapper as ICanSavePfa)?.CanSavePfa == true;

public bool CanSaveOnnx => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx == true;
public bool CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;

public SchemaBindablePredictorWrapperBase(IPredictor predictor)
{
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ private sealed class Mapper : IRowMapper, ISaveAsOnnx, ISaveAsPfa
private readonly ConcatTransform _parent;
private readonly BoundColumn[] _columns;

public bool CanSaveOnnx => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;
public bool CanSavePfa => true;

public Mapper(ConcatTransform parent, ISchema inputSchema)
Expand Down Expand Up @@ -893,7 +893,7 @@ public void SaveAsPfa(BoundPfaContext ctx)
public void SaveAsOnnx(OnnxContext ctx)
{
_host.CheckValue(ctx, nameof(ctx));
Contracts.Assert(CanSaveOnnx);
Contracts.Assert(CanSaveOnnx(ctx));

string opType = "FeatureVectorizer";
for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
Expand Down
25 changes: 25 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ConvertTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Microsoft.ML.Runtime.Data.Conversion;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.EntryPoints;

Expand Down Expand Up @@ -374,6 +375,30 @@ public override void Save(ModelSaveContext ctx)
}
}

public override bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;

protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
var opType = "CSharp";

for (int i = 0; i < _exes.Length; i++)
{
var ex = _exes[i];
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("type", LoaderSignature);
node.AddAttribute("to", (byte)ex.Kind);
if (ex.HasKeyRange)
{
var key = ex.TypeDst.ItemType.AsKey;
node.AddAttribute("min", key.Min);
node.AddAttribute("max", key.Count);
node.AddAttribute("contiguous", key.Contiguous);
}
}

return true;
}

private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyRange range, out PrimitiveType itemType, out ColInfoEx ex)
{
ectx.AssertValue(info);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ private ValueGetter<VBuffer<float>> MakeGetterInd(IRow input, int iinfo)
};
}

public bool CanSaveOnnx => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;

public bool CanSavePfa => true;

Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ private AffineColumnFunction(IHost host)
public abstract void Save(ModelSaveContext ctx);

public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
public bool CanSaveOnnx => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;
public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);

public abstract Delegate GetGetter(IRow input, int icol);
Expand Down Expand Up @@ -503,7 +503,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return null;
}

public bool CanSaveOnnx => false;
public bool CanSaveOnnx(OnnxContext ctx) => false;

public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
=> throw Host.ExceptNotSupp();
Expand Down Expand Up @@ -636,7 +636,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return null;
}

public bool CanSaveOnnx => false;
public bool CanSaveOnnx(OnnxContext ctx) => false;

public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
=> throw Host.ExceptNotSupp();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/NormalizeUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ internal interface IColumnFunction : ICanSaveModel

JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);

bool CanSaveOnnx { get; }
bool CanSaveOnnx(OnnxContext ctx);

bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);
}
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Transforms/Normalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa
{
private NormalizerTransformer _parent;

public bool CanSaveOnnx => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;
public bool CanSavePfa => true;

public Mapper(NormalizerTransformer parent, ISchema schema)
Expand Down Expand Up @@ -562,12 +562,12 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColumnInfo info, string
Contracts.AssertValue(ctx);
Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length);
Contracts.Assert(_parent._columns[iinfo] == info);
Contracts.Assert(CanSaveOnnx);
Contracts.Assert(CanSaveOnnx(ctx));

if (info.InputType.ValueCount == 0)
return false;

if (info.ColumnFunction.CanSaveOnnx)
if (info.ColumnFunction.CanSaveOnnx(ctx))
{
string opType = "Scaler";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/TermTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ private sealed class Mapper : MapperBase, ISaveAsOnnx, ISaveAsPfa

private readonly BoundTermMap[] _termMap;

public bool CanSaveOnnx => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;

public bool CanSavePfa => true;

Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/TransformBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ private sealed class ColumnTmp : OneToOneColumn

public virtual bool CanSavePfa => false;

public virtual bool CanSaveOnnx => false;
public virtual bool CanSaveOnnx(OnnxContext ctx) => false;

protected OneToOneTransformBase(IHostEnvironment env, string name, OneToOneColumn[] column,
IDataView input, Func<ColumnType, string> testType)
Expand Down Expand Up @@ -574,7 +574,7 @@ public void SaveAsPfa(BoundPfaContext ctx)
public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(CanSaveOnnx);
Host.Assert(CanSaveOnnx(ctx));

for (int iinfo = 0; iinfo < Infos.Length; ++iinfo)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2860,7 +2860,7 @@ public abstract class FastTreePredictionWrapper :
public ColumnType InputType { get; }
public ColumnType OutputType => NumberType.Float;
public bool CanSavePfa => true;
public bool CanSaveOnnx => true;
public bool CanSaveOnnx(OnnxContext ctx) => true;

protected FastTreePredictionWrapper(IHostEnvironment env, string name, Ensemble trainedEnsemble, int numFeatures, string innerArgs)
: base(env, name)
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3266,6 +3266,12 @@ public OneVersusAllPipelineStep(Output output)

namespace Legacy.Models
{
public enum OnnxVersion
{
Latest = 0,
Experimental = 1
}


/// <summary>
/// Converts the model to ONNX format.
Expand Down Expand Up @@ -3309,6 +3315,11 @@ public sealed partial class OnnxConverter
/// </summary>
public Var<Microsoft.ML.Runtime.EntryPoints.ITransformModel> Model { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.ITransformModel>();

/// <summary>
/// The targeted ONNX version. It can be either "latest" or "experimental"
/// </summary>
public OnnxVersion OnnxVersion { get; set; } = OnnxVersion.Latest;

/// <summary>
/// The data file
/// </summary>
Expand Down
6 changes: 5 additions & 1 deletion src/Microsoft.ML.Onnx/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ internal sealed class OnnxContextImpl : OnnxContext
private readonly string _domain;
private readonly string _producerVersion;
private readonly long _modelVersion;
private readonly OnnxVersion _onnxVersion;

public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain)
string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(OnnxContext));
Expand All @@ -52,6 +53,7 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
_producerVersion = producerVersion;
_modelVersion = modelVersion;
_domain = domain;
_onnxVersion = onnxVersion;
}

public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
Expand Down Expand Up @@ -251,5 +253,7 @@ public void AddInputVariable(ColumnType type, string colName)
/// </summary>
public ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);

public override OnnxVersion GetOnnxVersion() => _onnxVersion;
}
}
26 changes: 25 additions & 1 deletion src/Microsoft.ML.Onnx/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,38 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,
case DataKind.TX:
dataType = TensorProto.Types.DataType.String;
break;
case DataKind.I1:
dataType = TensorProto.Types.DataType.Int8;
break;
case DataKind.U1:
dataType = TensorProto.Types.DataType.Uint8;
break;
case DataKind.I2:
dataType = TensorProto.Types.DataType.Int16;
break;
case DataKind.U2:
dataType = TensorProto.Types.DataType.Uint16;
break;
case DataKind.I4:
dataType = TensorProto.Types.DataType.Int32;
break;
case DataKind.U4:
dataType = TensorProto.Types.DataType.Int64;
break;
case DataKind.I8:
dataType = TensorProto.Types.DataType.Int64;
break;
case DataKind.U8:
dataType = TensorProto.Types.DataType.Uint64;
break;
case DataKind.R4:
dataType = TensorProto.Types.DataType.Float;
break;
case DataKind.R8:
dataType = TensorProto.Types.DataType.Double;
break;
default:
Contracts.Assert(false, "Unknown type.");
Contracts.Assert(false, "Unsupported type: DataKind " + rawKind.ToString());
break;
}

Expand Down
Loading