Add LoadAssemblyFromNativeMemory function to load assemblies from memory in a native PowerShell host (#14652)

This commit is contained in:
Marc-André Moreau 2021-04-28 17:28:41 -04:00 committed by GitHub
parent 59715d5ba9
commit 8f8ddc3fb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 0 deletions

View file

@ -587,4 +587,35 @@ namespace System.Management.Automation
PowerShellAssemblyLoadContext.InitializeSingleton(basePaths);
}
}
/// <summary>
/// Provides helper functions to faciliate calling managed code from a native PowerShell host.
/// </summary>
public static unsafe class PowerShellUnsafeAssemblyLoad
{
/// <summary>
/// Load an assembly in memory from unmanaged code.
/// </summary>
/// <remarks>
/// This API is covered by the experimental feature 'PSLoadAssemblyFromNativeCode',
/// and it may be deprecated and removed in future.
/// </remarks>
/// <param name="data">Unmanaged pointer to assembly data buffer.</param>
/// <param name="size">Size in bytes of the assembly data buffer.</param>
/// <returns>Returns zero on success and non-zero on failure.</returns>
[UnmanagedCallersOnly]
public static int LoadAssemblyFromNativeMemory(IntPtr data, int size)
{
try
{
using var stream = new UnmanagedMemoryStream((byte*)data, size);
AssemblyLoadContext.Default.LoadFromStream(stream);
return 0;
}
catch
{
return -1;
}
}
}
}

View file

@ -140,7 +140,11 @@ namespace System.Management.Automation
new ExperimentalFeature(
name: PSNativeCommandArgumentPassingFeatureName,
description: "Use ArgumentList when invoking a native command"),
new ExperimentalFeature(
name: "PSLoadAssemblyFromNativeCode",
description: "Expose an API to allow assembly loading from native code"),
};
EngineExperimentalFeatures = new ReadOnlyCollection<ExperimentalFeature>(engineFeatures);
// Initialize the readonly dictionary 'EngineExperimentalFeatureMap'.

View file

@ -0,0 +1,107 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.Loader;
using System.Management.Automation;
using Xunit;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Emit;
using Microsoft.CodeAnalysis.Text;
namespace PSTests.Sequential
{
public static class NativeInterop
{
[Fact]
public static void TestLoadNativeInMemoryAssembly()
{
string tempDir = Path.Combine(Path.GetTempPath(), "TestLoadNativeInMemoryAssembly");
string testDll = Path.Combine(tempDir, "test.dll");
if (!File.Exists(testDll))
{
Directory.CreateDirectory(tempDir);
bool result = CreateTestDll(testDll);
Assert.True(result, "The call to 'CreateTestDll' should be successful and return true.");
Assert.True(File.Exists(testDll), "The test assembly should be created.");
}
var asmName = AssemblyName.GetAssemblyName(testDll);
string asmFullName = SearchAssembly(asmName.Name);
Assert.Null(asmFullName);
unsafe
{
int ret = LoadAssemblyTest(testDll);
Assert.Equal(0, ret);
}
asmFullName = SearchAssembly(asmName.Name);
Assert.Equal(asmName.FullName, asmFullName);
}
private static unsafe int LoadAssemblyTest(string assemblyPath)
{
// The 'LoadAssemblyFromNativeMemory' method is annotated with 'UnmanagedCallersOnly' attribute,
// so we have to use the 'unmanaged' function pointer to invoke it.
delegate* unmanaged<IntPtr, int, int> funcPtr = &PowerShellUnsafeAssemblyLoad.LoadAssemblyFromNativeMemory;
int length = 0;
IntPtr nativeMem = IntPtr.Zero;
try
{
using (var fileStream = new FileStream(assemblyPath, FileMode.Open, FileAccess.Read))
{
length = (int)fileStream.Length;
nativeMem = Marshal.AllocHGlobal(length);
using var unmanagedStream = new UnmanagedMemoryStream((byte*)nativeMem, length, length, FileAccess.Write);
fileStream.CopyTo(unmanagedStream);
}
// Call the function pointer.
return funcPtr(nativeMem, length);
}
finally
{
// Free the native memory
Marshal.FreeHGlobal(nativeMem);
}
}
private static string SearchAssembly(string assemblyName)
{
Assembly asm = AssemblyLoadContext.Default.Assemblies.FirstOrDefault(
assembly => assembly.FullName.StartsWith(assemblyName, StringComparison.OrdinalIgnoreCase));
return asm?.FullName;
}
private static bool CreateTestDll(string dllPath)
{
var parseOptions = CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest);
var compilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary);
List<SyntaxTree> syntaxTrees = new();
SourceText sourceText = SourceText.From("public class Utt { }");
syntaxTrees.Add(CSharpSyntaxTree.ParseText(sourceText, parseOptions));
var refs = new List<PortableExecutableReference> { MetadataReference.CreateFromFile(typeof(object).Assembly.Location) };
Compilation compilation = CSharpCompilation.Create(
Path.GetRandomFileName(),
syntaxTrees: syntaxTrees,
references: refs,
options: compilationOptions);
using var fs = new FileStream(dllPath, FileMode.CreateNew, FileAccess.ReadWrite, FileShare.None);
EmitResult emitResult = compilation.Emit(peStream: fs, options: null);
return emitResult.Success;
}
}
}