# Copyright 2014-2016 OpenMarket Ltd # Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import random import re import string from collections.abc import Iterable from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically # says "there is no grammar for media ids" # # The server_name part of this is purposely lax: use parse_and_validate_mxc for # additional validation. # MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. rand = random.SystemRandom() def random_string(length: int) -> str: return "".join(rand.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length: int) -> str: return "".join(rand.choice(_string_with_symbols) for _ in range(length)) def is_ascii(s: bytes) -> bool: try: s.decode("ascii").encode("ascii") except UnicodeDecodeError: return False except UnicodeEncodeError: return False return True def assert_valid_client_secret(client_secret: str) -> None: """Validate that a given string matches the client_secret defined by the spec""" if ( len(client_secret) <= 0 or len(client_secret) > 255 or CLIENT_SECRET_REGEX.match(client_secret) is None ): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: """Split a server name into host/port parts. Args: server_name: server name to parse Returns: host/port parts. Raises: ValueError if the server name could not be parsed. """ try: if server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None domain_port = server_name.rsplit(":", 1) domain = domain_port[0] port = int(domain_port[1]) if domain_port[1:] else None return domain, port except Exception: raise ValueError("Invalid server name '%s'" % server_name) VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: """Split a server name into host/port parts and do some basic validation. Args: server_name: server name to parse Returns: host/port parts. Raises: ValueError if the server name could not be parsed. """ host, port = parse_server_name(server_name) # these tests don't need to be bulletproof as we'll find out soon enough # if somebody is giving us invalid data. What we *do* need is to be sure # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals if host[0] == "[": if host[-1] != "]": raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) return host, port # otherwise it should only be alphanumerics. if not VALID_HOST_REGEX.match(host): raise ValueError( "Server name '%s' contains invalid characters" % (server_name,) ) return host, port def valid_id_server_location(id_server: str) -> bool: """Check whether an identity server location, such as the one passed as the `id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid. A valid identity server location consists of a valid hostname and optional port number, optionally followed by any number of `/` delimited path components, without any fragment or query string parts. Args: id_server: identity server location string to validate Returns: True if valid, False otherwise. """ components = id_server.split("/", 1) host = components[0] try: parse_and_validate_server_name(host) except ValueError: return False if len(components) < 2: # no path return True path = components[1] return "#" not in path and "?" not in path def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: """Parse the given string as an MXC URI Checks that the "server name" part is a valid server name Args: mxc: the (alleged) MXC URI to be checked Returns: hostname, port, media id Raises: ValueError if the URI cannot be parsed """ m = MXC_REGEX.match(mxc) if not m: raise ValueError("mxc URI %r did not match expected format" % (mxc,)) server_name = m.group(1) media_id = m.group(2) host, port = parse_and_validate_server_name(server_name) return host, port, media_id def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. Otherwise, return the stringification of a a list with the first maxitems items, followed by "...". Args: iterable: iterable to truncate maxitems: number of items to return before truncating """ items = list(itertools.islice(iterable, maxitems + 1)) if len(items) <= maxitems: return str(items) return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" def strtobool(val: str) -> bool: """Convert a string representation of truth to True or False True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 'val' is anything else. This is lifted from distutils.util.strtobool, with the exception that it actually returns a bool, rather than an int. """ val = val.lower() if val in ("y", "yes", "t", "true", "on", "1"): return True elif val in ("n", "no", "f", "false", "off", "0"): return False else: raise ValueError("invalid truth value %r" % (val,)) _BASE62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" def base62_encode(num: int, minwidth: int = 1) -> str: """Encode a number using base62 Args: num: number to be encoded minwidth: width to pad to, if the number is small """ res = "" while num: num, rem = divmod(num, 62) res = _BASE62[rem] + res # pad to minimum width pad = "0" * (minwidth - len(res)) return pad + res