package msc2836

import (
	"bytes"
	"context"
	"database/sql"
	"encoding/base64"
	"encoding/json"

	"github.com/matrix-org/dendrite/internal/sqlutil"
	"github.com/matrix-org/dendrite/setup/base"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/gomatrixserverlib"
	"github.com/matrix-org/util"
)

type eventInfo struct {
	EventID        string
	OriginServerTS gomatrixserverlib.Timestamp
	RoomID         string
}

type Database interface {
	// StoreRelation stores the parent->child and child->parent relationship for later querying.
	// Also stores the event metadata e.g timestamp
	StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error
	// ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the
	// provided `relType`. The returned slice is sorted by origin_server_ts according to whether
	// `recentFirst` is true or false.
	ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
	// ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if
	// there is no parent for this child event, with no error. The parent eventInfo can be missing the
	// timestamp if the event is not known to the server.
	ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error)
	// UpdateChildMetadata persists the children_count and children_hash from this event if and only if
	// the count is greater than what was previously there. If the count is updated, the event will be
	// updated to be unexplored.
	UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error
	// ChildMetadata returns the children_count and children_hash for the event ID in question.
	// Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set
	// back to `false` when a larger count is inserted via UpdateChildMetadata.
	// Returns nil error if the event ID does not exist.
	ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error)
	// MarkChildrenExplored sets the 'explored' flag on this event to `true`.
	MarkChildrenExplored(ctx context.Context, eventID string) error
}

type DB struct {
	db                                     *sql.DB
	writer                                 sqlutil.Writer
	insertEdgeStmt                         *sql.Stmt
	insertNodeStmt                         *sql.Stmt
	selectChildrenForParentOldestFirstStmt *sql.Stmt
	selectChildrenForParentRecentFirstStmt *sql.Stmt
	selectParentForChildStmt               *sql.Stmt
	updateChildMetadataStmt                *sql.Stmt
	selectChildMetadataStmt                *sql.Stmt
	updateChildMetadataExploredStmt        *sql.Stmt
}

// NewDatabase loads the database for msc2836
func NewDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) {
	if dbOpts.ConnectionString.IsPostgres() {
		return newPostgresDatabase(base, dbOpts)
	}
	return newSQLiteDatabase(base, dbOpts)
}

func newPostgresDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) {
	d := DB{}
	var err error
	if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewDummyWriter()); err != nil {
		return nil, err
	}
	_, err = d.db.Exec(`
	CREATE TABLE IF NOT EXISTS msc2836_edges (
		parent_event_id TEXT NOT NULL,
		child_event_id TEXT NOT NULL,
		rel_type TEXT NOT NULL,
		parent_room_id TEXT NOT NULL,
		parent_servers TEXT NOT NULL,
		CONSTRAINT msc2836_edges_uniq UNIQUE (parent_event_id, child_event_id, rel_type)
	);

	CREATE TABLE IF NOT EXISTS msc2836_nodes (
		event_id TEXT PRIMARY KEY NOT NULL,
		origin_server_ts BIGINT NOT NULL,
		room_id TEXT NOT NULL,
		unsigned_children_count BIGINT NOT NULL,
		unsigned_children_hash TEXT NOT NULL,
		explored SMALLINT NOT NULL
	);
	`)
	if err != nil {
		return nil, err
	}
	if d.insertEdgeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
		VALUES($1, $2, $3, $4, $5)
		ON CONFLICT DO NOTHING
	`); err != nil {
		return nil, err
	}
	if d.insertNodeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
		VALUES($1, $2, $3, $4, $5, $6)
		ON CONFLICT DO NOTHING
	`); err != nil {
		return nil, err
	}
	selectChildrenQuery := `
	SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
	LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
	WHERE parent_event_id = $1 AND rel_type = $2
	ORDER BY origin_server_ts
	`
	if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
		return nil, err
	}
	if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
		return nil, err
	}
	if d.selectParentForChildStmt, err = d.db.Prepare(`
		SELECT parent_event_id, parent_room_id FROM msc2836_edges
		WHERE child_event_id = $1 AND rel_type = $2
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
	`); err != nil {
		return nil, err
	}
	if d.selectChildMetadataStmt, err = d.db.Prepare(`
		SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
	`); err != nil {
		return nil, err
	}
	return &d, err
}

func newSQLiteDatabase(base *base.BaseDendrite, dbOpts *config.DatabaseOptions) (Database, error) {
	d := DB{}
	var err error
	if d.db, d.writer, err = base.DatabaseConnection(dbOpts, sqlutil.NewExclusiveWriter()); err != nil {
		return nil, err
	}
	_, err = d.db.Exec(`
	CREATE TABLE IF NOT EXISTS msc2836_edges (
		parent_event_id TEXT NOT NULL,
		child_event_id TEXT NOT NULL,
		rel_type TEXT NOT NULL,
		parent_room_id TEXT NOT NULL,
		parent_servers TEXT NOT NULL,
		UNIQUE (parent_event_id, child_event_id, rel_type)
	);

	CREATE TABLE IF NOT EXISTS msc2836_nodes (
		event_id TEXT PRIMARY KEY NOT NULL,
		origin_server_ts BIGINT NOT NULL,
		room_id TEXT NOT NULL,
		unsigned_children_count BIGINT NOT NULL,
		unsigned_children_hash TEXT NOT NULL,
		explored SMALLINT NOT NULL
	);
	`)
	if err != nil {
		return nil, err
	}
	if d.insertEdgeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
		VALUES($1, $2, $3, $4, $5)
		ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
	`); err != nil {
		return nil, err
	}
	if d.insertNodeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
		VALUES($1, $2, $3, $4, $5, $6)
		ON CONFLICT DO NOTHING
	`); err != nil {
		return nil, err
	}
	selectChildrenQuery := `
	SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
	LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
	WHERE parent_event_id = $1 AND rel_type = $2
	ORDER BY origin_server_ts
	`
	if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
		return nil, err
	}
	if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
		return nil, err
	}
	if d.selectParentForChildStmt, err = d.db.Prepare(`
		SELECT parent_event_id, parent_room_id FROM msc2836_edges
		WHERE child_event_id = $1 AND rel_type = $2
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
	`); err != nil {
		return nil, err
	}
	if d.selectChildMetadataStmt, err = d.db.Prepare(`
		SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
	`); err != nil {
		return nil, err
	}
	return &d, nil
}

func (p *DB) StoreRelation(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error {
	parent, child, relType := parentChildEventIDs(ev)
	if parent == "" || child == "" {
		return nil
	}
	relationRoomID, relationServers := roomIDAndServers(ev)
	relationServersJSON, err := json.Marshal(relationServers)
	if err != nil {
		return err
	}
	count, hash := extractChildMetadata(ev)
	return p.writer.Do(p.db, nil, func(txn *sql.Tx) error {
		_, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON))
		if err != nil {
			return err
		}
		util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType)
		_, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0)
		return err
	})
}

func (p *DB) UpdateChildMetadata(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent) error {
	eventCount, eventHash := extractChildMetadata(ev)
	if eventCount == 0 {
		return nil // nothing to update with
	}

	// extract current children count/hash, if they are less than the current event then update the columns and set to unexplored
	count, hash, _, err := p.ChildMetadata(ctx, ev.EventID())
	if err != nil {
		return err
	}
	if eventCount > count || (eventCount == count && !bytes.Equal(hash, eventHash)) {
		_, err = p.updateChildMetadataStmt.ExecContext(ctx, eventCount, base64.RawStdEncoding.EncodeToString(eventHash), 0, ev.EventID())
		return err
	}
	return nil
}

func (p *DB) ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) {
	var b64hash string
	var exploredInt int
	if err = p.selectChildMetadataStmt.QueryRowContext(ctx, eventID).Scan(&count, &b64hash, &exploredInt); err != nil {
		if err == sql.ErrNoRows {
			err = nil
		}
		return
	}
	hash, err = base64.RawStdEncoding.DecodeString(b64hash)
	explored = exploredInt > 0
	return
}

func (p *DB) MarkChildrenExplored(ctx context.Context, eventID string) error {
	_, err := p.updateChildMetadataExploredStmt.ExecContext(ctx, 1, eventID)
	return err
}

func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) {
	var rows *sql.Rows
	var err error
	if recentFirst {
		rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType)
	} else {
		rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType)
	}
	if err != nil {
		return nil, err
	}
	defer rows.Close() // nolint: errcheck
	var children []eventInfo
	for rows.Next() {
		var evInfo eventInfo
		if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil {
			return nil, err
		}
		children = append(children, evInfo)
	}
	return children, nil
}

func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) {
	var ei eventInfo
	err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID)
	if err == sql.ErrNoRows {
		return nil, nil
	} else if err != nil {
		return nil, err
	}
	return &ei, nil
}

func parentChildEventIDs(ev *gomatrixserverlib.HeaderedEvent) (parent, child, relType string) {
	if ev == nil {
		return
	}
	body := struct {
		Relationship struct {
			RelType string `json:"rel_type"`
			EventID string `json:"event_id"`
		} `json:"m.relationship"`
	}{}
	if err := json.Unmarshal(ev.Content(), &body); err != nil {
		return
	}
	if body.Relationship.EventID == "" || body.Relationship.RelType == "" {
		return
	}
	return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType
}

func roomIDAndServers(ev *gomatrixserverlib.HeaderedEvent) (roomID string, servers []string) {
	servers = []string{}
	if ev == nil {
		return
	}
	body := struct {
		RoomID  string   `json:"relationship_room_id"`
		Servers []string `json:"relationship_servers"`
	}{}
	if err := json.Unmarshal(ev.Unsigned(), &body); err != nil {
		return
	}
	return body.RoomID, body.Servers
}

func extractChildMetadata(ev *gomatrixserverlib.HeaderedEvent) (count int, hash []byte) {
	unsigned := struct {
		Counts map[string]int                `json:"children"`
		Hash   gomatrixserverlib.Base64Bytes `json:"children_hash"`
	}{}
	if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil {
		// expected if there is no unsigned field at all
		return
	}
	for _, c := range unsigned.Counts {
		count += c
	}
	hash = unsigned.Hash
	return
}