Fix Python mocks (#4074)

The original version of this code caused inconsistencies in the event
loop associated with a given thread. These changes elimintate the event
loop shenanigans the mocks were trying to play by updating _sync_await
to create an event loop if none exists in the current thread.

It's possible that this will cause problems if the tests run on a
different thread than the original program, as the tests are likely to
end up waiting on outputs created by the program, which is not supported
in Python.

Also adds test coverage of the mocking/testing support in Python.
This commit is contained in:
Luke Hoban 2020-03-12 21:09:47 -07:00 committed by GitHub
parent ef6f0d4de4
commit 9da774e180
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 119 additions and 15 deletions

2
.gitignore vendored
View file

@ -21,3 +21,5 @@ coverage.cov
# By default, we don't check in yarn.lock files
**/yarn.lock
.mypy_cache

View file

@ -33,12 +33,8 @@ if TYPE_CHECKING:
from ..resource import Resource
loop = None
def test(fn):
def wrapper(*args, **kwargs):
asyncio.set_event_loop(loop)
_sync_await(run_pulumi_func(lambda: _sync_await(Output.from_input(fn(*args, **kwargs)).future())))
return wrapper
@ -93,7 +89,6 @@ class MockMonitor:
ret = self.mocks.call(request.tok, args, request.provider)
asyncio.set_event_loop(loop)
ret_proto = _sync_await(asyncio.ensure_future(rpc.serialize_properties(ret, {})))
fields = {"failures": None, "return": ret_proto}
@ -104,7 +99,6 @@ class MockMonitor:
_, state = self.mocks.new_resource(request.type, request.name, state, request.provider, request.id)
asyncio.set_event_loop(loop)
props_proto = _sync_await(asyncio.ensure_future(rpc.serialize_properties(state, {})))
urn = self.make_urn(request.parent, request.type, request.name)
@ -115,7 +109,6 @@ class MockMonitor:
id_, state = self.mocks.new_resource(request.type, request.name, inputs, request.provider, request.importId)
asyncio.set_event_loop(loop)
obj_proto = _sync_await(rpc.serialize_properties(state, {}))
urn = self.make_urn(request.parent, request.type, request.name)
@ -158,7 +151,3 @@ def set_mocks(mocks: Mocks,
dry_run=preview,
test_mode_enabled=True)
configure(settings)
# Make sure we have an event loop.
global loop
loop = asyncio.get_event_loop()

View file

@ -23,6 +23,7 @@ from typing import Callable, Any, Dict, List
from ..resource import ComponentResource, Resource, ResourceTransformation
from .settings import get_project, get_stack, get_root_resource, is_dry_run, set_root_resource
from .rpc_manager import RPC_MANAGER
from .sync_await import _all_tasks, _get_current_task
from .. import log
from . import known_types
@ -55,9 +56,9 @@ async def run_pulumi_func(func: Callable):
# We will occasionally start tasks deliberately that we know will never complete. We must
# cancel them before shutting down the event loop.
log.debug("Canceling all outstanding tasks")
for task in asyncio.Task.all_tasks():
for task in _all_tasks():
# Don't kill ourselves, that would be silly.
if task == asyncio.Task.current_task():
if task == _get_current_task():
continue
task.cancel()

View file

@ -29,9 +29,13 @@ if sys.version_info[0] == 3 and sys.version_info[1] < 7:
_enter_task = enter_task
_leave_task = leave_task
_all_tasks = asyncio.Task.all_tasks
_get_current_task = asyncio.Task.current_task
else:
_enter_task = asyncio.tasks._enter_task # type: ignore
_leave_task = asyncio.tasks._leave_task # type: ignore
_all_tasks = asyncio.all_tasks # type: ignore
_get_current_task = asyncio.current_task # type: ignore
def _sync_await(awaitable: Awaitable[Any]) -> Any:
@ -41,7 +45,15 @@ def _sync_await(awaitable: Awaitable[Any]) -> Any:
"""
# Fetch the current event loop and ensure a future.
loop = asyncio.get_event_loop()
loop = None
try:
loop = asyncio.get_event_loop()
except RuntimeError:
pass
if loop is None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
fut = asyncio.ensure_future(awaitable)
# If the loop is not running, we can just use run_until_complete. Without this, we would need to duplicate a fair
@ -51,7 +63,7 @@ def _sync_await(awaitable: Awaitable[Any]) -> Any:
# If we are executing inside a task, pretend we've returned from its current callback--effectively yielding to
# the event loop--by calling _leave_task.
task = asyncio.Task.current_task(loop)
task = _get_current_task(loop)
if task is not None:
_leave_task(loop, task)

View file

@ -0,0 +1 @@
venv

View file

@ -0,0 +1,19 @@
# Copyright 2016-2018, Pulumi Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pulumi
import resources
pulumi.export("outprop", resources.mycomponent.outprop)
pulumi.export("public_ip", resources.myinstance.public_ip)

View file

@ -0,0 +1,23 @@
import pulumi
from pulumi import Output
class MyComponent(pulumi.ComponentResource):
outprop: pulumi.Output[str]
def __init__(self, name, inprop: pulumi.Input[str] = None, opts = None):
super().__init__('pkg:index:MyComponent', name, None, opts)
if inprop is None:
raise TypeError("Missing required property 'inprop'")
self.outprop = pulumi.Output.from_input(inprop).apply(lambda x: f"output: {x}")
class Instance(pulumi.CustomResource):
public_ip: pulumi.Output[str]
def __init__(self, resource_name, name: pulumi.Input[str] = None, opts = None):
if name is None:
raise TypeError("Missing required property 'name'")
__props__: dict = dict()
__props__["public_ip"] = None
__props__["name"] = name
super(Instance, self).__init__('aws:ec2/instance:Instance', resource_name, __props__, opts)
mycomponent = MyComponent("mycomponent", inprop="hello")
myinstance = Instance("instance", name="myvm")

View file

@ -0,0 +1,57 @@
# Copyright 2016-2018, Pulumi Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pulumi
class MyMocks(pulumi.runtime.Mocks):
def call(self, token, args, provider):
return {}
def new_resource(self, type_, name, inputs, provider, id_):
if type_ == 'aws:ec2/securityGroup:SecurityGroup':
state = {
'arn': 'arn:aws:ec2:us-west-2:123456789012:security-group/sg-12345678',
'name': inputs['name'] if 'name' in inputs else name + '-sg',
}
return ['sg-12345678', dict(inputs, **state)]
elif type_ == 'aws:ec2/instance:Instance':
state = {
'arn': 'arn:aws:ec2:us-west-2:123456789012:instance/i-1234567890abcdef0',
'instanceState': 'running',
'primaryNetworkInterfaceId': 'eni-12345678',
'privateDns': 'ip-10-0-1-17.ec2.internal',
'public_dns': 'ec2-203-0-113-12.compute-1.amazonaws.com',
'public_ip': '203.0.113.12',
}
return ['i-1234567890abcdef0', dict(inputs, **state)]
else:
return ['', {}]
pulumi.runtime.set_mocks(MyMocks())
# Now actually import the code that creates resources, and then test it.
import resources
class TestingWithMocks(unittest.TestCase):
@pulumi.runtime.test
def test_component(self):
def check_outprop(outprop):
self.assertEqual(outprop, 'output: hello')
return resources.mycomponent.outprop.apply(check_outprop)
@pulumi.runtime.test
def test_custom(self):
def check_ip(ip):
self.assertEqual(ip, '203.0.113.12')
return resources.myinstance.public_ip.apply(check_ip)