#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

from typing import Any, List, Optional

from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
    _MAX_STATE_UPDATES_PER_ROOM,
    EventsStreamAllStateRow,
    EventsStreamCurrentStateRow,
    EventsStreamEventRow,
    EventsStreamRow,
)
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock

from tests.replication._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event


class EventsStreamTestCase(BaseStreamTestCase):
    servlets = [
        admin.register_servlets,
        login.register_servlets,
        room.register_servlets,
    ]

    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
        super().prepare(reactor, clock, hs)
        self.user_id = self.register_user("u1", "pass")
        self.user_tok = self.login("u1", "pass")

        self.reconnect()

        self.room_id = self.helper.create_room_as(tok=self.user_tok)
        self.test_handler.received_rdata_rows.clear()

    def test_update_function_event_row_limit(self) -> None:
        """Test replication with many non-state events

        Checks that all events are correctly replicated when there are lots of
        event rows to be replicated.
        """
        # disconnect, so that we can stack up some changes
        self.disconnect()

        # generate lots of non-state events. We inject them using inject_event
        # so that they are not send out over replication until we call self.replicate().
        events = [
            self._inject_test_event()
            for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
        ]

        # also one state event
        state_event = self._inject_state_event()

        # check we're testing what we think we are: no rows should yet have been
        # received
        self.assertEqual([], self.test_handler.received_rdata_rows)

        # now reconnect to pull the updates
        self.reconnect()
        self.replicate()

        # we should have received all the expected rows in the right order (as
        # well as various cache invalidation updates which we ignore)
        received_rows = [
            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
        ]

        for event in events:
            stream_name, token, row = received_rows.pop(0)
            self.assertEqual("events", stream_name)
            self.assertIsInstance(row, EventsStreamRow)
            self.assertEqual(row.type, "ev")
            self.assertIsInstance(row.data, EventsStreamEventRow)
            self.assertEqual(row.data.event_id, event.event_id)

        stream_name, token, row = received_rows.pop(0)
        self.assertIsInstance(row, EventsStreamRow)
        self.assertIsInstance(row.data, EventsStreamEventRow)
        self.assertEqual(row.data.event_id, state_event.event_id)

        stream_name, token, row = received_rows.pop(0)
        self.assertEqual("events", stream_name)
        self.assertIsInstance(row, EventsStreamRow)
        self.assertEqual(row.type, "state")
        self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
        self.assertEqual(row.data.event_id, state_event.event_id)

        self.assertEqual([], received_rows)

    @parameterized.expand(
        [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
    )
    def test_update_function_huge_state_change(
        self, num_state_changes: int, collapse_state_changes: bool
    ) -> None:
        """Test replication with many state events

        Ensures that all events are correctly replicated when there are lots of
        state change rows to be replicated.

        Args:
            num_state_changes: The number of state changes to create.
            collapse_state_changes: Whether the state changes are expected to be
                collapsed or not.
        """

        # we want to generate lots of state changes at a single stream ID.
        #
        # We do this by having two branches in the DAG. On one, we have a moderator
        # which that generates lots of state; on the other, we de-op the moderator,
        # thus invalidating all the state.

        OTHER_USER = "@other_user:localhost"

        # have the user join
        self.get_success(
            inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
        )

        # Update existing power levels with mod at PL50
        pls = self.helper.get_state(
            self.room_id, EventTypes.PowerLevels, tok=self.user_tok
        )
        pls["users"][OTHER_USER] = 50
        self.helper.send_state(
            self.room_id,
            EventTypes.PowerLevels,
            pls,
            tok=self.user_tok,
        )

        # this is the point in the DAG where we make a fork
        fork_point = self.get_success(
            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
        )

        events = [
            self._inject_state_event(sender=OTHER_USER)
            for _ in range(num_state_changes)
        ]

        self.replicate()
        # all those events and state changes should have landed
        self.assertGreaterEqual(
            len(self.test_handler.received_rdata_rows), 2 * len(events)
        )

        # disconnect, so that we can stack up the changes
        self.disconnect()
        self.test_handler.received_rdata_rows.clear()

        # a state event which doesn't get rolled back, to check that the state
        # before the huge update comes through ok
        state1 = self._inject_state_event()

        # roll back all the state by de-modding the user
        prev_events = fork_point
        pls["users"][OTHER_USER] = 0
        pl_event = self.get_success(
            inject_event(
                self.hs,
                prev_event_ids=list(prev_events),
                type=EventTypes.PowerLevels,
                state_key="",
                sender=self.user_id,
                room_id=self.room_id,
                content=pls,
            )
        )

        # one more bit of state that doesn't get rolled back
        state2 = self._inject_state_event()

        # check we're testing what we think we are: no rows should yet have been
        # received
        self.assertEqual([], self.test_handler.received_rdata_rows)

        # now reconnect to pull the updates
        self.reconnect()
        self.replicate()

        # we should have received all the expected rows in the right order (as
        # well as various cache invalidation updates which we ignore)
        #
        # we expect:
        #
        # - two rows for state1
        # - the PL event row, plus state rows for the PL event and each
        #       of the states that got reverted.
        # - two rows for state2

        received_rows = [
            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
        ]

        # first check the first two rows, which should be the state1 event.
        stream_name, token, row = received_rows.pop(0)
        self.assertEqual("events", stream_name)
        self.assertIsInstance(row, EventsStreamRow)
        self.assertEqual(row.type, "ev")
        self.assertIsInstance(row.data, EventsStreamEventRow)
        self.assertEqual(row.data.event_id, state1.event_id)

        stream_name, token, row = received_rows.pop(0)
        self.assertIsInstance(row, EventsStreamRow)
        self.assertEqual(row.type, "state")
        self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
        self.assertEqual(row.data.event_id, state1.event_id)

        # now the last two rows, which should be the state2 event.
        stream_name, token, row = received_rows.pop(-2)
        self.assertEqual("events", stream_name)
        self.assertIsInstance(row, EventsStreamRow)
        self.assertEqual(row.type, "ev")
        self.assertIsInstance(row.data, EventsStreamEventRow)
        self.assertEqual(row.data.event_id, state2.event_id)

        stream_name, token, row = received_rows.pop(-1)
        self.assertIsInstance(row, EventsStreamRow)
        self.assertEqual(row.type, "state")
        self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
        self.assertEqual(row.data.event_id, state2.event_id)

        # Based on the number of
        if collapse_state_changes:
            # that should leave us with the rows for the PL event, the state changes
            # get collapsed into a single row.
            self.assertEqual(len(received_rows), 2)

            stream_name, token, row = received_rows.pop(0)
            self.assertEqual("events", stream_name)
            self.assertIsInstance(row, EventsStreamRow)
            self.assertEqual(row.type, "ev")
            self.assertIsInstance(row.data, EventsStreamEventRow)
            self.assertEqual(row.data.event_id, pl_event.event_id)

            stream_name, token, row = received_rows.pop(0)
            self.assertIsInstance(row, EventsStreamRow)
            self.assertEqual(row.type, "state-all")
            self.assertIsInstance(row.data, EventsStreamAllStateRow)
            self.assertEqual(row.data.room_id, state2.room_id)

        else:
            # that should leave us with the rows for the PL event
            self.assertEqual(len(received_rows), len(events) + 2)

            stream_name, token, row = received_rows.pop(0)
            self.assertEqual("events", stream_name)
            self.assertIsInstance(row, EventsStreamRow)
            self.assertEqual(row.type, "ev")
            self.assertIsInstance(row.data, EventsStreamEventRow)
            self.assertEqual(row.data.event_id, pl_event.event_id)

            # the state rows are unsorted
            state_rows: List[EventsStreamCurrentStateRow] = []
            for stream_name, _, row in received_rows:
                self.assertEqual("events", stream_name)
                self.assertIsInstance(row, EventsStreamRow)
                self.assertEqual(row.type, "state")
                self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
                state_rows.append(row.data)

            state_rows.sort(key=lambda r: r.state_key)

            sr = state_rows.pop(0)
            self.assertEqual(sr.type, EventTypes.PowerLevels)
            self.assertEqual(sr.event_id, pl_event.event_id)
            for sr in state_rows:
                self.assertEqual(sr.type, "test_state_event")
                # "None" indicates the state has been deleted
                self.assertIsNone(sr.event_id)

    def test_update_function_state_row_limit(self) -> None:
        """Test replication with many state events over several stream ids."""

        # we want to generate lots of state changes, but for this test, we want to
        # spread out the state changes over a few stream IDs.
        #
        # We do this by having two branches in the DAG. On one, we have four moderators,
        # each of which that generates lots of state; on the other, we de-op the users,
        # thus invalidating all the state.

        NUM_USERS = 4
        STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1

        user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]

        # have the users join
        for u in user_ids:
            self.get_success(
                inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
            )

        # Update existing power levels with mod at PL50
        pls = self.helper.get_state(
            self.room_id, EventTypes.PowerLevels, tok=self.user_tok
        )
        pls["users"].update({u: 50 for u in user_ids})
        self.helper.send_state(
            self.room_id,
            EventTypes.PowerLevels,
            pls,
            tok=self.user_tok,
        )

        # this is the point in the DAG where we make a fork
        fork_point = self.get_success(
            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
        )

        events: List[EventBase] = []
        for user in user_ids:
            events.extend(
                self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
            )

        self.replicate()

        # all those events and state changes should have landed
        self.assertGreaterEqual(
            len(self.test_handler.received_rdata_rows), 2 * len(events)
        )

        # disconnect, so that we can stack up the changes
        self.disconnect()
        self.test_handler.received_rdata_rows.clear()

        # now roll back all that state by de-modding the users
        prev_events = list(fork_point)
        pl_events = []
        for u in user_ids:
            pls["users"][u] = 0
            e = self.get_success(
                inject_event(
                    self.hs,
                    prev_event_ids=prev_events,
                    type=EventTypes.PowerLevels,
                    state_key="",
                    sender=self.user_id,
                    room_id=self.room_id,
                    content=pls,
                )
            )
            prev_events = [e.event_id]
            pl_events.append(e)

        # check we're testing what we think we are: no rows should yet have been
        # received
        self.assertEqual([], self.test_handler.received_rdata_rows)

        # now reconnect to pull the updates
        self.reconnect()
        self.replicate()

        # we should have received all the expected rows in the right order (as
        # well as various cache invalidation updates which we ignore)
        received_rows = [
            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
        ]
        self.assertGreaterEqual(len(received_rows), len(events))
        for i in range(NUM_USERS):
            # for each user, we expect the PL event row, followed by state rows for
            # the PL event and each of the states that got reverted.
            stream_name, token, row = received_rows.pop(0)
            self.assertEqual("events", stream_name)
            self.assertIsInstance(row, EventsStreamRow)
            self.assertEqual(row.type, "ev")
            self.assertIsInstance(row.data, EventsStreamEventRow)
            self.assertEqual(row.data.event_id, pl_events[i].event_id)

            # the state rows are unsorted
            state_rows: List[EventsStreamCurrentStateRow] = []
            for _ in range(STATES_PER_USER + 1):
                stream_name, token, row = received_rows.pop(0)
                self.assertEqual("events", stream_name)
                self.assertIsInstance(row, EventsStreamRow)
                self.assertEqual(row.type, "state")
                self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
                state_rows.append(row.data)

            state_rows.sort(key=lambda r: r.state_key)

            sr = state_rows.pop(0)
            self.assertEqual(sr.type, EventTypes.PowerLevels)
            self.assertEqual(sr.event_id, pl_events[i].event_id)
            for sr in state_rows:
                self.assertEqual(sr.type, "test_state_event")
                # "None" indicates the state has been deleted
                self.assertIsNone(sr.event_id)

        self.assertEqual([], received_rows)

    def test_backwards_stream_id(self) -> None:
        """
        Test that RDATA that comes after the current position should be discarded.
        """
        # disconnect, so that we can stack up some changes
        self.disconnect()

        # Generate an events. We inject them using inject_event so that they are
        # not send out over replication until we call self.replicate().
        event = self._inject_test_event()

        # check we're testing what we think we are: no rows should yet have been
        # received
        self.assertEqual([], self.test_handler.received_rdata_rows)

        # now reconnect to pull the updates
        self.reconnect()
        self.replicate()

        # We should have received the expected single row (as well as various
        # cache invalidation updates which we ignore).
        received_rows = [
            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
        ]

        # There should be a single received row.
        self.assertEqual(len(received_rows), 1)

        stream_name, token, row = received_rows[0]
        self.assertEqual("events", stream_name)
        self.assertIsInstance(row, EventsStreamRow)
        self.assertEqual(row.type, "ev")
        self.assertIsInstance(row.data, EventsStreamEventRow)
        self.assertEqual(row.data.event_id, event.event_id)

        # Reset the data.
        self.test_handler.received_rdata_rows = []

        # Save the current token for later.
        worker_events_stream = self.worker_hs.get_replication_streams()["events"]
        prev_token = worker_events_stream.current_token("master")

        # Manually send an old RDATA command, which should get dropped. This
        # re-uses the row from above, but with an earlier stream token.
        self.hs.get_replication_command_handler().send_command(
            RdataCommand("events", "master", 1, row)
        )

        # No updates have been received (because it was discard as old).
        received_rows = [
            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
        ]
        self.assertEqual(len(received_rows), 0)

        # Ensure the stream has not gone backwards.
        current_token = worker_events_stream.current_token("master")
        self.assertGreaterEqual(current_token, prev_token)

    event_count = 0

    def _inject_test_event(
        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
    ) -> EventBase:
        if sender is None:
            sender = self.user_id

        if body is None:
            body = "event %i" % (self.event_count,)
            self.event_count += 1

        return self.get_success(
            inject_event(
                self.hs,
                room_id=self.room_id,
                sender=sender,
                type="test_event",
                content={"body": body},
                **kwargs,
            )
        )

    def _inject_state_event(
        self,
        body: Optional[str] = None,
        state_key: Optional[str] = None,
        sender: Optional[str] = None,
    ) -> EventBase:
        if sender is None:
            sender = self.user_id

        if state_key is None:
            state_key = "state_%i" % (self.event_count,)
            self.event_count += 1

        if body is None:
            body = "state event %s" % (state_key,)

        return self.get_success(
            inject_event(
                self.hs,
                room_id=self.room_id,
                sender=sender,
                type="test_state_event",
                state_key=state_key,
                content={"body": body},
            )
        )