Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
This commit is contained in:
commit
b55cdfaa31
1
changelog.d/4155.misc
Normal file
1
changelog.d/4155.misc
Normal file
|
@ -0,0 +1 @@
|
|||
add purge_history.sh and purge_remote_media.sh scripts to contrib/
|
1
changelog.d/4156.misc
Normal file
1
changelog.d/4156.misc
Normal file
|
@ -0,0 +1 @@
|
|||
HTTP tests have been refactored to contain less boilerplate.
|
1
changelog.d/4157.bugfix
Normal file
1
changelog.d/4157.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Loading URL previews from the DB cache on Postgres will no longer cause Unicode type errors when responding to the request, and URL previews will no longer fail if the remote server returns a Content-Type header with the chartype in quotes.
|
1
changelog.d/4161.bugfix
Normal file
1
changelog.d/4161.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
The hash_password script now works on Python 3.
|
16
contrib/purge_api/README.md
Normal file
16
contrib/purge_api/README.md
Normal file
|
@ -0,0 +1,16 @@
|
|||
Purge history API examples
|
||||
==========================
|
||||
|
||||
# `purge_history.sh`
|
||||
|
||||
A bash file, that uses the [purge history API](/docs/admin_api/README.rst) to
|
||||
purge all messages in a list of rooms up to a certain event. You can select a
|
||||
timeframe or a number of messages that you want to keep in the room.
|
||||
|
||||
Just configure the variables DOMAIN, ADMIN, ROOMS_ARRAY and TIME at the top of
|
||||
the script.
|
||||
|
||||
# `purge_remote_media.sh`
|
||||
|
||||
A bash file, that uses the [purge history API](/docs/admin_api/README.rst) to
|
||||
purge all old cached remote media.
|
141
contrib/purge_api/purge_history.sh
Normal file
141
contrib/purge_api/purge_history.sh
Normal file
|
@ -0,0 +1,141 @@
|
|||
#!/bin/bash
|
||||
|
||||
# this script will use the api:
|
||||
# https://github.com/matrix-org/synapse/blob/master/docs/admin_api/purge_history_api.rst
|
||||
#
|
||||
# It will purge all messages in a list of rooms up to a cetrain event
|
||||
|
||||
###################################################################################################
|
||||
# define your domain and admin user
|
||||
###################################################################################################
|
||||
# add this user as admin in your home server:
|
||||
DOMAIN=yourserver.tld
|
||||
# add this user as admin in your home server:
|
||||
ADMIN="@you_admin_username:$DOMAIN"
|
||||
|
||||
API_URL="$DOMAIN:8008/_matrix/client/r0"
|
||||
|
||||
###################################################################################################
|
||||
#choose the rooms to prune old messages from (add a free comment at the end)
|
||||
###################################################################################################
|
||||
# the room_id's you can get e.g. from your Riot clients "View Source" button on each message
|
||||
ROOMS_ARRAY=(
|
||||
'!DgvjtOljKujDBrxyHk:matrix.org#riot:matrix.org'
|
||||
'!QtykxKocfZaZOUrTwp:matrix.org#Matrix HQ'
|
||||
)
|
||||
|
||||
# ALTERNATIVELY:
|
||||
# you can select all the rooms that are not encrypted and loop over the result:
|
||||
# SELECT room_id FROM rooms WHERE room_id NOT IN (SELECT DISTINCT room_id FROM events WHERE type ='m.room.encrypted')
|
||||
# or
|
||||
# select all rooms with at least 100 members:
|
||||
# SELECT q.room_id FROM (select count(*) as numberofusers, room_id FROM current_state_events WHERE type ='m.room.member'
|
||||
# GROUP BY room_id) AS q LEFT JOIN room_aliases a ON q.room_id=a.room_id WHERE q.numberofusers > 100 ORDER BY numberofusers desc
|
||||
|
||||
###################################################################################################
|
||||
# evaluate the EVENT_ID before which should be pruned
|
||||
###################################################################################################
|
||||
# choose a time before which the messages should be pruned:
|
||||
TIME='12 months ago'
|
||||
# ALTERNATIVELY:
|
||||
# a certain time:
|
||||
# TIME='2016-08-31 23:59:59'
|
||||
|
||||
# creates a timestamp from the given time string:
|
||||
UNIX_TIMESTAMP=$(date +%s%3N --date='TZ="UTC+2" '"$TIME")
|
||||
|
||||
# ALTERNATIVELY:
|
||||
# prune all messages that are older than 1000 messages ago:
|
||||
# LAST_MESSAGES=1000
|
||||
# SQL_GET_EVENT="SELECT event_id from events WHERE type='m.room.message' AND room_id ='$ROOM' ORDER BY received_ts DESC LIMIT 1 offset $(($LAST_MESSAGES - 1))"
|
||||
|
||||
# ALTERNATIVELY:
|
||||
# select the EVENT_ID manually:
|
||||
#EVENT_ID='$1471814088343495zpPNI:matrix.org' # an example event from 21st of Aug 2016 by Matthew
|
||||
|
||||
###################################################################################################
|
||||
# make the admin user a server admin in the database with
|
||||
###################################################################################################
|
||||
# psql -A -t --dbname=synapse -c "UPDATE users SET admin=1 WHERE name LIKE '$ADMIN'"
|
||||
|
||||
###################################################################################################
|
||||
# database function
|
||||
###################################################################################################
|
||||
sql (){
|
||||
# for sqlite3:
|
||||
#sqlite3 homeserver.db "pragma busy_timeout=20000;$1" | awk '{print $2}'
|
||||
# for postgres:
|
||||
psql -A -t --dbname=synapse -c "$1" | grep -v 'Pager'
|
||||
}
|
||||
|
||||
###################################################################################################
|
||||
# get an access token
|
||||
###################################################################################################
|
||||
# for example externally by watching Riot in your browser's network inspector
|
||||
# or internally on the server locally, use this:
|
||||
TOKEN=$(sql "SELECT token FROM access_tokens WHERE user_id='$ADMIN' ORDER BY id DESC LIMIT 1")
|
||||
AUTH="Authorization: Bearer $TOKEN"
|
||||
|
||||
###################################################################################################
|
||||
# check, if your TOKEN works. For example this works:
|
||||
###################################################################################################
|
||||
# $ curl --header "$AUTH" "$API_URL/rooms/$ROOM/state/m.room.power_levels"
|
||||
|
||||
###################################################################################################
|
||||
# finally start pruning the room:
|
||||
###################################################################################################
|
||||
POSTDATA='{"delete_local_events":"true"}' # this will really delete local events, so the messages in the room really disappear unless they are restored by remote federation
|
||||
|
||||
for ROOM in "${ROOMS_ARRAY[@]}"; do
|
||||
echo "########################################### $(date) ################# "
|
||||
echo "pruning room: $ROOM ..."
|
||||
ROOM=${ROOM%#*}
|
||||
#set -x
|
||||
echo "check for alias in db..."
|
||||
# for postgres:
|
||||
sql "SELECT * FROM room_aliases WHERE room_id='$ROOM'"
|
||||
echo "get event..."
|
||||
# for postgres:
|
||||
EVENT_ID=$(sql "SELECT event_id FROM events WHERE type='m.room.message' AND received_ts<'$UNIX_TIMESTAMP' AND room_id='$ROOM' ORDER BY received_ts DESC LIMIT 1;")
|
||||
if [ "$EVENT_ID" == "" ]; then
|
||||
echo "no event $TIME"
|
||||
else
|
||||
echo "event: $EVENT_ID"
|
||||
SLEEP=2
|
||||
set -x
|
||||
# call purge
|
||||
OUT=$(curl --header "$AUTH" -s -d $POSTDATA POST "$API_URL/admin/purge_history/$ROOM/$EVENT_ID")
|
||||
PURGE_ID=$(echo "$OUT" |grep purge_id|cut -d'"' -f4 )
|
||||
if [ "$PURGE_ID" == "" ]; then
|
||||
# probably the history purge is already in progress for $ROOM
|
||||
: "continuing with next room"
|
||||
else
|
||||
while : ; do
|
||||
# get status of purge and sleep longer each time if still active
|
||||
sleep $SLEEP
|
||||
STATUS=$(curl --header "$AUTH" -s GET "$API_URL/admin/purge_history_status/$PURGE_ID" |grep status|cut -d'"' -f4)
|
||||
: "$ROOM --> Status: $STATUS"
|
||||
[[ "$STATUS" == "active" ]] || break
|
||||
SLEEP=$((SLEEP + 1))
|
||||
done
|
||||
fi
|
||||
set +x
|
||||
sleep 1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###################################################################################################
|
||||
# additionally
|
||||
###################################################################################################
|
||||
# to benefit from pruning large amounts of data, you need to call VACUUM to free the unused space.
|
||||
# This can take a very long time (hours) and the client have to be stopped while you do so:
|
||||
# $ synctl stop
|
||||
# $ sqlite3 -line homeserver.db "vacuum;"
|
||||
# $ synctl start
|
||||
|
||||
# This could be set, so you don't need to prune every time after deleting some rows:
|
||||
# $ sqlite3 homeserver.db "PRAGMA auto_vacuum = FULL;"
|
||||
# be cautious, it could make the database somewhat slow if there are a lot of deletions
|
||||
|
||||
exit
|
54
contrib/purge_api/purge_remote_media.sh
Normal file
54
contrib/purge_api/purge_remote_media.sh
Normal file
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
|
||||
DOMAIN=yourserver.tld
|
||||
# add this user as admin in your home server:
|
||||
ADMIN="@you_admin_username:$DOMAIN"
|
||||
|
||||
API_URL="$DOMAIN:8008/_matrix/client/r0"
|
||||
|
||||
# choose a time before which the messages should be pruned:
|
||||
# TIME='2016-08-31 23:59:59'
|
||||
TIME='12 months ago'
|
||||
|
||||
# creates a timestamp from the given time string:
|
||||
UNIX_TIMESTAMP=$(date +%s%3N --date='TZ="UTC+2" '"$TIME")
|
||||
|
||||
|
||||
###################################################################################################
|
||||
# database function
|
||||
###################################################################################################
|
||||
sql (){
|
||||
# for sqlite3:
|
||||
#sqlite3 homeserver.db "pragma busy_timeout=20000;$1" | awk '{print $2}'
|
||||
# for postgres:
|
||||
psql -A -t --dbname=synapse -c "$1" | grep -v 'Pager'
|
||||
}
|
||||
|
||||
###############################################################################
|
||||
# make the admin user a server admin in the database with
|
||||
###############################################################################
|
||||
# sql "UPDATE users SET admin=1 WHERE name LIKE '$ADMIN'"
|
||||
|
||||
###############################################################################
|
||||
# get an access token
|
||||
###############################################################################
|
||||
# for example externally by watching Riot in your browser's network inspector
|
||||
# or internally on the server locally, use this:
|
||||
TOKEN=$(sql "SELECT token FROM access_tokens WHERE user_id='$ADMIN' ORDER BY id DESC LIMIT 1")
|
||||
|
||||
###############################################################################
|
||||
# check, if your TOKEN works. For example this works:
|
||||
###############################################################################
|
||||
# curl --header "Authorization: Bearer $TOKEN" "$API_URL/rooms/$ROOM/state/m.room.power_levels"
|
||||
|
||||
###############################################################################
|
||||
# optional check size before
|
||||
###############################################################################
|
||||
# echo calculate used storage before ...
|
||||
# du -shc ../.synapse/media_store/*
|
||||
|
||||
###############################################################################
|
||||
# finally start pruning media:
|
||||
###############################################################################
|
||||
set -x # for debugging the generated string
|
||||
curl --header "Authorization: Bearer $TOKEN" -v POST "$API_URL/admin/purge_media_cache/?before_ts=$UNIX_TIMESTAMP"
|
|
@ -3,13 +3,15 @@
|
|||
import argparse
|
||||
import getpass
|
||||
import sys
|
||||
import unicodedata
|
||||
|
||||
import bcrypt
|
||||
import yaml
|
||||
|
||||
bcrypt_rounds=12
|
||||
bcrypt_rounds = 12
|
||||
password_pepper = ""
|
||||
|
||||
|
||||
def prompt_for_pass():
|
||||
password = getpass.getpass("Password: ")
|
||||
|
||||
|
@ -23,19 +25,27 @@ def prompt_for_pass():
|
|||
|
||||
return password
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Calculate the hash of a new password, so that passwords"
|
||||
" can be reset")
|
||||
description=(
|
||||
"Calculate the hash of a new password, so that passwords can be reset"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--password",
|
||||
"-p",
|
||||
"--password",
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
"-c",
|
||||
"--config",
|
||||
type=argparse.FileType('r'),
|
||||
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
|
||||
help=(
|
||||
"Path to server config file. "
|
||||
"Used to read in bcrypt_rounds and password_pepper."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -49,4 +59,21 @@ if __name__ == "__main__":
|
|||
if not password:
|
||||
password = prompt_for_pass()
|
||||
|
||||
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
|
||||
# On Python 2, make sure we decode it to Unicode before we normalise it
|
||||
if isinstance(password, bytes):
|
||||
try:
|
||||
password = password.decode(sys.stdin.encoding)
|
||||
except UnicodeDecodeError:
|
||||
print(
|
||||
"ERROR! Your password is not decodable using your terminal encoding (%s)."
|
||||
% (sys.stdin.encoding,)
|
||||
)
|
||||
|
||||
pw = unicodedata.normalize("NFKC", password)
|
||||
|
||||
hashed = bcrypt.hashpw(
|
||||
pw.encode('utf8') + password_pepper.encode("utf8"),
|
||||
bcrypt.gensalt(bcrypt_rounds),
|
||||
).decode('ascii')
|
||||
|
||||
print(hashed)
|
||||
|
|
|
@ -468,13 +468,13 @@ def set_cors_headers(request):
|
|||
Args:
|
||||
request (twisted.web.http.Request): The http request to add CORs to.
|
||||
"""
|
||||
request.setHeader("Access-Control-Allow-Origin", "*")
|
||||
request.setHeader(b"Access-Control-Allow-Origin", b"*")
|
||||
request.setHeader(
|
||||
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
|
||||
b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
request.setHeader(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Origin, X-Requested-With, Content-Type, Accept, Authorization"
|
||||
b"Access-Control-Allow-Headers",
|
||||
b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cgi
|
||||
import datetime
|
||||
import errno
|
||||
|
@ -24,6 +25,7 @@ import shutil
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import six
|
||||
from six import string_types
|
||||
from six.moves import urllib_parse as urlparse
|
||||
|
||||
|
@ -98,7 +100,7 @@ class PreviewUrlResource(Resource):
|
|||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
url = parse_string(request, "url")
|
||||
if "ts" in request.args:
|
||||
if b"ts" in request.args:
|
||||
ts = parse_integer(request, "ts")
|
||||
else:
|
||||
ts = self.clock.time_msec()
|
||||
|
@ -180,7 +182,12 @@ class PreviewUrlResource(Resource):
|
|||
cache_result["expires_ts"] > ts and
|
||||
cache_result["response_code"] / 100 == 2
|
||||
):
|
||||
defer.returnValue(cache_result["og"])
|
||||
# It may be stored as text in the database, not as bytes (such as
|
||||
# PostgreSQL). If so, encode it back before handing it on.
|
||||
og = cache_result["og"]
|
||||
if isinstance(og, six.text_type):
|
||||
og = og.encode('utf8')
|
||||
defer.returnValue(og)
|
||||
return
|
||||
|
||||
media_info = yield self._download_url(url, user)
|
||||
|
@ -213,14 +220,17 @@ class PreviewUrlResource(Resource):
|
|||
elif _is_html(media_info['media_type']):
|
||||
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
|
||||
|
||||
file = open(media_info['filename'])
|
||||
body = file.read()
|
||||
file.close()
|
||||
with open(media_info['filename'], 'rb') as file:
|
||||
body = file.read()
|
||||
|
||||
# clobber the encoding from the content-type, or default to utf-8
|
||||
# XXX: this overrides any <meta/> or XML charset headers in the body
|
||||
# which may pose problems, but so far seems to work okay.
|
||||
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
|
||||
match = re.match(
|
||||
r'.*; *charset="?(.*?)"?(;|$)',
|
||||
media_info['media_type'],
|
||||
re.I
|
||||
)
|
||||
encoding = match.group(1) if match else "utf-8"
|
||||
|
||||
og = decode_and_calc_og(body, media_info['uri'], encoding)
|
||||
|
|
|
@ -19,24 +19,17 @@ import json
|
|||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v1.admin import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
|
||||
|
||||
class UserRegisterTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
class UserRegisterTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.clock = ThreadedMemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
self.url = "/_matrix/client/r0/admin/register"
|
||||
|
||||
self.registration_handler = Mock()
|
||||
|
@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
|
||||
self.secrets = Mock()
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
self.hs = self.setup_test_homeserver()
|
||||
|
||||
self.hs.config.registration_shared_secret = u"shared"
|
||||
|
||||
self.hs.get_media_repository = Mock()
|
||||
self.hs.get_deactivate_account_handler = Mock()
|
||||
|
||||
self.resource = JsonResource(self.hs)
|
||||
register_servlets(self.hs, self.resource)
|
||||
return self.hs
|
||||
|
||||
def test_disabled(self):
|
||||
"""
|
||||
|
@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
|
||||
request, channel = make_request("POST", self.url, b'{}')
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, b'{}')
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(
|
||||
|
@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
|
||||
self.hs.get_secrets = Mock(return_value=secrets)
|
||||
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("GET", self.url)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.json_body, {"nonce": "abcd"})
|
||||
|
||||
|
@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
Calling GET on the endpoint will return a randomised nonce, which will
|
||||
only last for SALT_TIMEOUT (60s).
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("GET", self.url)
|
||||
self.render(request)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
# 59 seconds
|
||||
self.clock.advance(59)
|
||||
self.reactor.advance(59)
|
||||
|
||||
body = json.dumps({"nonce": nonce})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||
|
||||
# 61 seconds
|
||||
self.clock.advance(2)
|
||||
self.reactor.advance(2)
|
||||
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||
|
@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
"""
|
||||
Only the provided nonce can be used, as it's checked in the MAC.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("GET", self.url)
|
||||
self.render(request)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
|
@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
"mac": want_mac,
|
||||
}
|
||||
)
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("HMAC incorrect", channel.json_body["error"])
|
||||
|
@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
When the correct nonce is provided, and the right key is provided, the
|
||||
user is registered.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("GET", self.url)
|
||||
self.render(request)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
|
@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
"mac": want_mac,
|
||||
}
|
||||
)
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||
|
@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
"""
|
||||
A valid unrecognised nonce.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("GET", self.url)
|
||||
self.render(request)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
|
@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
"mac": want_mac,
|
||||
}
|
||||
)
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||
|
||||
# Now, try and reuse it
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||
|
@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
"""
|
||||
|
||||
def nonce():
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("GET", self.url)
|
||||
self.render(request)
|
||||
return channel.json_body["nonce"]
|
||||
|
||||
#
|
||||
|
@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
|
||||
# Must be present
|
||||
body = json.dumps({})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('nonce must be specified', channel.json_body["error"])
|
||||
|
@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
|
||||
# Must be present
|
||||
body = json.dumps({"nonce": nonce()})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||
|
||||
# Must be a string
|
||||
body = json.dumps({"nonce": nonce(), "username": 1234})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
|
||||
# Must be present
|
||||
body = json.dumps({"nonce": nonce(), "username": "a"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('password must be specified', channel.json_body["error"])
|
||||
|
||||
# Must be a string
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase):
|
|||
body = json.dumps(
|
||||
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
|
||||
)
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
||||
# Super long
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
|
|
@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
handlers = Mock(registration_handler=self.registration_handler)
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
self.reactor = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.reactor)
|
||||
|
||||
self.hs = self.hs = setup_test_homeserver(
|
||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
|
||||
)
|
||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||
self.hs.get_handlers = Mock(return_value=handlers)
|
||||
|
@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase):
|
|||
return_value=(user_id, token)
|
||||
)
|
||||
|
||||
request, channel = make_request(b"POST", url, request_data)
|
||||
render(request, res, self.clock)
|
||||
request, channel = make_request(self.reactor, b"POST", url, request_data)
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200")
|
||||
|
||||
|
|
|
@ -169,7 +169,7 @@ class RestHelper(object):
|
|||
path = path + "?access_token=%s" % tok
|
||||
|
||||
request, channel = make_request(
|
||||
"POST", path, json.dumps(content).encode('utf8')
|
||||
self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
|
||||
)
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
|
@ -217,7 +217,9 @@ class RestHelper(object):
|
|||
|
||||
data = {"membership": membership}
|
||||
|
||||
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
|
||||
request, channel = make_request(
|
||||
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
|
||||
)
|
||||
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
|
@ -228,18 +230,6 @@ class RestHelper(object):
|
|||
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, user_id):
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST",
|
||||
"/_matrix/client/r0/register",
|
||||
json.dumps(
|
||||
{"user": user_id, "password": "test", "type": "m.login.password"}
|
||||
),
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
defer.returnValue(response)
|
||||
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
@ -251,7 +241,9 @@ class RestHelper(object):
|
|||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
|
||||
request, channel = make_request(
|
||||
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
|
||||
)
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
assert int(channel.result["code"]) == expect_code, (
|
||||
|
|
|
@ -13,84 +13,47 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha import filter
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||
|
||||
|
||||
class FilterTestCase(unittest.TestCase):
|
||||
class FilterTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
USER_ID = "@apple:test"
|
||||
user_id = "@apple:test"
|
||||
hijack_auth = True
|
||||
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
|
||||
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||
TO_REGISTER = [filter]
|
||||
servlets = [filter.register_servlets]
|
||||
|
||||
def setUp(self):
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
|
||||
self.auth = self.hs.get_auth()
|
||||
|
||||
def get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.USER_ID),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||
return synapse.types.create_requester(
|
||||
UserID.from_string(self.USER_ID), 1, False, None
|
||||
)
|
||||
|
||||
self.auth.get_user_by_access_token = get_user_by_access_token
|
||||
self.auth.get_user_by_req = get_user_by_req
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.filtering = self.hs.get_filtering()
|
||||
self.resource = JsonResource(self.hs)
|
||||
|
||||
for r in self.TO_REGISTER:
|
||||
r.register_servlets(self.hs, self.resource)
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.filtering = hs.get_filtering()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
def test_add_filter(self):
|
||||
request, channel = make_request(
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
self.clock.advance(0)
|
||||
self.pump()
|
||||
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
|
||||
|
||||
def test_add_filter_for_other_user(self):
|
||||
request, channel = make_request(
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase):
|
|||
def test_add_filter_non_local_user(self):
|
||||
_is_mine = self.hs.is_mine
|
||||
self.hs.is_mine = lambda target_user: False
|
||||
request, channel = make_request(
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.hs.is_mine = _is_mine
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
|
@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase):
|
|||
filter_id = self.filtering.add_user_filter(
|
||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||
)
|
||||
self.clock.advance(1)
|
||||
self.reactor.advance(1)
|
||||
filter_id = filter_id.result
|
||||
request, channel = make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
||||
request, channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
|
||||
|
||||
def test_get_filter_non_existant(self):
|
||||
request, channel = make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
|
||||
request, channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase):
|
|||
# Currently invalid params do not have an appropriate errcode
|
||||
# in errors.py
|
||||
def test_get_filter_invalid_id(self):
|
||||
request, channel = make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
|
||||
request, channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
|
||||
# No ID also returns an invalid_id error
|
||||
def test_get_filter_no_id(self):
|
||||
request, channel = make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
|
||||
request, channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
|
|
|
@ -3,22 +3,19 @@ import json
|
|||
from mock import Mock
|
||||
|
||||
from twisted.python import failure
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha.register import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import make_request, render, setup_test_homeserver
|
||||
|
||||
|
||||
class RegisterRestServletTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
self.url = b"/_matrix/client/r0/register"
|
||||
|
||||
self.appservice = None
|
||||
|
@ -46,9 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler,
|
||||
)
|
||||
self.hs = setup_test_homeserver(
|
||||
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
self.hs = self.setup_test_homeserver()
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
|
@ -58,8 +53,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.hs.config.registrations_require_3pid = []
|
||||
self.hs.config.auto_join_rooms = []
|
||||
|
||||
self.resource = JsonResource(self.hs)
|
||||
register_servlets(self.hs, self.resource)
|
||||
return self.hs
|
||||
|
||||
def test_POST_appservice_registration_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
|
@ -69,10 +63,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
request_data = json.dumps({"username": "kermit"})
|
||||
|
||||
request, channel = make_request(
|
||||
request, channel = self.make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
det_data = {
|
||||
|
@ -85,25 +79,25 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
def test_POST_appservice_registration_invalid(self):
|
||||
self.appservice = None # no application service exists
|
||||
request_data = json.dumps({"username": "kermit"})
|
||||
request, channel = make_request(
|
||||
request, channel = self.make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
)
|
||||
render(request, self.resource, self.clock)
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
|
||||
def test_POST_bad_password(self):
|
||||
request_data = json.dumps({"username": "kermit", "password": 666})
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||
self.assertEquals(channel.json_body["error"], "Invalid password")
|
||||
|
||||
def test_POST_bad_username(self):
|
||||
request_data = json.dumps({"username": 777, "password": "monkey"})
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||
self.assertEquals(channel.json_body["error"], "Invalid username")
|
||||
|
@ -121,8 +115,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
self.device_handler.check_device_registered = Mock(return_value=device_id)
|
||||
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
|
@ -143,8 +137,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
||||
|
@ -155,8 +149,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.hs.config.allow_guest_access = True
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
|
||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
self.render(request)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
|
@ -169,8 +163,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
def test_POST_disabled_guest_registration(self):
|
||||
self.hs.config.allow_guest_access = False
|
||||
|
||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
render(request, self.resource, self.clock)
|
||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
self.render(request)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
|
||||
|
|
164
tests/rest/media/v1/test_url_preview.py
Normal file
164
tests/rest/media/v1/test_url_preview.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.config.repository import MediaStorageProviderConfig
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
|
||||
hijack_auth = True
|
||||
user_id = "@test:user"
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
|
||||
config = self.default_config()
|
||||
config.url_preview_enabled = True
|
||||
config.max_spider_size = 9999999
|
||||
config.url_preview_url_blacklist = []
|
||||
config.media_store_path = self.storage_path
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
|
||||
loaded = list(load_module(provider_config)) + [
|
||||
MediaStorageProviderConfig(False, False, False)
|
||||
]
|
||||
|
||||
config.media_storage_providers = [loaded]
|
||||
|
||||
hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
|
||||
self.fetches = []
|
||||
|
||||
def get_file(url, output_stream, max_size):
|
||||
"""
|
||||
Returns tuple[int,dict,str,int] of file length, response headers,
|
||||
absolute URI, and response code.
|
||||
"""
|
||||
|
||||
def write_to(r):
|
||||
data, response = r
|
||||
output_stream.write(data)
|
||||
return response
|
||||
|
||||
d = Deferred()
|
||||
d.addCallback(write_to)
|
||||
self.fetches.append((d, url))
|
||||
return d
|
||||
|
||||
client = Mock()
|
||||
client.get_file = get_file
|
||||
|
||||
self.media_repo = hs.get_media_repository_resource()
|
||||
preview_url = self.media_repo.children[b'preview_url']
|
||||
preview_url.client = client
|
||||
self.preview_url = preview_url
|
||||
|
||||
def test_cache_returns_correct_type(self):
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# We've made one fetch
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
|
||||
end_content = (
|
||||
b'<html><head>'
|
||||
b'<meta property="og:title" content="~matrix~" />'
|
||||
b'<meta property="og:description" content="hi" />'
|
||||
b'</head></html>'
|
||||
)
|
||||
|
||||
self.fetches[0][0].callback(
|
||||
(
|
||||
end_content,
|
||||
(
|
||||
len(end_content),
|
||||
{
|
||||
b"Content-Length": [b"%d" % (len(end_content))],
|
||||
b"Content-Type": [b'text/html; charset="utf8"'],
|
||||
},
|
||||
"https://example.com",
|
||||
200,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
||||
|
||||
# Check the cache returns the correct response
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# Only one fetch, still, since we'll lean on the cache
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
|
||||
# Check the cache response has the same content
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
||||
|
||||
# Clear the in-memory cache
|
||||
self.assertIn("matrix.org", self.preview_url._cache)
|
||||
self.preview_url._cache.pop("matrix.org")
|
||||
self.assertNotIn("matrix.org", self.preview_url._cache)
|
||||
|
||||
# Check the database cache returns the correct response
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# Only one fetch, still, since we'll lean on the cache
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
|
||||
# Check the cache response has the same content
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
|
@ -34,6 +34,7 @@ class FakeChannel(object):
|
|||
wire).
|
||||
"""
|
||||
|
||||
_reactor = attr.ib()
|
||||
result = attr.ib(default=attr.Factory(dict))
|
||||
_producer = None
|
||||
|
||||
|
@ -56,6 +57,8 @@ class FakeChannel(object):
|
|||
self.result["headers"] = headers
|
||||
|
||||
def write(self, content):
|
||||
assert isinstance(content, bytes), "Should be bytes! " + repr(content)
|
||||
|
||||
if "body" not in self.result:
|
||||
self.result["body"] = b""
|
||||
|
||||
|
@ -63,6 +66,15 @@ class FakeChannel(object):
|
|||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self._producer = producer
|
||||
self.producerStreaming = streaming
|
||||
|
||||
def _produce():
|
||||
if self._producer:
|
||||
self._producer.resumeProducing()
|
||||
self._reactor.callLater(0.1, _produce)
|
||||
|
||||
if not streaming:
|
||||
self._reactor.callLater(0.0, _produce)
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self._producer is None:
|
||||
|
@ -105,7 +117,13 @@ class FakeSite:
|
|||
|
||||
|
||||
def make_request(
|
||||
method, path, content=b"", access_token=None, request=SynapseRequest, shorthand=True
|
||||
reactor,
|
||||
method,
|
||||
path,
|
||||
content=b"",
|
||||
access_token=None,
|
||||
request=SynapseRequest,
|
||||
shorthand=True,
|
||||
):
|
||||
"""
|
||||
Make a web request using the given method and path, feed it the
|
||||
|
@ -138,7 +156,7 @@ def make_request(
|
|||
content = content.encode('utf8')
|
||||
|
||||
site = FakeSite()
|
||||
channel = FakeChannel()
|
||||
channel = FakeChannel(reactor)
|
||||
|
||||
req = request(site, channel)
|
||||
req.process = lambda: b""
|
||||
|
|
|
@ -21,30 +21,20 @@ from mock import Mock, NonCallableMock
|
|||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha import register, sync
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
|
||||
|
||||
class TestMauLimit(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
self.clock = Clock(self.reactor)
|
||||
class TestMauLimit(unittest.HomeserverTestCase):
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
servlets = [register.register_servlets, sync.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.hs = self.setup_test_homeserver(
|
||||
"red",
|
||||
http_client=None,
|
||||
clock=self.clock,
|
||||
reactor=self.reactor,
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
)
|
||||
|
@ -63,10 +53,7 @@ class TestMauLimit(unittest.TestCase):
|
|||
self.hs.config.server_notices_mxid_display_name = None
|
||||
self.hs.config.server_notices_mxid_avatar_url = None
|
||||
self.hs.config.server_notices_room_name = "Test Server Notice Room"
|
||||
|
||||
self.resource = JsonResource(self.hs)
|
||||
register.register_servlets(self.hs, self.resource)
|
||||
sync.register_servlets(self.hs, self.resource)
|
||||
return self.hs
|
||||
|
||||
def test_simple_deny_mau(self):
|
||||
# Create and sync so that the MAU counts get updated
|
||||
|
@ -193,8 +180,8 @@ class TestMauLimit(unittest.TestCase):
|
|||
}
|
||||
)
|
||||
|
||||
request, channel = make_request("POST", "/register", request_data)
|
||||
render(request, self.resource, self.reactor)
|
||||
request, channel = self.make_request("POST", "/register", request_data)
|
||||
self.render(request)
|
||||
|
||||
if channel.code != 200:
|
||||
raise HttpResponseException(
|
||||
|
@ -206,10 +193,10 @@ class TestMauLimit(unittest.TestCase):
|
|||
return access_token
|
||||
|
||||
def do_sync_for_user(self, token):
|
||||
request, channel = make_request(
|
||||
request, channel = self.make_request(
|
||||
"GET", "/sync", access_token=token
|
||||
)
|
||||
render(request, self.resource, self.reactor)
|
||||
self.render(request)
|
||||
|
||||
if channel.code != 200:
|
||||
raise HttpResponseException(
|
||||
|
|
|
@ -57,7 +57,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
|
||||
)
|
||||
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
|
||||
request, channel = make_request(
|
||||
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
|
||||
)
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
|
||||
|
@ -75,7 +77,7 @@ class JsonResourceTests(unittest.TestCase):
|
|||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'500')
|
||||
|
@ -98,7 +100,7 @@ class JsonResourceTests(unittest.TestCase):
|
|||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'500')
|
||||
|
@ -115,7 +117,7 @@ class JsonResourceTests(unittest.TestCase):
|
|||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'403')
|
||||
|
@ -136,7 +138,7 @@ class JsonResourceTests(unittest.TestCase):
|
|||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/_matrix/foobar")
|
||||
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
|
||||
render(request, res, self.reactor)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'400')
|
||||
|
|
|
@ -23,7 +23,6 @@ from synapse.rest.client.v2_alpha.register import register_servlets
|
|||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import make_request
|
||||
|
||||
|
||||
class TermsTestCase(unittest.HomeserverTestCase):
|
||||
|
@ -92,7 +91,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
# We don't bother checking that the response is correct - we'll leave that to
|
||||
|
@ -110,7 +109,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
|
|||
},
|
||||
}
|
||||
)
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.render(request)
|
||||
|
||||
# We're interested in getting a response that looks like a successful
|
||||
|
|
|
@ -189,11 +189,11 @@ class HomeserverTestCase(TestCase):
|
|||
for servlet in self.servlets:
|
||||
servlet(self.hs, self.resource)
|
||||
|
||||
from tests.rest.client.v1.utils import RestHelper
|
||||
|
||||
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
|
||||
|
||||
if hasattr(self, "user_id"):
|
||||
from tests.rest.client.v1.utils import RestHelper
|
||||
|
||||
self.helper = RestHelper(self.hs, self.resource, self.user_id)
|
||||
|
||||
if self.hijack_auth:
|
||||
|
||||
def get_user_by_access_token(token=None, allow_guest=False):
|
||||
|
@ -285,7 +285,9 @@ class HomeserverTestCase(TestCase):
|
|||
if isinstance(content, dict):
|
||||
content = json.dumps(content).encode('utf8')
|
||||
|
||||
return make_request(method, path, content, access_token, request, shorthand)
|
||||
return make_request(
|
||||
self.reactor, method, path, content, access_token, request, shorthand
|
||||
)
|
||||
|
||||
def render(self, request):
|
||||
"""
|
||||
|
|
2
tox.ini
2
tox.ini
|
@ -122,7 +122,7 @@ skip_install = True
|
|||
basepython = python3.6
|
||||
deps =
|
||||
flake8
|
||||
commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"
|
||||
commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"
|
||||
|
||||
[testenv:check_isort]
|
||||
skip_install = True
|
||||
|
|
Loading…
Reference in a new issue