From e5595f11175bbfb327e54342ebacc5c97eda0a90 Mon Sep 17 00:00:00 2001
From: Alex Stephen <alexstephen@google.com>
Date: Mon, 26 Aug 2019 10:27:44 -0700
Subject: [PATCH] gcp_compute refactor (#61249)

* wip

* it works!

* cache should work

* ran black on code

* wip

* now it works

* black
---
 lib/ansible/plugins/inventory/gcp_compute.py | 495 +++++++++++--------
 1 file changed, 283 insertions(+), 212 deletions(-)

diff --git a/lib/ansible/plugins/inventory/gcp_compute.py b/lib/ansible/plugins/inventory/gcp_compute.py
index c6525b34fdb..80330b1e633 100644
--- a/lib/ansible/plugins/inventory/gcp_compute.py
+++ b/lib/ansible/plugins/inventory/gcp_compute.py
@@ -1,10 +1,11 @@
 # Copyright (c) 2017 Ansible Project
 # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
 
-from __future__ import (absolute_import, division, print_function)
+from __future__ import absolute_import, division, print_function
+
 __metaclass__ = type
 
-DOCUMENTATION = '''
+DOCUMENTATION = """
     name: gcp_compute
     plugin_type: inventory
     short_description: Google Cloud Compute Engine inventory source
@@ -105,9 +106,9 @@ DOCUMENTATION = '''
           type: bool
           default: False
           version_added: '2.8'
-'''
+"""
 
-EXAMPLES = '''
+EXAMPLES = """
 plugin: gcp_compute
 zones: # populate inventory with instances in these regions
   - us-east1-a
@@ -133,14 +134,19 @@ compose:
   # Set an inventory parameter to use the Public IP address to connect to the host
   # For Private ip use "networkInterfaces[0].networkIP"
   ansible_host: networkInterfaces[0].accessConfigs[0].natIP
-'''
+"""
 
 import json
 
 from ansible.errors import AnsibleError, AnsibleParserError
 from ansible.module_utils._text import to_text
 from ansible.module_utils.basic import missing_required_lib
-from ansible.module_utils.gcp_utils import GcpSession, navigate_hash, GcpRequestException, HAS_GOOGLE_LIBRARIES
+from ansible.module_utils.gcp_utils import (
+    GcpSession,
+    navigate_hash,
+    GcpRequestException,
+    HAS_GOOGLE_LIBRARIES,
+)
 from ansible.plugins.inventory import BaseInventoryPlugin, Constructable, Cacheable
 
 
@@ -150,62 +156,211 @@ class GcpMockModule(object):
         self.params = params
 
     def fail_json(self, *args, **kwargs):
-        raise AnsibleError(kwargs['msg'])
+        raise AnsibleError(kwargs["msg"])
+
+
+class GcpInstance(object):
+    def __init__(self, json, hostname_ordering, project_disks, should_format=True):
+        self.hostname_ordering = hostname_ordering
+        self.project_disks = project_disks
+        self.json = json
+        if should_format:
+            self.convert()
+
+    def to_json(self):
+        return self.json
+
+    def convert(self):
+        if "zone" in self.json:
+            self.json["zone_selflink"] = self.json["zone"]
+            self.json["zone"] = self.json["zone"].split("/")[-1]
+        if "machineType" in self.json:
+            self.json["machineType_selflink"] = self.json["machineType"]
+            self.json["machineType"] = self.json["machineType"].split("/")[-1]
+
+        if "networkInterfaces" in self.json:
+            for network in self.json["networkInterfaces"]:
+                if "network" in network:
+                    network["network"] = self._format_network_info(network["network"])
+                if "subnetwork" in network:
+                    network["subnetwork"] = self._format_network_info(
+                        network["subnetwork"]
+                    )
+
+        if "metadata" in self.json:
+            # If no metadata, 'items' will be blank.
+            # We want the metadata hash overriden anyways for consistency.
+            self.json["metadata"] = self._format_metadata(
+                self.json["metadata"].get("items", {})
+            )
+
+        self.json["project"] = self.json["selfLink"].split("/")[6]
+        self.json["image"] = self._get_image()
+
+    def _format_network_info(self, address):
+        """
+            :param address: A GCP network address
+            :return a dict with network shortname and region
+        """
+        split = address.split("/")
+        region = ""
+        if "global" in split:
+            region = "global"
+        else:
+            region = split[8]
+        return {"region": region, "name": split[-1], "selfLink": address}
+
+    def _format_metadata(self, metadata):
+        """
+            :param metadata: A list of dicts where each dict has keys "key" and "value"
+            :return a dict with key/value pairs for each in list.
+        """
+        new_metadata = {}
+        for pair in metadata:
+            new_metadata[pair["key"]] = pair["value"]
+        return new_metadata
+
+    def hostname(self):
+        """
+            :return the hostname of this instance
+        """
+        for order in self.hostname_ordering:
+            name = None
+            if order == "public_ip":
+                name = self._get_publicip()
+            elif order == "private_ip":
+                name = self._get_privateip()
+            elif order == "name":
+                name = self.json[u"name"]
+            else:
+                raise AnsibleParserError("%s is not a valid hostname precedent" % order)
+
+            if name:
+                return name
+
+        raise AnsibleParserError("No valid name found for host")
+
+    def _get_publicip(self):
+        """
+            :return the publicIP of this instance or None
+        """
+        # Get public IP if exists
+        for interface in self.json["networkInterfaces"]:
+            if "accessConfigs" in interface:
+                for accessConfig in interface["accessConfigs"]:
+                    if "natIP" in accessConfig:
+                        return accessConfig[u"natIP"]
+        return None
+
+    def _get_image(self):
+        """
+            :param instance: A instance response from GCP
+            :return the image of this instance or None
+        """
+        image = None
+        if self.project_disks and "disks" in self.json:
+            for disk in self.json["disks"]:
+                if disk.get("boot"):
+                    image = self.project_disks[disk["source"]]
+        return image
+
+    def _get_privateip(self):
+        """
+            :param item: A host response from GCP
+            :return the privateIP of this instance or None
+        """
+        # Fallback: Get private IP
+        for interface in self.json[u"networkInterfaces"]:
+            if "networkIP" in interface:
+                return interface[u"networkIP"]
 
 
 class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
 
-    NAME = 'gcp_compute'
+    NAME = "gcp_compute"
 
-    _instances = r"https://www.googleapis.com/compute/v1/projects/%s/aggregated/instances"
+    _instances = (
+        r"https://www.googleapis.com/compute/v1/projects/%s/aggregated/instances"
+    )
 
     def __init__(self):
         super(InventoryModule, self).__init__()
 
-        self.group_prefix = 'gcp_'
+        self.group_prefix = "gcp_"
 
     def _populate_host(self, item):
-        '''
+        """
             :param item: A GCP instance
-        '''
-        hostname = self._get_hostname(item)
+        """
+        hostname = item.hostname()
         self.inventory.add_host(hostname)
-        for key in item:
+        for key in item.to_json():
             try:
-                self.inventory.set_variable(hostname, self.get_option('vars_prefix') + key, item[key])
+                self.inventory.set_variable(
+                    hostname, self.get_option("vars_prefix") + key, item.to_json()[key]
+                )
             except (ValueError, TypeError) as e:
-                self.display.warning("Could not set host info hostvar for %s, skipping %s: %s" % (hostname, key, to_text(e)))
-        self.inventory.add_child('all', hostname)
+                self.display.warning(
+                    "Could not set host info hostvar for %s, skipping %s: %s"
+                    % (hostname, key, to_text(e))
+                )
+        self.inventory.add_child("all", hostname)
 
     def verify_file(self, path):
-        '''
+        """
             :param path: the path to the inventory config file
             :return the contents of the config file
-        '''
+        """
         if super(InventoryModule, self).verify_file(path):
-            if path.endswith(('gcp.yml', 'gcp.yaml')):
+            if path.endswith(("gcp.yml", "gcp.yaml")):
                 return True
-            elif path.endswith(('gcp_compute.yml', 'gcp_compute.yaml')):
+            elif path.endswith(("gcp_compute.yml", "gcp_compute.yaml")):
                 return True
         return False
 
     def fetch_list(self, params, link, query):
-        '''
+        """
             :param params: a dict containing all of the fields relevant to build URL
             :param link: a formatted URL
             :param query: a formatted query string
             :return the JSON response containing a list of instances.
-        '''
-        response = self.auth_session.get(link, params={'filter': query})
-        return self._return_if_object(self.fake_module, response)
+        """
+        lists = []
+        resp = self._return_if_object(
+            self.fake_module, self.auth_session.get(link, params={"filter": query})
+        )
+        lists.append(resp.get("items"))
+        while resp.get("nextPageToken"):
+            resp = self._return_if_object(
+                self.fake_module,
+                self.auth_session.get(
+                    link,
+                    params={"filter": query, "pageToken": resp.get("nextPageToken")},
+                ),
+            )
+            lists.append(resp.get("items"))
+        return self.build_list(lists)
+
+    def build_list(self, lists):
+        arrays_for_zones = {}
+        for resp in lists:
+            for zone in resp:
+                if "instances" in resp[zone]:
+                    if zone in arrays_for_zones:
+                        arrays_for_zones[zone] = (
+                            arrays_for_zones[zone] + resp[zone]["instances"]
+                        )
+                    else:
+                        arrays_for_zones[zone] = resp[zone]["instances"]
+        return arrays_for_zones
 
     def _get_query_options(self, filters):
-        '''
+        """
             :param config_data: contents of the inventory config file
             :return A fully built query string
-        '''
+        """
         if not filters:
-            return ''
+            return ""
 
         if len(filters) == 1:
             return filters[0]
@@ -213,19 +368,19 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
             queries = []
             for f in filters:
                 # For multiple queries, all queries should have ()
-                if f[0] != '(' and f[-1] != ')':
-                    queries.append("(%s)" % ''.join(f))
+                if f[0] != "(" and f[-1] != ")":
+                    queries.append("(%s)" % "".join(f))
                 else:
                     queries.append(f)
 
-            return ' '.join(queries)
+            return " ".join(queries)
 
     def _return_if_object(self, module, response):
-        '''
+        """
             :param module: A GcpModule
             :param response: A Requests response object
             :return JSON response
-        '''
+        """
         # If not found, return nothing.
         if response.status_code == 404:
             return None
@@ -237,241 +392,155 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
         try:
             response.raise_for_status
             result = response.json()
-        except getattr(json.decoder, 'JSONDecodeError', ValueError) as inst:
+        except getattr(json.decoder, "JSONDecodeError", ValueError) as inst:
             module.fail_json(msg="Invalid JSON response with error: %s" % inst)
         except GcpRequestException as inst:
             module.fail_json(msg="Network error: %s" % inst)
 
-        if navigate_hash(result, ['error', 'errors']):
-            module.fail_json(msg=navigate_hash(result, ['error', 'errors']))
+        if navigate_hash(result, ["error", "errors"]):
+            module.fail_json(msg=navigate_hash(result, ["error", "errors"]))
 
         return result
 
-    def _format_items(self, items, project_disks):
-        '''
-            :param items: A list of hosts
-        '''
-        for host in items:
-            if 'zone' in host:
-                host['zone_selflink'] = host['zone']
-                host['zone'] = host['zone'].split('/')[-1]
-            if 'machineType' in host:
-                host['machineType_selflink'] = host['machineType']
-                host['machineType'] = host['machineType'].split('/')[-1]
-
-            if 'networkInterfaces' in host:
-                for network in host['networkInterfaces']:
-                    if 'network' in network:
-                        network['network'] = self._format_network_info(network['network'])
-                    if 'subnetwork' in network:
-                        network['subnetwork'] = self._format_network_info(network['subnetwork'])
-
-            if 'metadata' in host:
-                # If no metadata, 'items' will be blank.
-                # We want the metadata hash overriden anyways for consistency.
-                host['metadata'] = self._format_metadata(host['metadata'].get('items', {}))
-
-            host['project'] = host['selfLink'].split('/')[6]
-            host['image'] = self._get_image(host, project_disks)
-        return items
-
     def _add_hosts(self, items, config_data, format_items=True, project_disks=None):
-        '''
+        """
             :param items: A list of hosts
             :param config_data: configuration data
             :param format_items: format items or not
-        '''
+        """
         if not items:
             return
-        if format_items:
-            items = self._format_items(items, project_disks)
 
-        for host in items:
+        hostname_ordering = ["public_ip", "private_ip", "name"]
+        if self.get_option("hostnames"):
+            hostname_ordering = self.get_option("hostnames")
+
+        for host_json in items:
+            host = GcpInstance(
+                host_json, hostname_ordering, project_disks, format_items
+            )
             self._populate_host(host)
 
-            hostname = self._get_hostname(host)
-            self._set_composite_vars(self.get_option('compose'), host, hostname)
-            self._add_host_to_composed_groups(self.get_option('groups'), host, hostname)
-            self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host, hostname)
-
-    def _format_network_info(self, address):
-        '''
-            :param address: A GCP network address
-            :return a dict with network shortname and region
-        '''
-        split = address.split('/')
-        region = ''
-        if 'global' in split:
-            region = 'global'
-        else:
-            region = split[8]
-        return {
-            'region': region,
-            'name': split[-1],
-            'selfLink': address
-        }
-
-    def _format_metadata(self, metadata):
-        '''
-            :param metadata: A list of dicts where each dict has keys "key" and "value"
-            :return a dict with key/value pairs for each in list.
-        '''
-        new_metadata = {}
-        for pair in metadata:
-            new_metadata[pair["key"]] = pair["value"]
-        return new_metadata
-
-    def _get_hostname(self, item):
-        '''
-            :param item: A host response from GCP
-            :return the hostname of this instance
-        '''
-        hostname_ordering = ['public_ip', 'private_ip', 'name']
-        if self.get_option('hostnames'):
-            hostname_ordering = self.get_option('hostnames')
-
-        for order in hostname_ordering:
-            name = None
-            if order == 'public_ip':
-                name = self._get_publicip(item)
-            elif order == 'private_ip':
-                name = self._get_privateip(item)
-            elif order == 'name':
-                name = item[u'name']
-            else:
-                raise AnsibleParserError("%s is not a valid hostname precedent" % order)
-
-            if name:
-                return name
-
-        raise AnsibleParserError("No valid name found for host")
-
-    def _get_publicip(self, item):
-        '''
-            :param item: A host response from GCP
-            :return the publicIP of this instance or None
-        '''
-        # Get public IP if exists
-        for interface in item['networkInterfaces']:
-            if 'accessConfigs' in interface:
-                for accessConfig in interface['accessConfigs']:
-                    if 'natIP' in accessConfig:
-                        return accessConfig[u'natIP']
-        return None
-
-    def _get_image(self, instance, project_disks):
-        '''
-            :param instance: A instance response from GCP
-            :return the image of this instance or None
-        '''
-        image = None
-        if project_disks and 'disks' in instance:
-            for disk in instance['disks']:
-                if disk.get('boot'):
-                    image = project_disks[disk["source"]]
-        return image
+            hostname = host.hostname()
+            self._set_composite_vars(
+                self.get_option("compose"), host.to_json(), hostname
+            )
+            self._add_host_to_composed_groups(
+                self.get_option("groups"), host.to_json(), hostname
+            )
+            self._add_host_to_keyed_groups(
+                self.get_option("keyed_groups"), host.to_json(), hostname
+            )
 
     def _get_project_disks(self, config_data, query):
-        '''
+        """
             project space disk images
-        '''
+        """
 
         try:
             self._project_disks
         except AttributeError:
             self._project_disks = {}
-            request_params = {'maxResults': 500, 'filter': query}
+            request_params = {"maxResults": 500, "filter": query}
 
-            for project in config_data['projects']:
+            for project in config_data["projects"]:
                 session_responses = []
                 page_token = True
                 while page_token:
                     response = self.auth_session.get(
-                        'https://www.googleapis.com/compute/v1/projects/{0}/aggregated/disks'.format(project),
-                        params=request_params
+                        "https://www.googleapis.com/compute/v1/projects/{0}/aggregated/disks".format(
+                            project
+                        ),
+                        params=request_params,
                     )
                     response_json = response.json()
-                    if 'nextPageToken' in response_json:
-                        request_params['pageToken'] = response_json['nextPageToken']
-                    elif 'pageToken' in request_params:
-                        del request_params['pageToken']
+                    if "nextPageToken" in response_json:
+                        request_params["pageToken"] = response_json["nextPageToken"]
+                    elif "pageToken" in request_params:
+                        del request_params["pageToken"]
 
-                    if 'items' in response_json:
+                    if "items" in response_json:
                         session_responses.append(response_json)
-                    page_token = 'pageToken' in request_params
+                    page_token = "pageToken" in request_params
 
             for response in session_responses:
-                if 'items' in response:
+                if "items" in response:
                     # example k would be a zone or region name
                     # example v would be { "disks" : [], "otherkey" : "..." }
-                    for zone_or_region, aggregate in response['items'].items():
-                        if 'zones' in zone_or_region:
-                            if 'disks' in aggregate:
-                                zone = zone_or_region.replace('zones/', '')
-                                for disk in aggregate['disks']:
-                                    if 'zones' in config_data and zone in config_data['zones']:
+                    for zone_or_region, aggregate in response["items"].items():
+                        if "zones" in zone_or_region:
+                            if "disks" in aggregate:
+                                zone = zone_or_region.replace("zones/", "")
+                                for disk in aggregate["disks"]:
+                                    if (
+                                        "zones" in config_data
+                                        and zone in config_data["zones"]
+                                    ):
                                         # If zones specified, only store those zones' data
-                                        if 'sourceImage' in disk:
-                                            self._project_disks[disk['selfLink']] = disk['sourceImage'].split('/')[-1]
+                                        if "sourceImage" in disk:
+                                            self._project_disks[
+                                                disk["selfLink"]
+                                            ] = disk["sourceImage"].split("/")[-1]
                                         else:
-                                            self._project_disks[disk['selfLink']] = disk['selfLink'].split('/')[-1]
+                                            self._project_disks[
+                                                disk["selfLink"]
+                                            ] = disk["selfLink"].split("/")[-1]
 
                                     else:
-                                        if 'sourceImage' in disk:
-                                            self._project_disks[disk['selfLink']] = disk['sourceImage'].split('/')[-1]
+                                        if "sourceImage" in disk:
+                                            self._project_disks[
+                                                disk["selfLink"]
+                                            ] = disk["sourceImage"].split("/")[-1]
                                         else:
-                                            self._project_disks[disk['selfLink']] = disk['selfLink'].split('/')[-1]
+                                            self._project_disks[
+                                                disk["selfLink"]
+                                            ] = disk["selfLink"].split("/")[-1]
 
         return self._project_disks
 
-    def _get_privateip(self, item):
-        '''
-            :param item: A host response from GCP
-            :return the privateIP of this instance or None
-        '''
-        # Fallback: Get private IP
-        for interface in item[u'networkInterfaces']:
-            if 'networkIP' in interface:
-                return interface[u'networkIP']
-
     def parse(self, inventory, loader, path, cache=True):
 
         if not HAS_GOOGLE_LIBRARIES:
-            raise AnsibleParserError('gce inventory plugin cannot start: %s' % missing_required_lib('google-auth'))
+            raise AnsibleParserError(
+                "gce inventory plugin cannot start: %s"
+                % missing_required_lib("google-auth")
+            )
 
         super(InventoryModule, self).parse(inventory, loader, path)
 
         config_data = {}
         config_data = self._read_config_data(path)
 
-        if self.get_option('use_contrib_script_compatible_sanitization'):
-            self._sanitize_group_name = self._legacy_script_compatible_group_sanitization
+        if self.get_option("use_contrib_script_compatible_sanitization"):
+            self._sanitize_group_name = (
+                self._legacy_script_compatible_group_sanitization
+            )
 
         # setup parameters as expected by 'fake module class' to reuse module_utils w/o changing the API
         params = {
-            'filters': self.get_option('filters'),
-            'projects': self.get_option('projects'),
-            'scopes': self.get_option('scopes'),
-            'zones': self.get_option('zones'),
-            'auth_kind': self.get_option('auth_kind'),
-            'service_account_file': self.get_option('service_account_file'),
-            'service_account_contents': self.get_option('service_account_contents'),
-            'service_account_email': self.get_option('service_account_email'),
+            "filters": self.get_option("filters"),
+            "projects": self.get_option("projects"),
+            "scopes": self.get_option("scopes"),
+            "zones": self.get_option("zones"),
+            "auth_kind": self.get_option("auth_kind"),
+            "service_account_file": self.get_option("service_account_file"),
+            "service_account_contents": self.get_option("service_account_contents"),
+            "service_account_email": self.get_option("service_account_email"),
         }
 
         self.fake_module = GcpMockModule(params)
-        self.auth_session = GcpSession(self.fake_module, 'compute')
+        self.auth_session = GcpSession(self.fake_module, "compute")
 
-        query = self._get_query_options(params['filters'])
+        query = self._get_query_options(params["filters"])
 
-        if self.get_option('retrieve_image_info'):
+        if self.get_option("retrieve_image_info"):
             project_disks = self._get_project_disks(config_data, query)
         else:
             project_disks = None
 
         # Cache logic
         if cache:
-            cache = self.get_option('cache')
+            cache = self.get_option("cache")
             cache_key = self.get_cache_key(path)
         else:
             cache_key = None
@@ -482,27 +551,29 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable):
                 results = self._cache[cache_key]
                 for project in results:
                     for zone in results[project]:
-                        self._add_hosts(results[project][zone], config_data, False, project_disks=project_disks)
+                        self._add_hosts(
+                            results[project][zone],
+                            config_data,
+                            False,
+                            project_disks=project_disks,
+                        )
             except KeyError:
                 cache_needs_update = True
 
         if not cache or cache_needs_update:
             cached_data = {}
-            for project in params['projects']:
+            for project in params["projects"]:
                 cached_data[project] = {}
-                params['project'] = project
-                zones = params['zones']
+                params["project"] = project
+                zones = params["zones"]
                 # Fetch all instances
                 link = self._instances % project
                 resp = self.fetch_list(params, link, query)
-                if 'items' in resp and resp['items']:
-                    for key, value in resp.get('items').items():
-                        if 'instances' in value:
-                            # Key is in format: "zones/europe-west1-b"
-                            zone = key[6:]
-                            if not zones or zone in zones:
-                                self._add_hosts(value['instances'], config_data, project_disks=project_disks)
-                                cached_data[project][zone] = value['instances']
+                for key, value in resp.items():
+                    zone = key[6:]
+                    if not zones or zone in zones:
+                        self._add_hosts(value, config_data, project_disks=project_disks)
+                        cached_data[project][zone] = value
 
         if cache_needs_update:
             self._cache[cache_key] = cached_data