Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions src/Common/AssemblyLoadingUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.IO;
using System.IO.Compression;
using System.Reflection;

namespace Microsoft.ML.Runtime
{
internal static class AssemblyLoadingUtils
{
/// <summary>
/// Make sure the given assemblies are loaded and that their loadable classes have been catalogued.
/// </summary>
public static void LoadAndRegister(IHostEnvironment env, string[] assemblies)
{
Contracts.AssertValue(env);

if (Utils.Size(assemblies) > 0)
{
foreach (string path in assemblies)
{
Exception ex = null;
try
{
// REVIEW: Will LoadFrom ever return null?
Contracts.CheckNonEmpty(path, nameof(path));
var assem = LoadAssembly(env, path);
if (assem != null)
continue;
}
catch (Exception e)
{
ex = e;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we try to load zip files? Why not either load or unzip depending on the extension?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above - existing behavior from the command line code that was removed from ComponentCatalog.

}

// If it is a zip file, load it that way.
ZipArchive zip;
try
{
zip = ZipFile.OpenRead(path);
}
catch (Exception e)
{
// Couldn't load as an assembly and not a zip, so warn the user.
ex = ex ?? e;
Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be writing to the Console from libraries. And more generally, why do we swallow this error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was copied out of the ComponentCatalog and moved to an internal class that is only called from the command line. (see src/Microsoft.ML.Maml/HelpCommand.cs and src/Microsoft.ML.ResultProcessor/ResultProcessor.cs) I wasn't going to change the behavior, because I'm sure someone is depending on it.

continue;
}

string dir;
try
{
dir = CreateTempDirectory();
}
catch (Exception e)
{
throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path);
}

try
{
zip.ExtractToDirectory(dir);
}
catch (Exception e)
{
throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path);
}

LoadAssembliesInDir(env, dir, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we use zip files and just blast them into a folder instead using Nuget? It only works if the extensions don't use other dependencies. The moment they use other dependencies, we need to worry about versions. Nuget (theoretically) handles all of this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is existing behavior that I am refactoring out of ComponentCatalog so it isn't part of the public API anymore. It is only called from 2 places in the command-line (HelpCommand and ResultProcessor).

ComponentCatalog.CacheClassesExtra(_extraAssemblies);

// extra DLLs for dynamic loading
[Argument(ArgumentType.Multiple, HelpText = "Extra DLLs", ShortName = "dll")]
public string[] ExtraAssemblies = null;

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who deletes the temp path and when?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess nobody. This is an internal API that is only called from 2 places: HelpCommand and ResultProcessor, so it isn't part of our public API after this change.

}
}

public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the only thing that can be done with the registrar is to dispose it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This is an internal method that is only invoked on that command-line. It creates an object that listens for new assemblies to be loaded, and registers them automatically.

{
Contracts.CheckValue(env, nameof(env));
env.CheckValueOrNull(loadAssembliesPath);

return new AssemblyRegistrar(env, loadAssembliesPath);
}

public static void RegisterCurrentLoadedAssemblies(IHostEnvironment env)
{
Contracts.CheckValue(env, nameof(env));

foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will attempt to register lots of assemblies (e.g. all framework assemblies). Isn't it wasteful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I brought back the code that checks if the assembly references the assembly containing the LoadableClassAttributeBase class.

bool found = false;
var targetName = target.GetName();
foreach (var name in assembly.GetReferencedAssemblies())
{
if (name.Name == targetName.Name)
{
found = true;
break;
}
}
if (!found)
continue;

{
// Ignore dynamic assemblies.
if (a.IsDynamic)
continue;

env.ComponentCatalog.RegisterAssembly(a);
}
}

private static string CreateTempDirectory()
{
string dir = GetTempPath();
Directory.CreateDirectory(dir);
return dir;
}

private static string GetTempPath()
{
Guid guid = Guid.NewGuid();
return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "MLNET_" + guid.ToString()));
}

private static readonly string[] _filePrefixesToAvoid = new string[] {
"api-ms-win",
"clr",
"coreclr",
"dbgshim",
"ext-ms-win",
"microsoft.bond.",
"microsoft.cosmos.",
"microsoft.csharp",
"microsoft.data.",
"microsoft.hpc.",
"microsoft.live.",
"microsoft.platformbuilder.",
"microsoft.visualbasic",
"microsoft.visualstudio.",
"microsoft.win32",
"microsoft.windowsapicodepack.",
"microsoft.windowsazure.",
"mscor",
"msvc",
"petzold.",
"roslyn.",
"sho",
"sni",
"sqm",
"system.",
"zlib",
};

private static bool ShouldSkipPath(string path)
{
string name = Path.GetFileName(path).ToLowerInvariant();
switch (name)
{
case "cqo.dll":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did we come up with the list? Are we should the list does not have items that are not needed anymore? Might be at least worth to add a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the list in the current code. I just moved the list out of the ComponentCatalog class.

private static bool ShouldSkipPath(string path)
{
string name = Path.GetFileName(path).ToLowerInvariant();
switch (name)
{
case "cqo.dll":
case "fasttreenative.dll":
case "libiomp5md.dll":
case "libvw.dll":
case "matrixinterf.dll":
case "microsoft.ml.neuralnetworks.gpucuda.dll":
case "mklimports.dll":
case "microsoft.research.controls.decisiontrees.dll":
case "microsoft.ml.neuralnetworks.sse.dll":
case "neuraltreeevaluator.dll":
case "optimizationbuilderdotnet.dll":
case "parallelcommunicator.dll":
case "microsoft.ml.runtime.runtests.dll":
case "scopecompiler.dll":
case "tbb.dll":
case "internallearnscope.dll":
case "unmanagedlib.dll":
case "vcclient.dll":
case "libxgboost.dll":
case "zedgraph.dll":
case "__scopecodegen__.dll":
case "cosmosClientApi.dll":

case "fasttreenative.dll":
case "libiomp5md.dll":
case "libvw.dll":
case "matrixinterf.dll":
case "microsoft.ml.neuralnetworks.gpucuda.dll":
case "mklimports.dll":
case "microsoft.research.controls.decisiontrees.dll":
case "microsoft.ml.neuralnetworks.sse.dll":
case "neuraltreeevaluator.dll":
case "optimizationbuilderdotnet.dll":
case "parallelcommunicator.dll":
case "microsoft.ml.runtime.runtests.dll":
case "scopecompiler.dll":
case "tbb.dll":
case "internallearnscope.dll":
case "unmanagedlib.dll":
case "vcclient.dll":
case "libxgboost.dll":
case "zedgraph.dll":
case "__scopecodegen__.dll":
case "cosmosClientApi.dll":
return true;
}

foreach (var s in _filePrefixesToAvoid)
{
if (name.StartsWith(s))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should use the invariant culture/comparison

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I changed this to be StringComparison.OrdinalIgnoreCase

return true;
}

return false;
}

private static void LoadAssembliesInDir(IHostEnvironment env, string dir, bool filter)
{
if (!Directory.Exists(dir))
return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you silently return? Why not throw or at least debug assert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there are cases where this is called on potentially non-existing directories, for example here where we try calling it on the AutoLoad folder, which may not exist. (again that was all existing behavior that I'm refactoring out of ComponentCatalog.)


// Load all dlls in the given directory.
var paths = Directory.EnumerateFiles(dir, "*.dll");
foreach (string path in paths)
{
if (filter && ShouldSkipPath(path))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should at least log/trace when we skip files. If somebody creates an extension that starts with "Clr" (not very far fetched) it will silently fail to load.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've logged an Info message for this.

continue;

LoadAssembly(env, path);
}
}

/// <summary>
/// Given an assembly path, load the assembly and register it with the ComponentCatalog.
/// </summary>
private static Assembly LoadAssembly(IHostEnvironment env, string path)
{
try
{
var assembly = Assembly.LoadFrom(path);
env.ComponentCatalog.RegisterAssembly(assembly);
return assembly;
}
catch (Exception)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, why do we try to swallow all these errors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed it to only swallow errors from Assembly.LoadFrom, and log an error, which is the existing behavior.

{
return null;
}
}

private sealed class AssemblyRegistrar : IDisposable
{
private readonly IHostEnvironment _env;

public AssemblyRegistrar(IHostEnvironment env, string path)
{
_env = env;

RegisterCurrentLoadedAssemblies(_env);

if (!string.IsNullOrEmpty(path))
{
LoadAssembliesInDir(_env, path, true);
path = Path.Combine(path, "AutoLoad");
LoadAssembliesInDir(_env, path, true);
}

AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad;
}

public void Dispose()
{
AppDomain.CurrentDomain.AssemblyLoad -= CurrentDomainAssemblyLoad;
}

private void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args)
{
// Don't try to index dynamic generated assembly
if (args.LoadedAssembly.IsDynamic)
return;

_env.ComponentCatalog.RegisterAssembly(args.LoadedAssembly);
}
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Api/ComponentCreation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs ar
{
env.CheckValue(args, nameof(args));

var classes = ComponentCatalog.FindLoadableClasses<TArgs, TSig>();
var classes = env.ComponentCatalog.FindLoadableClasses<TArgs, TSig>();
if (classes.Length == 0)
throw env.Except("Couldn't find a {0} class that accepts {1} as arguments.", typeof(TRes).Name, typeof(TArgs).FullName);
if (classes.Length > 1)
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Api/SerializableLambdaTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public static VersionInfo GetVersionInfo()
verWrittenCur: 0x00010001,
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(SerializableLambdaTransform).Assembly.FullName);
}

public const string LoaderSignature = "UserLambdaMapTransform";
Expand Down
Loading