Convert some of the media REST code to async/await (#7110)

This commit is contained in:
Patrick Cloke 2020-03-20 07:20:02 -04:00 committed by GitHub
parent c2db6599c8
commit caec7d4fa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 111 deletions

1
changelog.d/7110.misc Normal file
View file

@ -0,0 +1 @@
Convert some of synapse.rest.media to async/await.

View file

@ -24,7 +24,6 @@ from six import iteritems
import twisted.internet.error
import twisted.web.http
from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@ -114,15 +113,14 @@ class MediaRepository(object):
"update_recently_accessed_media", self._update_recently_accessed
)
@defer.inlineCallbacks
def _update_recently_accessed(self):
async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
yield self.store.update_cached_last_access_time(
await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)
@ -138,8 +136,7 @@ class MediaRepository(object):
else:
self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
@ -158,11 +155,11 @@ class MediaRepository(object):
file_info = FileInfo(server_name=None, file_id=media_id)
fname = yield self.media_storage.store_file(content, file_info)
fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@ -171,12 +168,11 @@ class MediaRepository(object):
user_id=auth_user,
)
yield self._generate_thumbnails(None, media_id, media_id, media_type)
await self._generate_thumbnails(None, media_id, media_id, media_type)
return "mxc://%s/%s" % (self.server_name, media_id)
@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
@ -190,7 +186,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@ -204,13 +200,12 @@ class MediaRepository(object):
file_info = FileInfo(None, media_id, url_cache=url_cache)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
@ -236,8 +231,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@ -246,14 +241,13 @@ class MediaRepository(object):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
@ -274,8 +268,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@ -286,8 +280,7 @@ class MediaRepository(object):
return media_info
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@ -299,7 +292,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@ -317,19 +310,18 @@ class MediaRepository(object):
logger.info("Media is quarantined")
raise NotFoundError()
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(server_name, media_id, file_id)
media_info = await self._download_remote_file(server_name, media_id, file_id)
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@ -351,7 +343,7 @@ class MediaRepository(object):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
length, headers = yield self.client.get_file(
length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
@ -397,7 +389,7 @@ class MediaRepository(object):
)
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
await finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
@ -405,7 +397,7 @@ class MediaRepository(object):
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
@ -423,7 +415,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
await self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
@ -458,16 +450,15 @@ class MediaRepository(object):
return t_byte_source
@defer.inlineCallbacks
def generate_local_exact_thumbnail(
async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@ -490,7 +481,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@ -500,22 +491,21 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return output_path
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@ -537,7 +527,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@ -547,7 +537,7 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@ -560,8 +550,7 @@ class MediaRepository(object):
return output_path
@defer.inlineCallbacks
def _generate_thumbnails(
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
@ -582,7 +571,7 @@ class MediaRepository(object):
if not requirements:
return
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
@ -600,7 +589,7 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
m_width, m_height = yield defer_to_thread(
m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
@ -620,11 +609,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
@ -646,7 +635,7 @@ class MediaRepository(object):
url_cache=url_cache,
)
output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@ -656,7 +645,7 @@ class MediaRepository(object):
# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@ -667,15 +656,14 @@ class MediaRepository(object):
t_len,
)
else:
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
async def delete_old_remote_media(self, before_ts):
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@ -689,7 +677,7 @@ class MediaRepository(object):
# TODO: Should we delete from the backup store
with (yield self.remote_media_linearizer.queue(key)):
with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
@ -705,7 +693,7 @@ class MediaRepository(object):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
yield self.store.delete_remote_media(origin, media_id)
await self.store.delete_remote_media(origin, media_id)
deleted += 1
return {"deleted": deleted}

View file

@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
@defer.inlineCallbacks
def _do_preview(self, url, user, ts):
async def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource):
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource):
og = og.encode("utf8")
return og
media_info = yield self._download_url(url, user)
media_info = await self._download_url(url, user)
logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, media_info["media_type"], url_cache=True
)
@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
image_info = yield self._download_url(
image_info = await self._download_url(
_rebase_url(og["og:image"], media_info["uri"]), user
)
if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource):
jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
await self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource):
return jsonog.encode("utf8")
@defer.inlineCallbacks
def _download_url(self, url, user):
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'", url)
length, headers, uri, code = yield self.client.get_file(
length, headers, uri, code = await self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size
)
except SynapseError:
@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource):
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
yield finish()
await finish()
try:
if b"Content-Type" in headers:
@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource):
download_name = get_filename_from_headers(headers)
yield self.store.store_local_media(
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource):
"expire_url_cache_data", self._expire_url_cache_data
)
@defer.inlineCallbacks
def _expire_url_cache_data(self):
async def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource):
logger.info("Running url preview cache expiry")
if not (yield self.store.db.updates.has_completed_background_updates()):
if not (await self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
media_ids = yield self.store.get_expired_url_cache(now)
media_ids = await self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
yield self.store.delete_url_cache(removed_media)
await self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource):
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
media_ids = yield self.store.get_url_cache_media_before(expire_before)
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
yield self.store.delete_url_cache_media(removed_media)
await self.store.delete_url_cache_media(removed_media)
logger.info("Deleted %d media from url cache", len(removed_media))

View file

@ -16,8 +16,6 @@
import logging
from twisted.internet import defer
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
)
self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(
async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(
async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(
async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)