From 8f8ddc3fb76a03dad93f5664314c2795dd69f390 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Andr=C3=A9=20Moreau?= Date: Wed, 28 Apr 2021 17:28:41 -0400 Subject: [PATCH] Add `LoadAssemblyFromNativeMemory` function to load assemblies from memory in a native PowerShell host (#14652) --- .../CoreCLR/CorePsAssemblyLoadContext.cs | 31 +++++ .../ExperimentalFeature.cs | 4 + test/xUnit/csharp/test_NativeInterop.cs | 107 ++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 test/xUnit/csharp/test_NativeInterop.cs diff --git a/src/System.Management.Automation/CoreCLR/CorePsAssemblyLoadContext.cs b/src/System.Management.Automation/CoreCLR/CorePsAssemblyLoadContext.cs index 3dcc9ebde..4a38eee45 100644 --- a/src/System.Management.Automation/CoreCLR/CorePsAssemblyLoadContext.cs +++ b/src/System.Management.Automation/CoreCLR/CorePsAssemblyLoadContext.cs @@ -587,4 +587,35 @@ namespace System.Management.Automation PowerShellAssemblyLoadContext.InitializeSingleton(basePaths); } } + + /// + /// Provides helper functions to faciliate calling managed code from a native PowerShell host. + /// + public static unsafe class PowerShellUnsafeAssemblyLoad + { + /// + /// Load an assembly in memory from unmanaged code. + /// + /// + /// This API is covered by the experimental feature 'PSLoadAssemblyFromNativeCode', + /// and it may be deprecated and removed in future. + /// + /// Unmanaged pointer to assembly data buffer. + /// Size in bytes of the assembly data buffer. + /// Returns zero on success and non-zero on failure. + [UnmanagedCallersOnly] + public static int LoadAssemblyFromNativeMemory(IntPtr data, int size) + { + try + { + using var stream = new UnmanagedMemoryStream((byte*)data, size); + AssemblyLoadContext.Default.LoadFromStream(stream); + return 0; + } + catch + { + return -1; + } + } + } } diff --git a/src/System.Management.Automation/engine/ExperimentalFeature/ExperimentalFeature.cs b/src/System.Management.Automation/engine/ExperimentalFeature/ExperimentalFeature.cs index b0c7c3d56..b2c527550 100644 --- a/src/System.Management.Automation/engine/ExperimentalFeature/ExperimentalFeature.cs +++ b/src/System.Management.Automation/engine/ExperimentalFeature/ExperimentalFeature.cs @@ -140,7 +140,11 @@ namespace System.Management.Automation new ExperimentalFeature( name: PSNativeCommandArgumentPassingFeatureName, description: "Use ArgumentList when invoking a native command"), + new ExperimentalFeature( + name: "PSLoadAssemblyFromNativeCode", + description: "Expose an API to allow assembly loading from native code"), }; + EngineExperimentalFeatures = new ReadOnlyCollection(engineFeatures); // Initialize the readonly dictionary 'EngineExperimentalFeatureMap'. diff --git a/test/xUnit/csharp/test_NativeInterop.cs b/test/xUnit/csharp/test_NativeInterop.cs new file mode 100644 index 000000000..685e8a23c --- /dev/null +++ b/test/xUnit/csharp/test_NativeInterop.cs @@ -0,0 +1,107 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Runtime.Loader; +using System.Management.Automation; +using Xunit; + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Emit; +using Microsoft.CodeAnalysis.Text; + +namespace PSTests.Sequential +{ + public static class NativeInterop + { + [Fact] + public static void TestLoadNativeInMemoryAssembly() + { + string tempDir = Path.Combine(Path.GetTempPath(), "TestLoadNativeInMemoryAssembly"); + string testDll = Path.Combine(tempDir, "test.dll"); + + if (!File.Exists(testDll)) + { + Directory.CreateDirectory(tempDir); + bool result = CreateTestDll(testDll); + Assert.True(result, "The call to 'CreateTestDll' should be successful and return true."); + Assert.True(File.Exists(testDll), "The test assembly should be created."); + } + + var asmName = AssemblyName.GetAssemblyName(testDll); + string asmFullName = SearchAssembly(asmName.Name); + Assert.Null(asmFullName); + + unsafe + { + int ret = LoadAssemblyTest(testDll); + Assert.Equal(0, ret); + } + + asmFullName = SearchAssembly(asmName.Name); + Assert.Equal(asmName.FullName, asmFullName); + } + + private static unsafe int LoadAssemblyTest(string assemblyPath) + { + // The 'LoadAssemblyFromNativeMemory' method is annotated with 'UnmanagedCallersOnly' attribute, + // so we have to use the 'unmanaged' function pointer to invoke it. + delegate* unmanaged funcPtr = &PowerShellUnsafeAssemblyLoad.LoadAssemblyFromNativeMemory; + + int length = 0; + IntPtr nativeMem = IntPtr.Zero; + + try + { + using (var fileStream = new FileStream(assemblyPath, FileMode.Open, FileAccess.Read)) + { + length = (int)fileStream.Length; + nativeMem = Marshal.AllocHGlobal(length); + + using var unmanagedStream = new UnmanagedMemoryStream((byte*)nativeMem, length, length, FileAccess.Write); + fileStream.CopyTo(unmanagedStream); + } + + // Call the function pointer. + return funcPtr(nativeMem, length); + } + finally + { + // Free the native memory + Marshal.FreeHGlobal(nativeMem); + } + } + + private static string SearchAssembly(string assemblyName) + { + Assembly asm = AssemblyLoadContext.Default.Assemblies.FirstOrDefault( + assembly => assembly.FullName.StartsWith(assemblyName, StringComparison.OrdinalIgnoreCase)); + + return asm?.FullName; + } + + private static bool CreateTestDll(string dllPath) + { + var parseOptions = CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest); + var compilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary); + + List syntaxTrees = new(); + SourceText sourceText = SourceText.From("public class Utt { }"); + syntaxTrees.Add(CSharpSyntaxTree.ParseText(sourceText, parseOptions)); + + var refs = new List { MetadataReference.CreateFromFile(typeof(object).Assembly.Location) }; + Compilation compilation = CSharpCompilation.Create( + Path.GetRandomFileName(), + syntaxTrees: syntaxTrees, + references: refs, + options: compilationOptions); + + using var fs = new FileStream(dllPath, FileMode.CreateNew, FileAccess.ReadWrite, FileShare.None); + EmitResult emitResult = compilation.Emit(peStream: fs, options: null); + return emitResult.Success; + } + } +}