mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-16 23:11:34 +01:00
Re-introduce federation /download endpoint (#17350)
This commit is contained in:
parent
f79dbd0f61
commit
a023538822
8 changed files with 588 additions and 11 deletions
2
changelog.d/17350.feature
Normal file
2
changelog.d/17350.feature
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md)
|
||||||
|
by adding a federation /download endpoint.
|
|
@ -33,6 +33,7 @@ from synapse.federation.transport.server.federation import (
|
||||||
FEDERATION_SERVLET_CLASSES,
|
FEDERATION_SERVLET_CLASSES,
|
||||||
FederationAccountStatusServlet,
|
FederationAccountStatusServlet,
|
||||||
FederationUnstableClientKeysClaimServlet,
|
FederationUnstableClientKeysClaimServlet,
|
||||||
|
FederationUnstableMediaDownloadServlet,
|
||||||
)
|
)
|
||||||
from synapse.http.server import HttpServer, JsonResource
|
from synapse.http.server import HttpServer, JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
|
@ -315,6 +316,13 @@ def register_servlets(
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if servletclass == FederationUnstableMediaDownloadServlet:
|
||||||
|
if (
|
||||||
|
not hs.config.server.enable_media_repo
|
||||||
|
or not hs.config.experimental.msc3916_authenticated_media_enabled
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
servletclass(
|
servletclass(
|
||||||
hs=hs,
|
hs=hs,
|
||||||
authenticator=authenticator,
|
authenticator=authenticator,
|
||||||
|
|
|
@ -360,9 +360,25 @@ class BaseFederationServlet:
|
||||||
"request"
|
"request"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
if (
|
||||||
|
func.__self__.__class__.__name__ # type: ignore
|
||||||
|
== "FederationUnstableMediaDownloadServlet"
|
||||||
|
):
|
||||||
|
response = await func(
|
||||||
|
origin, content, request, *args, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
response = await func(
|
response = await func(
|
||||||
origin, content, request.args, *args, **kwargs
|
origin, content, request.args, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
func.__self__.__class__.__name__ # type: ignore
|
||||||
|
== "FederationUnstableMediaDownloadServlet"
|
||||||
|
):
|
||||||
|
response = await func(
|
||||||
|
origin, content, request, *args, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response = await func(
|
response = await func(
|
||||||
origin, content, request.args, *args, **kwargs
|
origin, content, request.args, *args, **kwargs
|
||||||
|
|
|
@ -44,10 +44,13 @@ from synapse.federation.transport.server._base import (
|
||||||
)
|
)
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_boolean_from_args,
|
parse_boolean_from_args,
|
||||||
|
parse_integer,
|
||||||
parse_integer_from_args,
|
parse_integer_from_args,
|
||||||
parse_string_from_args,
|
parse_string_from_args,
|
||||||
parse_strings_from_args,
|
parse_strings_from_args,
|
||||||
)
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import SYNAPSE_VERSION
|
from synapse.util import SYNAPSE_VERSION
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
@ -787,6 +790,43 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
|
||||||
return 200, {"account_statuses": statuses, "failures": failures}
|
return 200, {"account_statuses": statuses, "failures": failures}
|
||||||
|
|
||||||
|
|
||||||
|
class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
|
||||||
|
"""
|
||||||
|
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
|
||||||
|
a multipart/mixed response consisting of a JSON object and the requested media
|
||||||
|
item. This endpoint only returns local media.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATH = "/media/download/(?P<media_id>[^/]*)"
|
||||||
|
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
|
||||||
|
RATELIMIT = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
ratelimiter: FederationRateLimiter,
|
||||||
|
authenticator: Authenticator,
|
||||||
|
server_name: str,
|
||||||
|
):
|
||||||
|
super().__init__(hs, authenticator, ratelimiter, server_name)
|
||||||
|
self.media_repo = self.hs.get_media_repository()
|
||||||
|
|
||||||
|
async def on_GET(
|
||||||
|
self,
|
||||||
|
origin: Optional[str],
|
||||||
|
content: Literal[None],
|
||||||
|
request: SynapseRequest,
|
||||||
|
media_id: str,
|
||||||
|
) -> None:
|
||||||
|
max_timeout_ms = parse_integer(
|
||||||
|
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
|
||||||
|
)
|
||||||
|
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
|
||||||
|
await self.media_repo.get_local_media(
|
||||||
|
request, media_id, None, max_timeout_ms, federation=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
|
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
|
||||||
FederationSendServlet,
|
FederationSendServlet,
|
||||||
FederationEventServlet,
|
FederationEventServlet,
|
||||||
|
@ -818,4 +858,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
|
||||||
FederationV1SendKnockServlet,
|
FederationV1SendKnockServlet,
|
||||||
FederationMakeKnockServlet,
|
FederationMakeKnockServlet,
|
||||||
FederationAccountStatusServlet,
|
FederationAccountStatusServlet,
|
||||||
|
FederationUnstableMediaDownloadServlet,
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,7 +25,16 @@ import os
|
||||||
import urllib
|
import urllib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Awaitable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -37,8 +46,13 @@ from synapse.api.errors import Codes, cs_error
|
||||||
from synapse.http.server import finish_request, respond_with_json
|
from synapse.http.server import finish_request, respond_with_json
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
from synapse.util import Clock
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.storage.databases.main.media_repository import LocalMedia
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# list all text content types that will have the charset default to UTF-8 when
|
# list all text content types that will have the charset default to UTF-8 when
|
||||||
|
@ -260,6 +274,68 @@ def _can_encode_filename_as_token(x: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def respond_with_multipart_responder(
|
||||||
|
clock: Clock,
|
||||||
|
request: SynapseRequest,
|
||||||
|
responder: "Optional[Responder]",
|
||||||
|
media_info: "LocalMedia",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Responds to requests originating from the federation media `/download` endpoint by
|
||||||
|
streaming a multipart/mixed response
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clock:
|
||||||
|
request: the federation request to respond to
|
||||||
|
responder: the responder which will send the response
|
||||||
|
media_info: metadata about the media item
|
||||||
|
"""
|
||||||
|
if not responder:
|
||||||
|
respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If we have a responder we *must* use it as a context manager.
|
||||||
|
with responder:
|
||||||
|
if request._disconnected:
|
||||||
|
logger.warning(
|
||||||
|
"Not sending response to request %s, already disconnected.", request
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
from synapse.media.media_storage import MultipartFileConsumer
|
||||||
|
|
||||||
|
# note that currently the json_object is just {}, this will change when linked media
|
||||||
|
# is implemented
|
||||||
|
multipart_consumer = MultipartFileConsumer(
|
||||||
|
clock, request, media_info.media_type, {}, media_info.media_length
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Responding to media request with responder %s", responder)
|
||||||
|
if media_info.media_length is not None:
|
||||||
|
content_length = multipart_consumer.content_length()
|
||||||
|
assert content_length is not None
|
||||||
|
request.setHeader(b"Content-Length", b"%d" % (content_length,))
|
||||||
|
|
||||||
|
request.setHeader(
|
||||||
|
b"Content-Type",
|
||||||
|
b"multipart/mixed; boundary=%s" % multipart_consumer.boundary,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await responder.write_to_consumer(multipart_consumer)
|
||||||
|
except Exception as e:
|
||||||
|
# The majority of the time this will be due to the client having gone
|
||||||
|
# away. Unfortunately, Twisted simply throws a generic exception at us
|
||||||
|
# in that case.
|
||||||
|
logger.warning("Failed to write to consumer: %s %s", type(e), e)
|
||||||
|
|
||||||
|
# Unregister the producer, if it has one, so Twisted doesn't complain
|
||||||
|
if request.producer:
|
||||||
|
request.unregisterProducer()
|
||||||
|
|
||||||
|
finish_request(request)
|
||||||
|
|
||||||
|
|
||||||
async def respond_with_responder(
|
async def respond_with_responder(
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
responder: "Optional[Responder]",
|
responder: "Optional[Responder]",
|
||||||
|
|
|
@ -54,6 +54,7 @@ from synapse.media._base import (
|
||||||
ThumbnailInfo,
|
ThumbnailInfo,
|
||||||
get_filename_from_headers,
|
get_filename_from_headers,
|
||||||
respond_404,
|
respond_404,
|
||||||
|
respond_with_multipart_responder,
|
||||||
respond_with_responder,
|
respond_with_responder,
|
||||||
)
|
)
|
||||||
from synapse.media.filepath import MediaFilePaths
|
from synapse.media.filepath import MediaFilePaths
|
||||||
|
@ -429,6 +430,7 @@ class MediaRepository:
|
||||||
media_id: str,
|
media_id: str,
|
||||||
name: Optional[str],
|
name: Optional[str],
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
federation: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Responds to requests for local media, if exists, or returns 404.
|
"""Responds to requests for local media, if exists, or returns 404.
|
||||||
|
|
||||||
|
@ -440,6 +442,7 @@ class MediaRepository:
|
||||||
the filename in the Content-Disposition header of the response.
|
the filename in the Content-Disposition header of the response.
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
federation: whether the local media being fetched is for a federation request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves once a response has successfully been written to request
|
Resolves once a response has successfully been written to request
|
||||||
|
@ -460,6 +463,11 @@ class MediaRepository:
|
||||||
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
|
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
|
if federation:
|
||||||
|
await respond_with_multipart_responder(
|
||||||
|
self.clock, request, responder, media_info
|
||||||
|
)
|
||||||
|
else:
|
||||||
await respond_with_responder(
|
await respond_with_responder(
|
||||||
request, responder, media_type, media_length, upload_name
|
request, responder, media_type, media_length, upload_name
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,9 +19,12 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from contextlib import closing
|
||||||
|
from io import BytesIO
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
IO,
|
IO,
|
||||||
|
@ -30,24 +33,35 @@ from typing import (
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
BinaryIO,
|
BinaryIO,
|
||||||
Callable,
|
Callable,
|
||||||
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.internet import interfaces
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.interfaces import IConsumer
|
from twisted.internet.interfaces import IConsumer
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
from synapse.api.errors import NotFoundError
|
from synapse.api.errors import NotFoundError
|
||||||
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
from synapse.logging.context import (
|
||||||
|
defer_to_thread,
|
||||||
|
make_deferred_yieldable,
|
||||||
|
run_in_background,
|
||||||
|
)
|
||||||
from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
|
from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||||
|
|
||||||
|
from ..types import JsonDict
|
||||||
from ._base import FileInfo, Responder
|
from ._base import FileInfo, Responder
|
||||||
from .filepath import MediaFilePaths
|
from .filepath import MediaFilePaths
|
||||||
|
|
||||||
|
@ -57,6 +71,8 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CRLF = b"\r\n"
|
||||||
|
|
||||||
|
|
||||||
class MediaStorage:
|
class MediaStorage:
|
||||||
"""Responsible for storing/fetching files from local sources.
|
"""Responsible for storing/fetching files from local sources.
|
||||||
|
@ -174,7 +190,7 @@ class MediaStorage:
|
||||||
and configured storage providers.
|
and configured storage providers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_info
|
file_info: Metadata about the media file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns a Responder if the file was found, otherwise None.
|
Returns a Responder if the file was found, otherwise None.
|
||||||
|
@ -316,7 +332,7 @@ class FileResponder(Responder):
|
||||||
"""Wraps an open file that can be sent to a request.
|
"""Wraps an open file that can be sent to a request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
open_file: A file like object to be streamed ot the client,
|
open_file: A file like object to be streamed to the client,
|
||||||
is closed when finished streaming.
|
is closed when finished streaming.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -370,3 +386,240 @@ class ReadableFileWrapper:
|
||||||
|
|
||||||
# We yield to the reactor by sleeping for 0 seconds.
|
# We yield to the reactor by sleeping for 0 seconds.
|
||||||
await self.clock.sleep(0)
|
await self.clock.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(interfaces.IConsumer)
|
||||||
|
@implementer(interfaces.IPushProducer)
|
||||||
|
class MultipartFileConsumer:
|
||||||
|
"""Wraps a given consumer so that any data that gets written to it gets
|
||||||
|
converted to a multipart format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clock: Clock,
|
||||||
|
wrapped_consumer: interfaces.IConsumer,
|
||||||
|
file_content_type: str,
|
||||||
|
json_object: JsonDict,
|
||||||
|
content_length: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self.clock = clock
|
||||||
|
self.wrapped_consumer = wrapped_consumer
|
||||||
|
self.json_field = json_object
|
||||||
|
self.json_field_written = False
|
||||||
|
self.content_type_written = False
|
||||||
|
self.file_content_type = file_content_type
|
||||||
|
self.boundary = uuid4().hex.encode("ascii")
|
||||||
|
|
||||||
|
# The producer that registered with us, and if it's a push or pull
|
||||||
|
# producer.
|
||||||
|
self.producer: Optional["interfaces.IProducer"] = None
|
||||||
|
self.streaming: Optional[bool] = None
|
||||||
|
|
||||||
|
# Whether the wrapped consumer has asked us to pause.
|
||||||
|
self.paused = False
|
||||||
|
|
||||||
|
self.length = content_length
|
||||||
|
|
||||||
|
### IConsumer APIs ###
|
||||||
|
|
||||||
|
def registerProducer(
|
||||||
|
self, producer: "interfaces.IProducer", streaming: bool
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register to receive data from a producer.
|
||||||
|
|
||||||
|
This sets self to be a consumer for a producer. When this object runs
|
||||||
|
out of data (as when a send(2) call on a socket succeeds in moving the
|
||||||
|
last data from a userspace buffer into a kernelspace buffer), it will
|
||||||
|
ask the producer to resumeProducing().
|
||||||
|
|
||||||
|
For L{IPullProducer} providers, C{resumeProducing} will be called once
|
||||||
|
each time data is required.
|
||||||
|
|
||||||
|
For L{IPushProducer} providers, C{pauseProducing} will be called
|
||||||
|
whenever the write buffer fills up and C{resumeProducing} will only be
|
||||||
|
called when it empties. The consumer will only call C{resumeProducing}
|
||||||
|
to balance a previous C{pauseProducing} call; the producer is assumed
|
||||||
|
to start in an un-paused state.
|
||||||
|
|
||||||
|
@param streaming: C{True} if C{producer} provides L{IPushProducer},
|
||||||
|
C{False} if C{producer} provides L{IPullProducer}.
|
||||||
|
|
||||||
|
@raise RuntimeError: If a producer is already registered.
|
||||||
|
"""
|
||||||
|
self.producer = producer
|
||||||
|
self.streaming = streaming
|
||||||
|
|
||||||
|
self.wrapped_consumer.registerProducer(self, True)
|
||||||
|
|
||||||
|
# kick off producing if `self.producer` is not a streaming producer
|
||||||
|
if not streaming:
|
||||||
|
self.resumeProducing()
|
||||||
|
|
||||||
|
def unregisterProducer(self) -> None:
|
||||||
|
"""
|
||||||
|
Stop consuming data from a producer, without disconnecting.
|
||||||
|
"""
|
||||||
|
self.wrapped_consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
|
||||||
|
self.wrapped_consumer.unregisterProducer()
|
||||||
|
self.paused = True
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
The producer will write data by calling this method.
|
||||||
|
|
||||||
|
The implementation must be non-blocking and perform whatever
|
||||||
|
buffering is necessary. If the producer has provided enough data
|
||||||
|
for now and it is a L{IPushProducer}, the consumer may call its
|
||||||
|
C{pauseProducing} method.
|
||||||
|
"""
|
||||||
|
if not self.json_field_written:
|
||||||
|
self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
|
||||||
|
|
||||||
|
content_type = Header(b"Content-Type", b"application/json")
|
||||||
|
self.wrapped_consumer.write(bytes(content_type) + CRLF)
|
||||||
|
|
||||||
|
json_field = json.dumps(self.json_field)
|
||||||
|
json_bytes = json_field.encode("utf-8")
|
||||||
|
self.wrapped_consumer.write(CRLF + json_bytes)
|
||||||
|
self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
|
||||||
|
|
||||||
|
self.json_field_written = True
|
||||||
|
|
||||||
|
# if we haven't written the content type yet, do so
|
||||||
|
if not self.content_type_written:
|
||||||
|
type = self.file_content_type.encode("utf-8")
|
||||||
|
content_type = Header(b"Content-Type", type)
|
||||||
|
self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
|
||||||
|
self.content_type_written = True
|
||||||
|
|
||||||
|
self.wrapped_consumer.write(data)
|
||||||
|
|
||||||
|
### IPushProducer APIs ###
|
||||||
|
|
||||||
|
def stopProducing(self) -> None:
|
||||||
|
"""
|
||||||
|
Stop producing data.
|
||||||
|
|
||||||
|
This tells a producer that its consumer has died, so it must stop
|
||||||
|
producing data for good.
|
||||||
|
"""
|
||||||
|
assert self.producer is not None
|
||||||
|
|
||||||
|
self.paused = True
|
||||||
|
self.producer.stopProducing()
|
||||||
|
|
||||||
|
def pauseProducing(self) -> None:
|
||||||
|
"""
|
||||||
|
Pause producing data.
|
||||||
|
|
||||||
|
Tells a producer that it has produced too much data to process for
|
||||||
|
the time being, and to stop until C{resumeProducing()} is called.
|
||||||
|
"""
|
||||||
|
assert self.producer is not None
|
||||||
|
|
||||||
|
self.paused = True
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
cast("interfaces.IPushProducer", self.producer).pauseProducing()
|
||||||
|
else:
|
||||||
|
self.paused = True
|
||||||
|
|
||||||
|
def resumeProducing(self) -> None:
|
||||||
|
"""
|
||||||
|
Resume producing data.
|
||||||
|
|
||||||
|
This tells a producer to re-add itself to the main loop and produce
|
||||||
|
more data for its consumer.
|
||||||
|
"""
|
||||||
|
assert self.producer is not None
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
cast("interfaces.IPushProducer", self.producer).resumeProducing()
|
||||||
|
else:
|
||||||
|
# If the producer is not a streaming producer we need to start
|
||||||
|
# repeatedly calling `resumeProducing` in a loop.
|
||||||
|
run_in_background(self._resumeProducingRepeatedly)
|
||||||
|
|
||||||
|
def content_length(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Calculate the content length of the multipart response
|
||||||
|
in bytes.
|
||||||
|
"""
|
||||||
|
if not self.length:
|
||||||
|
return None
|
||||||
|
# calculate length of json field and content-type header
|
||||||
|
json_field = json.dumps(self.json_field)
|
||||||
|
json_bytes = json_field.encode("utf-8")
|
||||||
|
json_length = len(json_bytes)
|
||||||
|
|
||||||
|
type = self.file_content_type.encode("utf-8")
|
||||||
|
content_type = Header(b"Content-Type", type)
|
||||||
|
type_length = len(bytes(content_type))
|
||||||
|
|
||||||
|
# 154 is the length of the elements that aren't variable, ie
|
||||||
|
# CRLFs and boundary strings, etc
|
||||||
|
self.length += json_length + type_length + 154
|
||||||
|
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
### Internal APIs. ###
|
||||||
|
|
||||||
|
async def _resumeProducingRepeatedly(self) -> None:
|
||||||
|
assert self.producer is not None
|
||||||
|
assert not self.streaming
|
||||||
|
|
||||||
|
producer = cast("interfaces.IPullProducer", self.producer)
|
||||||
|
|
||||||
|
self.paused = False
|
||||||
|
while not self.paused:
|
||||||
|
producer.resumeProducing()
|
||||||
|
await self.clock.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
class Header:
|
||||||
|
"""
|
||||||
|
`Header` This class is a tiny wrapper that produces
|
||||||
|
request headers. We can't use standard python header
|
||||||
|
class because it encodes unicode fields using =? bla bla ?=
|
||||||
|
encoding, which is correct, but no one in HTTP world expects
|
||||||
|
that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: bytes,
|
||||||
|
value: Any,
|
||||||
|
params: Optional[List[Tuple[Any, Any]]] = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.value = value
|
||||||
|
self.params = params or []
|
||||||
|
|
||||||
|
def add_param(self, name: Any, value: Any) -> None:
|
||||||
|
self.params.append((name, value))
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
with closing(BytesIO()) as h:
|
||||||
|
h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
|
||||||
|
if self.params:
|
||||||
|
for name, val in self.params:
|
||||||
|
h.write(b"; ")
|
||||||
|
h.write(escape(name).encode("us-ascii"))
|
||||||
|
h.write(b"=")
|
||||||
|
h.write(b'"' + escape(val).encode("utf-8") + b'"')
|
||||||
|
h.seek(0)
|
||||||
|
return h.read()
|
||||||
|
|
||||||
|
|
||||||
|
def escape(value: Union[str, bytes]) -> str:
|
||||||
|
"""
|
||||||
|
This function prevents header values from corrupting the request,
|
||||||
|
a newline in the file name parameter makes form-data request unreadable
|
||||||
|
for a majority of parsers. (stolen from treq.multipart)
|
||||||
|
"""
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.decode("utf-8")
|
||||||
|
return value.replace("\r", "").replace("\n", "").replace('"', '\\"')
|
||||||
|
|
173
tests/federation/test_federation_media.py
Normal file
173
tests/federation/test_federation_media.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2024 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
# Originally licensed under the Apache License, Version 2.0:
|
||||||
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||||
|
#
|
||||||
|
# [This file includes modifications made by New Vector Limited]
|
||||||
|
#
|
||||||
|
#
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.media.filepath import MediaFilePaths
|
||||||
|
from synapse.media.media_storage import MediaStorage
|
||||||
|
from synapse.media.storage_provider import (
|
||||||
|
FileStorageProviderBackend,
|
||||||
|
StorageProviderWrapper,
|
||||||
|
)
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import UserID
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
from tests.test_utils import SMALL_PNG
|
||||||
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
|
||||||
|
class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
super().prepare(reactor, clock, hs)
|
||||||
|
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
|
||||||
|
self.addCleanup(shutil.rmtree, self.test_dir)
|
||||||
|
self.primary_base_path = os.path.join(self.test_dir, "primary")
|
||||||
|
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
|
||||||
|
|
||||||
|
hs.config.media.media_store_path = self.primary_base_path
|
||||||
|
|
||||||
|
storage_providers = [
|
||||||
|
StorageProviderWrapper(
|
||||||
|
FileStorageProviderBackend(hs, self.secondary_base_path),
|
||||||
|
store_local=True,
|
||||||
|
store_remote=False,
|
||||||
|
store_synchronous=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||||
|
self.media_storage = MediaStorage(
|
||||||
|
hs, self.primary_base_path, self.filepaths, storage_providers
|
||||||
|
)
|
||||||
|
self.media_repo = hs.get_media_repository()
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{"experimental_features": {"msc3916_authenticated_media_enabled": True}}
|
||||||
|
)
|
||||||
|
def test_file_download(self) -> None:
|
||||||
|
content = io.BytesIO(b"file_to_stream")
|
||||||
|
content_uri = self.get_success(
|
||||||
|
self.media_repo.create_content(
|
||||||
|
"text/plain",
|
||||||
|
"test_upload",
|
||||||
|
content,
|
||||||
|
46,
|
||||||
|
UserID.from_string("@user_id:whatever.org"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# test with a text file
|
||||||
|
channel = self.make_signed_federation_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
|
||||||
|
)
|
||||||
|
self.pump()
|
||||||
|
self.assertEqual(200, channel.code)
|
||||||
|
|
||||||
|
content_type = channel.headers.getRawHeaders("content-type")
|
||||||
|
assert content_type is not None
|
||||||
|
assert "multipart/mixed" in content_type[0]
|
||||||
|
assert "boundary" in content_type[0]
|
||||||
|
|
||||||
|
# extract boundary
|
||||||
|
boundary = content_type[0].split("boundary=")[1]
|
||||||
|
# split on boundary and check that json field and expected value exist
|
||||||
|
stripped = channel.text_body.split("\r\n" + "--" + boundary)
|
||||||
|
# TODO: the json object expected will change once MSC3911 is implemented, currently
|
||||||
|
# {} is returned for all requests as a placeholder (per MSC3196)
|
||||||
|
found_json = any(
|
||||||
|
"\r\nContent-Type: application/json\r\n\r\n{}" in field
|
||||||
|
for field in stripped
|
||||||
|
)
|
||||||
|
self.assertTrue(found_json)
|
||||||
|
|
||||||
|
# check that the text file and expected value exist
|
||||||
|
found_file = any(
|
||||||
|
"\r\nContent-Type: text/plain\r\n\r\nfile_to_stream" in field
|
||||||
|
for field in stripped
|
||||||
|
)
|
||||||
|
self.assertTrue(found_file)
|
||||||
|
|
||||||
|
content = io.BytesIO(SMALL_PNG)
|
||||||
|
content_uri = self.get_success(
|
||||||
|
self.media_repo.create_content(
|
||||||
|
"image/png",
|
||||||
|
"test_png_upload",
|
||||||
|
content,
|
||||||
|
67,
|
||||||
|
UserID.from_string("@user_id:whatever.org"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# test with an image file
|
||||||
|
channel = self.make_signed_federation_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
|
||||||
|
)
|
||||||
|
self.pump()
|
||||||
|
self.assertEqual(200, channel.code)
|
||||||
|
|
||||||
|
content_type = channel.headers.getRawHeaders("content-type")
|
||||||
|
assert content_type is not None
|
||||||
|
assert "multipart/mixed" in content_type[0]
|
||||||
|
assert "boundary" in content_type[0]
|
||||||
|
|
||||||
|
# extract boundary
|
||||||
|
boundary = content_type[0].split("boundary=")[1]
|
||||||
|
# split on boundary and check that json field and expected value exist
|
||||||
|
body = channel.result.get("body")
|
||||||
|
assert body is not None
|
||||||
|
stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8"))
|
||||||
|
found_json = any(
|
||||||
|
b"\r\nContent-Type: application/json\r\n\r\n{}" in field
|
||||||
|
for field in stripped_bytes
|
||||||
|
)
|
||||||
|
self.assertTrue(found_json)
|
||||||
|
|
||||||
|
# check that the png file exists and matches what was uploaded
|
||||||
|
found_file = any(SMALL_PNG in field for field in stripped_bytes)
|
||||||
|
self.assertTrue(found_file)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{"experimental_features": {"msc3916_authenticated_media_enabled": False}}
|
||||||
|
)
|
||||||
|
def test_disable_config(self) -> None:
|
||||||
|
content = io.BytesIO(b"file_to_stream")
|
||||||
|
content_uri = self.get_success(
|
||||||
|
self.media_repo.create_content(
|
||||||
|
"text/plain",
|
||||||
|
"test_upload",
|
||||||
|
content,
|
||||||
|
46,
|
||||||
|
UserID.from_string("@user_id:whatever.org"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
channel = self.make_signed_federation_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
|
||||||
|
)
|
||||||
|
self.pump()
|
||||||
|
self.assertEqual(404, channel.code)
|
||||||
|
self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")
|
Loading…
Reference in a new issue