From b06d1124f72a1be2aa0906f837559191399aa121 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 17 Jul 2017 17:20:57 +0100 Subject: [PATCH] Factor out runTransaction to common code (#162) --- .../dendrite/clientapi/auth/auth.go | 9 +++- .../clientapi/auth/storage/devices/storage.go | 25 ++--------- .../matrix-org/dendrite/common/httpapi.go | 6 +-- .../matrix-org/dendrite/common/sql.go | 41 +++++++++++++++++++ .../federationsender/storage/storage.go | 21 +--------- .../dendrite/syncapi/storage/syncserver.go | 25 ++--------- 6 files changed, 58 insertions(+), 69 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/common/sql.go diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go index 9f350b4b0..a661c1f81 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go @@ -24,7 +24,6 @@ import ( "strings" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/util" ) @@ -40,10 +39,16 @@ var UnknownDeviceID = "unknown-device" // 32 bytes => 256 bits var tokenByteLength = 32 +// DeviceDatabase represents a device database. +type DeviceDatabase interface { + // Lookup the device matching the given access token. + GetDeviceByAccessToken(token string) (*authtypes.Device, error) +} + // VerifyAccessToken verifies that an access token was supplied in the given HTTP request // and returns the device it corresponds to. Returns resErr (an error response which can be // sent to the client) if the token is invalid or there was a problem querying the database. -func VerifyAccessToken(req *http.Request, deviceDB *devices.Database) (device *authtypes.Device, resErr *util.JSONResponse) { +func VerifyAccessToken(req *http.Request, deviceDB DeviceDatabase) (device *authtypes.Device, resErr *util.JSONResponse) { token, err := extractAccessToken(req) if err != nil { resErr = &util.JSONResponse{ diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go index 8f39cad3c..a58c26346 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/devices/storage.go @@ -18,6 +18,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" ) @@ -53,7 +54,7 @@ func (d *Database) GetDeviceByAccessToken(token string) (*authtypes.Device, erro // an error will be returned. // Returns the device on success. func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *authtypes.Device, returnErr error) { - returnErr = runTransaction(d.db, func(txn *sql.Tx) error { + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error // Revoke existing token for this device if err = d.devices.deleteDevice(txn, deviceID, localpart); err != nil { @@ -74,30 +75,10 @@ func (d *Database) CreateDevice(localpart, deviceID, accessToken string) (dev *a // If the device doesn't exist, it will not return an error // If something went wrong during the deletion, it will return the SQL error func (d *Database) RemoveDevice(deviceID string, localpart string) error { - return runTransaction(d.db, func(txn *sql.Tx) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(txn, deviceID, localpart); err != sql.ErrNoRows { return err } return nil }) } - -// TODO: factor out to common -func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { - txn, err := db.Begin() - if err != nil { - return - } - defer func() { - if r := recover(); r != nil { - txn.Rollback() - panic(r) - } else if err != nil { - txn.Rollback() - } else { - err = txn.Commit() - } - }() - err = fn(txn) - return -} diff --git a/src/github.com/matrix-org/dendrite/common/httpapi.go b/src/github.com/matrix-org/dendrite/common/httpapi.go index 0ab33925f..6298c7b18 100644 --- a/src/github.com/matrix-org/dendrite/common/httpapi.go +++ b/src/github.com/matrix-org/dendrite/common/httpapi.go @@ -1,16 +1,16 @@ package common import ( + "net/http" + "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" - "net/http" ) // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request. -func MakeAuthAPI(metricsName string, deviceDB *devices.Database, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler { +func MakeAuthAPI(metricsName string, deviceDB auth.DeviceDatabase, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler { h := util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { device, resErr := auth.VerifyAccessToken(req, deviceDB) if resErr != nil { diff --git a/src/github.com/matrix-org/dendrite/common/sql.go b/src/github.com/matrix-org/dendrite/common/sql.go new file mode 100644 index 000000000..cabbe6662 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/common/sql.go @@ -0,0 +1,41 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "database/sql" +) + +// WithTransaction runs a block of code passing in an SQL transaction +// If the code returns an error or panics then the transactions is rolledback +// Otherwise the transaction is committed. +func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { + txn, err := db.Begin() + if err != nil { + return + } + defer func() { + if r := recover(); r != nil { + txn.Rollback() + panic(r) + } else if err != nil { + txn.Rollback() + } else { + err = txn.Commit() + } + }() + err = fn(txn) + return +} diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go b/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go index 2f98093e4..a10210076 100644 --- a/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go @@ -77,7 +77,7 @@ func (d *Database) UpdateRoom( addHosts []types.JoinedHost, removeHosts []string, ) (joinedHosts []types.JoinedHost, err error) { - err = runTransaction(d.db, func(txn *sql.Tx) error { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { if err = d.insertRoom(txn, roomID); err != nil { return err } @@ -105,22 +105,3 @@ func (d *Database) UpdateRoom( }) return } - -func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { - txn, err := db.Begin() - if err != nil { - return - } - defer func() { - if r := recover(); r != nil { - txn.Rollback() - panic(r) - } else if err != nil { - txn.Rollback() - } else { - err = txn.Commit() - } - }() - err = fn(txn) - return -} diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go index a1efa1f27..27afd1c05 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go @@ -92,7 +92,7 @@ func (d *SyncServerDatabase) Events(eventIDs []string) ([]gomatrixserverlib.Even func (d *SyncServerDatabase) WriteEvent( ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, ) (streamPos types.StreamPosition, returnErr error) { - returnErr = runTransaction(d.db, func(txn *sql.Tx) error { + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error pos, err := d.events.insertEvent(txn, ev, addStateEventIDs, removeStateEventIDs) if err != nil { @@ -162,7 +162,7 @@ func (d *SyncServerDatabase) SyncStreamPosition() (types.StreamPosition, error) // IncrementalSync returns all the data needed in order to create an incremental sync response. func (d *SyncServerDatabase) IncrementalSync(userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int) (res *types.Response, returnErr error) { - returnErr = runTransaction(d.db, func(txn *sql.Tx) error { + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { // Work out which rooms to return in the response. This is done by getting not only the currently // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // This works out what the 'state' key should be for each room as well as which membership block @@ -223,7 +223,7 @@ func (d *SyncServerDatabase) CompleteSync(userID string, numRecentEventsPerRoom // a consistent view of the database throughout. This includes extracting the sync stream position. // This does have the unfortunate side-effect that all the matrixy logic resides in this function, // but it's better to not hide the fact that this is being done in a transaction. - returnErr = runTransaction(d.db, func(txn *sql.Tx) error { + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { // Get the current stream position which we will base the sync response on. id, err := d.events.selectMaxID(txn) if err != nil { @@ -479,22 +479,3 @@ func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string { } return "" } - -func runTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { - txn, err := db.Begin() - if err != nil { - return - } - defer func() { - if r := recover(); r != nil { - txn.Rollback() - panic(r) - } else if err != nil { - txn.Rollback() - } else { - err = txn.Commit() - } - }() - err = fn(txn) - return -}