diff --git a/test/units/playbook/role/test_include_role.py b/test/units/playbook/role/test_include_role.py
index 30d79d6f469..342fa7d73c7 100644
--- a/test/units/playbook/role/test_include_role.py
+++ b/test/units/playbook/role/test_include_role.py
@@ -23,6 +23,7 @@ from ansible.compat.tests import unittest
 from ansible.compat.tests.mock import patch
 
 from ansible.playbook import Play
+from ansible.playbook.role_include import IncludeRole
 from ansible.playbook.task import Task
 from ansible.vars.manager import VariableManager
 
@@ -30,15 +31,6 @@ from units.mock.loader import DictDataLoader
 from units.mock.path import mock_unfrackpath_noop
 
 
-def flatten_tasks(tasks):
-    for task in tasks:
-        if isinstance(task, Task):
-            yield task
-        else:
-            for t in flatten_tasks(task.block):
-                yield t
-
-
 class TestIncludeRole(unittest.TestCase):
 
     def setUp(self):
@@ -97,8 +89,21 @@ class TestIncludeRole(unittest.TestCase):
     def tearDown(self):
         pass
 
+    def flatten_tasks(self, tasks):
+        for task in tasks:
+            if isinstance(task, IncludeRole):
+                blocks, handlers = task.get_block_list(loader=self.loader)
+                for block in blocks:
+                    for t in self.flatten_tasks(block.block):
+                        yield t
+            elif isinstance(task, Task):
+                yield task
+            else:
+                for t in self.flatten_tasks(task.block):
+                    yield t
+
     def get_tasks_vars(self, play, tasks):
-        for task in flatten_tasks(tasks):
+        for task in self.flatten_tasks(tasks):
             role = task._role
             if not role:
                 continue
@@ -122,9 +127,12 @@ class TestIncludeRole(unittest.TestCase):
         ), loader=self.loader, variable_manager=self.var_manager)
 
         tasks = play.compile()
+        tested = False
         for role, task_vars in self.get_tasks_vars(play, tasks):
+            tested = True
             self.assertEqual(task_vars.get('l3_variable'), 'l3-main')
             self.assertEqual(task_vars.get('test_variable'), 'l3-main')
+        self.assertTrue(tested)
 
     @patch('ansible.playbook.role.definition.unfrackpath',
            mock_unfrackpath_noop)
@@ -140,9 +148,12 @@ class TestIncludeRole(unittest.TestCase):
             loader=self.loader, variable_manager=self.var_manager)
 
         tasks = play.compile()
+        tested = False
         for role, task_vars in self.get_tasks_vars(play, tasks):
+            tested = True
             self.assertEqual(task_vars.get('l3_variable'), 'l3-alt')
             self.assertEqual(task_vars.get('test_variable'), 'l3-alt')
+        self.assertTrue(tested)
 
     @patch('ansible.playbook.role.definition.unfrackpath',
            mock_unfrackpath_noop)
@@ -165,7 +176,9 @@ class TestIncludeRole(unittest.TestCase):
         ), loader=self.loader, variable_manager=self.var_manager)
 
         tasks = play.compile()
+        expected_roles = ['l1', 'l2', 'l3']
         for role, task_vars in self.get_tasks_vars(play, tasks):
+            expected_roles.remove(role)
             # Outer-most role must not have variables from inner roles yet
             if role == 'l1':
                 self.assertEqual(task_vars.get('l1_variable'), 'l1-main')
@@ -184,6 +197,9 @@ class TestIncludeRole(unittest.TestCase):
                 self.assertEqual(task_vars.get('l2_variable'), 'l2-main')
                 self.assertEqual(task_vars.get('l3_variable'), 'l3-main')
                 self.assertEqual(task_vars.get('test_variable'), 'l3-main')
+            else:
+                self.fail()
+        self.assertFalse(expected_roles)
 
     @patch('ansible.playbook.role.definition.unfrackpath',
            mock_unfrackpath_noop)
@@ -206,7 +222,9 @@ class TestIncludeRole(unittest.TestCase):
         ), loader=self.loader, variable_manager=self.var_manager)
 
         tasks = play.compile()
+        expected_roles = ['l1', 'l2', 'l3']
         for role, task_vars in self.get_tasks_vars(play, tasks):
+            expected_roles.remove(role)
             # Outer-most role must not have variables from inner roles yet
             if role == 'l1':
                 self.assertEqual(task_vars.get('l1_variable'), 'l1-alt')
@@ -225,3 +243,6 @@ class TestIncludeRole(unittest.TestCase):
                 self.assertEqual(task_vars.get('l2_variable'), 'l2-alt')
                 self.assertEqual(task_vars.get('l3_variable'), 'l3-alt')
                 self.assertEqual(task_vars.get('test_variable'), 'l3-alt')
+            else:
+                self.fail()
+        self.assertFalse(expected_roles)