diff --git a/src/System.Management.Automation/engine/CmdletParameterBinderController.cs b/src/System.Management.Automation/engine/CmdletParameterBinderController.cs index 92977652e..acf3766a3 100644 --- a/src/System.Management.Automation/engine/CmdletParameterBinderController.cs +++ b/src/System.Management.Automation/engine/CmdletParameterBinderController.cs @@ -1690,8 +1690,8 @@ namespace System.Management.Automation // Set-ClusterOwnerNode -Owners foo,bar // Set-ClusterOwnerNode foo bar // Set-ClusterOwnerNode foo,bar - // we unwrap our List, but only if there is a single argument of type object[]. - if (valueFromRemainingArguments.Count == 1 && valueFromRemainingArguments[0] is object[]) + // we unwrap our List, but only if there is a single argument which is a collection. + if (valueFromRemainingArguments.Count == 1 && LanguagePrimitives.IsObjectEnumerable(valueFromRemainingArguments[0])) { cpi.SetArgumentValue(UnboundArguments[0].ArgumentExtent, valueFromRemainingArguments[0]); } diff --git a/src/System.Management.Automation/engine/LanguagePrimitives.cs b/src/System.Management.Automation/engine/LanguagePrimitives.cs index cc4ab3064..7db18d424 100644 --- a/src/System.Management.Automation/engine/LanguagePrimitives.cs +++ b/src/System.Management.Automation/engine/LanguagePrimitives.cs @@ -440,10 +440,24 @@ namespace System.Management.Automation internal static bool IsTypeEnumerable(Type type) { + if (type == null) { return false; } GetEnumerableDelegate getEnumerable = GetOrCalculateEnumerable(type); return (getEnumerable != LanguagePrimitives.ReturnNullEnumerable); } + /// + /// Returns True if the language considers obj to be IEnumerable + /// + /// + /// IEnumerable or IEnumerable-like object + /// + [SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "Since V1 code is already shipped, excluding this message.")] + public static bool IsObjectEnumerable(object obj) + { + return IsTypeEnumerable(PSObject.Base(obj)?.GetType()); + } + + /// /// Retrieves the IEnumerable of obj or null if the language does not consider obj to be IEnumerable /// @@ -453,26 +467,9 @@ namespace System.Management.Automation [SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "Since V1 code is already shipped, excluding this message.")] public static IEnumerable GetEnumerable(object obj) { - if (obj == null) - { - return null; - } - - Type objectType = obj.GetType(); - - // if the object passed is an PSObject, - // look at the base object. Notice that, if the - // object has been serialized, the base object - // will be there as an ArrayList if the original - // object was IEnumerable - if (objectType == typeof(PSObject)) - { - PSObject mshObj = (PSObject)obj; - obj = mshObj.BaseObject; - objectType = obj.GetType(); - } - - GetEnumerableDelegate getEnumerable = GetOrCalculateEnumerable(objectType); + obj = PSObject.Base(obj); + if (obj == null) { return null; } + GetEnumerableDelegate getEnumerable = GetOrCalculateEnumerable(obj.GetType()); return getEnumerable(obj); } diff --git a/test/powershell/engine/ParameterBinding/ParameterBinding.Tests.ps1 b/test/powershell/engine/ParameterBinding/ParameterBinding.Tests.ps1 index 0bf88e864..e3d841640 100644 --- a/test/powershell/engine/ParameterBinding/ParameterBinding.Tests.ps1 +++ b/test/powershell/engine/ParameterBinding/ParameterBinding.Tests.ps1 @@ -422,5 +422,25 @@ $result.Value[1] | Should Be 2 $result.Value[2] | Should Be 3 } + + It "Binds properly when collections of type other than object[] are used on an advanced function" { + $list = [Collections.Generic.List[int]](1..3) + $result = Test-BindingFunction $list + + $result.ArgumentCount | Should Be 3 + $result.Value[0] | Should Be 1 + $result.Value[1] | Should Be 2 + $result.Value[2] | Should Be 3 + } + + It "Binds properly when collections of type other than object[] are used on a cmdlet" { + $list = [Collections.Generic.List[int]](1..3) + $result = Test-BindingCmdlet $list + + $result.ArgumentCount | Should Be 3 + $result.Value[0] | Should Be 1 + $result.Value[1] | Should Be 2 + $result.Value[2] | Should Be 3 + } } }