diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go index 0a8baea99..3ae2c9278 100644 --- a/appservice/inthttp/client.go +++ b/appservice/inthttp/client.go @@ -7,7 +7,6 @@ import ( "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP APIs @@ -42,11 +41,10 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists( request *api.RoomAliasExistsRequest, response *api.RoomAliasExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists") - defer span.Finish() - - apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath, + h.httpClient, ctx, request, response, + ) } // UserIDExists implements AppServiceQueryAPI @@ -55,9 +53,8 @@ func (h *httpAppServiceQueryAPI) UserIDExists( request *api.UserIDExistsRequest, response *api.UserIDExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists") - defer span.Finish() - - apiURL := h.appserviceURL + AppServiceUserIDExistsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath, + h.httpClient, ctx, request, response, + ) } diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go index 645b43871..01d9f9895 100644 --- a/appservice/inthttp/server.go +++ b/appservice/inthttp/server.go @@ -1,43 +1,20 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/util" ) // AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( AppServiceRoomAliasExistsPath, - httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { - var request api.RoomAliasExistsRequest - var response api.RoomAliasExistsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists), ) + internalAPIMux.Handle( AppServiceUserIDExistsPath, - httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { - var request api.UserIDExistsRequest - var response api.UserIDExistsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := a.UserIDExists(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists), ) } diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index 70bac61dc..be7d13a96 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -15,11 +15,13 @@ package jsonerror import ( + "context" "fmt" "net/http" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // MatrixError represents the "standard error response" in Matrix. @@ -213,3 +215,15 @@ func NotTrusted(serverName string) *MatrixError { Err: fmt.Sprintf("Untrusted server '%s'", serverName), } } + +// InternalAPIError is returned when Dendrite failed to reach an internal API. +func InternalAPIError(ctx context.Context, err error) util.JSONResponse { + logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: &MatrixError{ + ErrCode: "M_INTERNAL_SERVER_ERROR", + Err: "Dendrite encountered an error reaching an internal API.", + }, + } +} diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 523b88c99..a8dd0e64f 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -30,13 +30,15 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv } } res := &roomserverAPI.PerformAdminEvacuateRoomResponse{} - rsAPI.PerformAdminEvacuateRoom( + if err := rsAPI.PerformAdminEvacuateRoom( req.Context(), &roomserverAPI.PerformAdminEvacuateRoomRequest{ RoomID: roomID, }, res, - ) + ); err != nil { + return util.ErrorResponse(err) + } if err := res.Error; err != nil { return err.JSONResponse() } @@ -67,13 +69,15 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv } } res := &roomserverAPI.PerformAdminEvacuateUserResponse{} - rsAPI.PerformAdminEvacuateUser( + if err := rsAPI.PerformAdminEvacuateUser( req.Context(), &roomserverAPI.PerformAdminEvacuateUserRequest{ UserID: userID, }, res, - ) + ); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if err := res.Error; err != nil { return err.JSONResponse() } diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 3f92b7ba6..874908639 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -556,10 +556,12 @@ func createRoom( if r.Visibility == "public" { // expose this room in the published room list var pubRes roomserverAPI.PerformPublishResponse - rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ + if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: "public", - }, &pubRes) + }, &pubRes); err != nil { + return jsonerror.InternalAPIError(ctx, err) + } if pubRes.Error != nil { // treat as non-fatal since the room is already made by this point util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public") diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 53ba3f190..836d9e152 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -302,10 +302,12 @@ func SetVisibility( } var publishRes roomserverAPI.PerformPublishResponse - rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: v.Visibility, - }, &publishRes) + }, &publishRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if publishRes.Error != nil { util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed") return publishRes.Error.JSONResponse() diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 4e6acebc3..c50e552bd 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -81,8 +81,9 @@ func JoinRoomByIDOrAlias( done := make(chan util.JSONResponse, 1) go func() { defer close(done) - rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes) - if joinRes.Error != nil { + if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { + done <- jsonerror.InternalAPIError(req.Context(), err) + } else if joinRes.Error != nil { done <- joinRes.Error.JSONResponse() } else { done <- util.JSONResponse{ diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index 28c80415b..b6f8fe1b9 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -91,10 +91,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { var queryResp userapi.QueryKeyBackupResponse - userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, - }, &queryResp) + }, &queryResp); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if queryResp.Error != "" { return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) } @@ -233,13 +235,15 @@ func GetBackupKeys( req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string, ) util.JSONResponse { var queryResp userapi.QueryKeyBackupResponse - userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, ReturnKeys: true, KeysForRoomID: roomID, KeysForSessionID: sessionID, - }, &queryResp) + }, &queryResp); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if queryResp.Error != "" { return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) } diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 8fbb86f7a..2570db09c 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -72,7 +72,9 @@ func UploadCrossSigningDeviceKeys( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) + if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if err := uploadRes.Error; err != nil { switch { @@ -114,7 +116,9 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie } uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes) + if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if err := uploadRes.Error; err != nil { switch { diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index fdda34a53..b7a76b47e 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -62,7 +62,9 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devi } var uploadRes api.PerformUploadKeysResponse - keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes) + if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil { + return util.ErrorResponse(err) + } if uploadRes.Error != nil { util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys") return jsonerror.InternalServerError() @@ -107,12 +109,14 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devic return *resErr } queryRes := api.QueryKeysResponse{} - keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ + if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ UserID: device.UserID, UserToDevices: r.DeviceKeys, Timeout: r.GetTimeout(), // TODO: Token? - }, &queryRes) + }, &queryRes); err != nil { + return util.ErrorResponse(err) + } return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ @@ -145,10 +149,12 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse { return *resErr } claimRes := api.PerformClaimKeysResponse{} - keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ + if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: r.OneTimeKeys, Timeout: r.GetTimeout(), - }, &claimRes) + }, &claimRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if claimRes.Error != nil { util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys") return jsonerror.InternalServerError() diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index d0eeccf17..9b2592eb5 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -17,6 +17,7 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -54,7 +55,9 @@ func PeekRoomByIDOrAlias( } // Ask the roomserver to perform the peek. - rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes) + if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil { + return util.ErrorResponse(err) + } if peekRes.Error != nil { return peekRes.Error.JSONResponse() } @@ -89,7 +92,9 @@ func UnpeekRoomByID( } unpeekRes := roomserverAPI.PerformUnpeekResponse{} - rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes) + if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if unpeekRes.Error != nil { return unpeekRes.Error.JSONResponse() } diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index 744e2d889..34c7eb004 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -64,7 +64,9 @@ func UpgradeRoom( } upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{} - rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp) + if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if upgradeResp.Error != nil { if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom { diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 53d4701f3..292ed55ad 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -110,7 +110,7 @@ type FederationClientError struct { Blacklisted bool } -func (e *FederationClientError) Error() string { +func (e FederationClientError) Error() string { return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 8884e34c6..bdcb9f57c 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -32,11 +32,12 @@ type fedRoomserverAPI struct { } // PerformJoin will call this function -func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { +func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error { if f.inputRoomEvents == nil { - return + return nil } f.inputRoomEvents(ctx, req, res) + return nil } // keychange consumer calls this diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 295ddc495..812d3c6da 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -10,7 +10,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP API @@ -48,7 +47,11 @@ func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client, if httpClient == nil { return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is ") } - return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil + return &httpFederationInternalAPI{ + federationAPIURL: federationSenderURL, + httpClient: httpClient, + cache: cache, + }, nil } type httpFederationInternalAPI struct { @@ -63,11 +66,10 @@ func (h *httpFederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath, + h.httpClient, ctx, request, response, + ) } // Handle sending an invite to a remote server. @@ -76,11 +78,10 @@ func (h *httpFederationInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath, + h.httpClient, ctx, request, response, + ) } // Handle starting a peek on a remote server. @@ -89,11 +90,10 @@ func (h *httpFederationInternalAPI) PerformOutboundPeek( request *api.PerformOutboundPeekRequest, response *api.PerformOutboundPeekResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath, + h.httpClient, ctx, request, response, + ) } // QueryJoinedHostServerNamesInRoom implements FederationInternalAPI @@ -102,11 +102,10 @@ func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath, + h.httpClient, ctx, request, response, + ) } // Handle an instruction to make_join & send_join with a remote server. @@ -115,12 +114,10 @@ func (h *httpFederationInternalAPI) PerformJoin( request *api.PerformJoinRequest, response *api.PerformJoinResponse, ) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { + if err := httputil.CallInternalRPCAPI( + "PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath, + h.httpClient, ctx, request, response, + ); err != nil { response.LastError = &gomatrix.HTTPError{ Message: err.Error(), Code: 0, @@ -135,11 +132,10 @@ func (h *httpFederationInternalAPI) PerformDirectoryLookup( request *api.PerformDirectoryLookupRequest, response *api.PerformDirectoryLookupResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath, + h.httpClient, ctx, request, response, + ) } // Handle an instruction to broadcast an EDU to all servers in rooms we are joined to. @@ -148,101 +144,61 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( request *api.PerformBroadcastEDURequest, response *api.PerformBroadcastEDUResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath, + h.httpClient, ctx, request, response, + ) } type getUserDevices struct { S gomatrixserverlib.ServerName UserID string - Res *gomatrixserverlib.RespUserDevices - Err *api.FederationClientError } func (h *httpFederationInternalAPI) GetUserDevices( ctx context.Context, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices") - defer span.Finish() - - var result gomatrixserverlib.RespUserDevices - request := getUserDevices{ - S: s, - UserID: userID, - } - var response getUserDevices - apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return result, err - } - if response.Err != nil { - return result, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( + "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, + ctx, &getUserDevices{ + S: s, + UserID: userID, + }, + ) } type claimKeys struct { S gomatrixserverlib.ServerName OneTimeKeys map[string]map[string]string - Res *gomatrixserverlib.RespClaimKeys - Err *api.FederationClientError } func (h *httpFederationInternalAPI) ClaimKeys( ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys") - defer span.Finish() - - var result gomatrixserverlib.RespClaimKeys - request := claimKeys{ - S: s, - OneTimeKeys: oneTimeKeys, - } - var response claimKeys - apiURL := h.federationAPIURL + FederationAPIClaimKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return result, err - } - if response.Err != nil { - return result, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( + "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, + ctx, &claimKeys{ + S: s, + OneTimeKeys: oneTimeKeys, + }, + ) } type queryKeys struct { S gomatrixserverlib.ServerName Keys map[string][]string - Res *gomatrixserverlib.RespQueryKeys - Err *api.FederationClientError } func (h *httpFederationInternalAPI) QueryKeys( ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys") - defer span.Finish() - - var result gomatrixserverlib.RespQueryKeys - request := queryKeys{ - S: s, - Keys: keys, - } - var response queryKeys - apiURL := h.federationAPIURL + FederationAPIQueryKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return result, err - } - if response.Err != nil { - return result, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( + "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, + ctx, &queryKeys{ + S: s, + Keys: keys, + }, + ) } type backfill struct { @@ -250,32 +206,20 @@ type backfill struct { RoomID string Limit int EventIDs []string - Res *gomatrixserverlib.Transaction - Err *api.FederationClientError } func (h *httpFederationInternalAPI) Backfill( ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (gomatrixserverlib.Transaction, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill") - defer span.Finish() - - request := backfill{ - S: s, - RoomID: roomID, - Limit: limit, - EventIDs: eventIDs, - } - var response backfill - apiURL := h.federationAPIURL + FederationAPIBackfillPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.Transaction{}, err - } - if response.Err != nil { - return gomatrixserverlib.Transaction{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( + "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, + ctx, &backfill{ + S: s, + RoomID: roomID, + Limit: limit, + EventIDs: eventIDs, + }, + ) } type lookupState struct { @@ -283,63 +227,39 @@ type lookupState struct { RoomID string EventID string RoomVersion gomatrixserverlib.RoomVersion - Res *gomatrixserverlib.RespState - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupState( ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (gomatrixserverlib.RespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState") - defer span.Finish() - - request := lookupState{ - S: s, - RoomID: roomID, - EventID: eventID, - RoomVersion: roomVersion, - } - var response lookupState - apiURL := h.federationAPIURL + FederationAPILookupStatePath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.RespState{}, err - } - if response.Err != nil { - return gomatrixserverlib.RespState{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( + "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, + ctx, &lookupState{ + S: s, + RoomID: roomID, + EventID: eventID, + RoomVersion: roomVersion, + }, + ) } type lookupStateIDs struct { S gomatrixserverlib.ServerName RoomID string EventID string - Res *gomatrixserverlib.RespStateIDs - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupStateIDs( ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ) (gomatrixserverlib.RespStateIDs, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs") - defer span.Finish() - - request := lookupStateIDs{ - S: s, - RoomID: roomID, - EventID: eventID, - } - var response lookupStateIDs - apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.RespStateIDs{}, err - } - if response.Err != nil { - return gomatrixserverlib.RespStateIDs{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( + "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, + ctx, &lookupStateIDs{ + S: s, + RoomID: roomID, + EventID: eventID, + }, + ) } type lookupMissingEvents struct { @@ -347,64 +267,38 @@ type lookupMissingEvents struct { RoomID string Missing gomatrixserverlib.MissingEvents RoomVersion gomatrixserverlib.RoomVersion - Res struct { - Events []gomatrixserverlib.RawJSON `json:"events"` - } - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupMissingEvents( ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents") - defer span.Finish() - - request := lookupMissingEvents{ - S: s, - RoomID: roomID, - Missing: missing, - RoomVersion: roomVersion, - } - apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath - err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &request) - if err != nil { - return res, err - } - if request.Err != nil { - return res, request.Err - } - res.Events = request.Res.Events - return res, nil + return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( + "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, + ctx, &lookupMissingEvents{ + S: s, + RoomID: roomID, + Missing: missing, + RoomVersion: roomVersion, + }, + ) } type getEvent struct { S gomatrixserverlib.ServerName EventID string - Res *gomatrixserverlib.Transaction - Err *api.FederationClientError } func (h *httpFederationInternalAPI) GetEvent( ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ) (gomatrixserverlib.Transaction, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent") - defer span.Finish() - - request := getEvent{ - S: s, - EventID: eventID, - } - var response getEvent - apiURL := h.federationAPIURL + FederationAPIGetEventPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.Transaction{}, err - } - if response.Err != nil { - return gomatrixserverlib.Transaction{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( + "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, + ctx, &getEvent{ + S: s, + EventID: eventID, + }, + ) } type getEventAuth struct { @@ -412,135 +306,86 @@ type getEventAuth struct { RoomVersion gomatrixserverlib.RoomVersion RoomID string EventID string - Res *gomatrixserverlib.RespEventAuth - Err *api.FederationClientError } func (h *httpFederationInternalAPI) GetEventAuth( ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (gomatrixserverlib.RespEventAuth, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetEventAuth") - defer span.Finish() - - request := getEventAuth{ - S: s, - RoomVersion: roomVersion, - RoomID: roomID, - EventID: eventID, - } - var response getEventAuth - apiURL := h.federationAPIURL + FederationAPIGetEventAuthPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.RespEventAuth{}, err - } - if response.Err != nil { - return gomatrixserverlib.RespEventAuth{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( + "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, + ctx, &getEventAuth{ + S: s, + RoomVersion: roomVersion, + RoomID: roomID, + EventID: eventID, + }, + ) } func (h *httpFederationInternalAPI) QueryServerKeys( ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath, + h.httpClient, ctx, req, res, + ) } type lookupServerKeys struct { S gomatrixserverlib.ServerName KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp - ServerKeys []gomatrixserverlib.ServerKeys - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupServerKeys( ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys") - defer span.Finish() - - request := lookupServerKeys{ - S: s, - KeyRequests: keyRequests, - } - var response lookupServerKeys - apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return []gomatrixserverlib.ServerKeys{}, err - } - if response.Err != nil { - return []gomatrixserverlib.ServerKeys{}, response.Err - } - return response.ServerKeys, nil + return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError]( + "LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient, + ctx, &lookupServerKeys{ + S: s, + KeyRequests: keyRequests, + }, + ) } type eventRelationships struct { S gomatrixserverlib.ServerName Req gomatrixserverlib.MSC2836EventRelationshipsRequest RoomVer gomatrixserverlib.RoomVersion - Res gomatrixserverlib.MSC2836EventRelationshipsResponse - Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2836EventRelationships( ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships") - defer span.Finish() - - request := eventRelationships{ - S: s, - Req: r, - RoomVer: roomVersion, - } - var response eventRelationships - apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath - err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return res, err - } - if response.Err != nil { - return res, response.Err - } - return response.Res, nil + return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( + "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, + ctx, &eventRelationships{ + S: s, + Req: r, + RoomVer: roomVersion, + }, + ) } type spacesReq struct { S gomatrixserverlib.ServerName SuggestedOnly bool RoomID string - Res gomatrixserverlib.MSC2946SpacesResponse - Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2946Spaces( ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") - defer span.Finish() - - request := spacesReq{ - S: dst, - SuggestedOnly: suggestedOnly, - RoomID: roomID, - } - var response spacesReq - apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath - err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return res, err - } - if response.Err != nil { - return res, response.Err - } - return response.Res, nil + return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( + "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, + ctx, &spacesReq{ + S: dst, + SuggestedOnly: suggestedOnly, + RoomID: roomID, + }, + ) } func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { @@ -614,11 +459,10 @@ func (h *httpFederationInternalAPI) InputPublicKeys( request *api.InputPublicKeysRequest, response *api.InputPublicKeysResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath, + h.httpClient, ctx, request, response, + ) } func (h *httpFederationInternalAPI) QueryPublicKeys( @@ -626,9 +470,8 @@ func (h *httpFederationInternalAPI) QueryPublicKeys( request *api.QueryPublicKeysRequest, response *api.QueryPublicKeysResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath, + h.httpClient, ctx, request, response, + ) } diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 28e52b32d..a8b829a71 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -1,12 +1,14 @@ package inthttp import ( + "context" "encoding/json" "net/http" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -15,372 +17,180 @@ import ( func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIQueryJoinedHostServerNamesInRoomPath, - httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryJoinedHostServerNamesInRoomRequest - var response api.QueryJoinedHostServerNamesInRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle( - FederationAPIPerformJoinRequestPath, - httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformJoinRequest - var response api.PerformJoinResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - intAPI.PerformJoin(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle( - FederationAPIPerformLeaveRequestPath, - httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformLeaveRequest - var response api.PerformLeaveResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom), ) + internalAPIMux.Handle( FederationAPIPerformInviteRequestPath, - httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformInviteRequest - var response api.PerformInviteResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite), ) + + internalAPIMux.Handle( + FederationAPIPerformLeaveRequestPath, + httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave), + ) + internalAPIMux.Handle( FederationAPIPerformDirectoryLookupRequestPath, - httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformDirectoryLookupRequest - var response api.PerformDirectoryLookupResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup), ) + internalAPIMux.Handle( FederationAPIPerformBroadcastEDUPath, - httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse { - var request api.PerformBroadcastEDURequest - var response api.PerformBroadcastEDUResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), ) + + internalAPIMux.Handle( + FederationAPIPerformJoinRequestPath, + httputil.MakeInternalRPCAPI( + "FederationAPIPerformJoinRequest", + func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error { + intAPI.PerformJoin(ctx, req, res) + return nil + }, + ), + ) + internalAPIMux.Handle( FederationAPIGetUserDevicesPath, - httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse { - var request getUserDevices - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIGetUserDevices", + func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { + res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIClaimKeysPath, - httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse { - var request claimKeys - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIClaimKeys", + func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { + res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIQueryKeysPath, - httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse { - var request queryKeys - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIQueryKeys", + func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { + res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIBackfillPath, - httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse { - var request backfill - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIBackfill", + func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { + res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPILookupStatePath, - httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse { - var request lookupState - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupState", + func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { + res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPILookupStateIDsPath, - httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse { - var request lookupStateIDs - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupStateIDs", + func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { + res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPILookupMissingEventsPath, - httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse { - var request lookupMissingEvents - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - for _, event := range res.Events { - js, err := json.Marshal(event) - if err != nil { - return util.MessageResponse(http.StatusInternalServerError, err.Error()) - } - request.Res.Events = append(request.Res.Events, js) - } - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupMissingEvents", + func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { + res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIGetEventPath, - httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse { - var request getEvent - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIGetEvent", + func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { + res, err := intAPI.GetEvent(ctx, req.S, req.EventID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIGetEventAuthPath, - httputil.MakeInternalAPI("GetEventAuth", func(req *http.Request) util.JSONResponse { - var request getEventAuth - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.GetEventAuth(req.Context(), request.S, request.RoomVersion, request.RoomID, request.EventID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIGetEventAuth", + func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { + res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIQueryServerKeysPath, - httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse { - var request api.QueryServerKeysRequest - var response api.QueryServerKeysResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys), ) + internalAPIMux.Handle( FederationAPILookupServerKeysPath, - httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse { - var request lookupServerKeys - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.ServerKeys = res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupServerKeys", + func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) { + res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIEventRelationshipsPath, - httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse { - var request eventRelationships - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIMSC2836EventRelationships", + func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { + res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPISpacesSummaryPath, - httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse { - var request spacesReq - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIMSC2946SpacesSummary", + func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { + res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) + return &res, federationClientError(err) + }, + ), ) + + // TODO: Look at this shape internalAPIMux.Handle(FederationAPIQueryPublicKeyPath, - httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse { request := api.QueryPublicKeysRequest{} response := api.QueryPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -394,8 +204,10 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + + // TODO: Look at this shape internalAPIMux.Handle(FederationAPIInputPublicKeyPath, - httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse { request := api.InputPublicKeysRequest{} response := api.InputPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -408,3 +220,18 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { }), ) } + +func federationClientError(err error) error { + switch ferr := err.(type) { + case nil: + return nil + case api.FederationClientError: + return &ferr + case *api.FederationClientError: + return ferr + default: + return &api.FederationClientError{ + Err: err.Error(), + } + } +} diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 2f9da1f25..ce8b06b70 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -30,9 +30,11 @@ func GetUserDevices( userID string, ) util.JSONResponse { var res keyapi.QueryDeviceMessagesResponse - keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ + if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ UserID: userID, - }, &res) + }, &res); err != nil { + return util.ErrorResponse(err) + } if res.Error != nil { util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed") return jsonerror.InternalServerError() @@ -47,7 +49,9 @@ func GetUserDevices( for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } - keyAPI.QuerySignatures(req.Context(), sigReq, sigRes) + if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } response := gomatrixserverlib.RespUserDevices{ UserID: userID, diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 9ad3bd8eb..b48eaf78e 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -392,7 +392,7 @@ func SendJoin( // the room, so set SendAsServer to cfg.Matrix.ServerName if !alreadyJoined { var response api.InputRoomEventsResponse - rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, @@ -401,7 +401,9 @@ func SendJoin( TransactionID: nil, }, }, - }, &response) + }, &response); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if response.ErrMsg != "" { util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed") if response.NotAllowed { diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index b1a9b6710..b03d4c1d6 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -19,7 +19,7 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/httputil" + clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" @@ -61,9 +61,11 @@ func QueryDeviceKeys( } var queryRes api.QueryKeysResponse - keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ + if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ UserToDevices: qkr.DeviceKeys, - }, &queryRes) + }, &queryRes); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if queryRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") return jsonerror.InternalServerError() @@ -113,9 +115,11 @@ func ClaimOneTimeKeys( } var claimRes api.PerformClaimKeysResponse - keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ + if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: cor.OneTimeKeys, - }, &claimRes) + }, &claimRes); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if claimRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") return jsonerror.InternalServerError() @@ -184,7 +188,7 @@ func NotaryKeys( ) util.JSONResponse { if req == nil { req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} - if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { + if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { return *reqErr } } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index dbaf68e5b..8e43ce959 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -277,7 +277,7 @@ func SendLeave( // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName var response api.InputRoomEventsResponse - rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, @@ -286,7 +286,9 @@ func SendLeave( TransactionID: nil, }, }, - }, &response) + }, &response); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if response.ErrMsg != "" { util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed") diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 43003be38..87b6fa33e 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -458,7 +458,9 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli UserID: updatePayload.UserID, } uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} - t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { + return err + } if uploadRes.Error != nil { return uploadRes.Error } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index a111580c7..1c796f542 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -64,11 +64,12 @@ func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) { +) error { t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) for _, ire := range request.InputRoomEvents { fmt.Println("InputRoomEvents: ", ire.Event.EventID()) } + return nil } // Query the latest events and state for a room from the room server. diff --git a/internal/httputil/http.go b/internal/httputil/http.go index 4527e2b95..1e07ee33c 100644 --- a/internal/httputil/http.go +++ b/internal/httputil/http.go @@ -19,19 +19,21 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" - "github.com/matrix-org/dendrite/userapi/api" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" ) -// PostJSON performs a POST request with JSON on an internal HTTP API -func PostJSON( +// PostJSON performs a POST request with JSON on an internal HTTP API. +// The error will match the errtype if returned from the remote API, or +// will be a different type if there was a problem reaching the API. +func PostJSON[reqtype, restype any, errtype error]( ctx context.Context, span opentracing.Span, httpClient *http.Client, - apiURL string, request, response interface{}, + apiURL string, request *reqtype, response *restype, ) error { jsonBytes, err := json.Marshal(request) if err != nil { @@ -69,17 +71,23 @@ func PostJSON( if err != nil { return err } - if res.StatusCode != http.StatusOK { - var errorBody struct { - Message string `json:"message"` - } - if _, ok := response.(*api.PerformKeyBackupResponse); ok { // TODO: remove this, once cross-boundary errors are a thing - return nil - } - if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil { - return fmt.Errorf("internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message) - } - return fmt.Errorf("internal API: %d from %s", res.StatusCode, apiURL) + var body []byte + body, err = io.ReadAll(res.Body) + if err != nil { + return err } - return json.NewDecoder(res.Body).Decode(response) + if res.StatusCode != http.StatusOK { + if len(body) == 0 { + return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL) + } + var reserr errtype + if err = json.Unmarshal(body, reserr); err != nil { + return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL) + } + return reserr + } + if err = json.Unmarshal(body, response); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + return nil } diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go new file mode 100644 index 000000000..385092d9c --- /dev/null +++ b/internal/httputil/internalapi.go @@ -0,0 +1,93 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "reflect" + + "github.com/matrix-org/util" + opentracing "github.com/opentracing/opentracing-go" +) + +type InternalAPIError struct { + Type string + Message string +} + +func (e InternalAPIError) Error() string { + return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message) +} + +func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler { + return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { + var request reqtype + var response restype + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := f(req.Context(), &request, &response); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: &InternalAPIError{ + Type: reflect.TypeOf(err).String(), + Message: fmt.Sprintf("%s", err), + }, + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: &response, + } + }) +} + +func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler { + return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { + var request reqtype + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + response, err := f(req.Context(), &request) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } + }) +} + +func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error { + span, ctx := opentracing.StartSpanFromContext(ctx, name) + defer span.Finish() + + return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response) +} + +func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, name) + defer span.Finish() + + var response restype + return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response) +} diff --git a/keyserver/api/api.go b/keyserver/api/api.go index c0a1eedbb..9ba3988b9 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -38,32 +38,32 @@ type KeyInternalAPI interface { // API functions required by the clientapi type ClientKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // API functions required by the userapi type UserKeyAPI interface { - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error } // API functions required by the syncapi type SyncKeyAPI interface { - QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) - QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error } type FederationKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // KeyError is returned if there was a problem performing/querying the server diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 08bbfedb8..99859dff6 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos } // nolint:gocyclo -func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { +func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { // Find the keys to store. byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} @@ -115,7 +115,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "Master key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return + return nil } byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey @@ -131,7 +131,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "Self-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return + return nil } byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey @@ -146,7 +146,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "User-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return + return nil } byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey @@ -161,7 +161,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "No keys were supplied in the request", IsMissingParam: true, } - return + return nil } // We can't have a self-signing or user-signing key without a master @@ -174,7 +174,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - return + return nil } // If we still can't find a master key for the user then stop the upload. @@ -185,7 +185,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "No master key was found", IsMissingParam: true, } - return + return nil } } @@ -212,7 +212,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } } if !changed { - return + return nil } // Store the keys. @@ -220,7 +220,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } - return + return nil } // Now upload any signatures that were included with the keys. @@ -238,7 +238,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } - return + return nil } } } @@ -255,17 +255,18 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P update.SelfSigningKey = &ssk } if update.MasterKey == nil && update.SelfSigningKey == nil { - return + return nil } if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return + return nil } + return nil } -func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) { +func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { // Before we do anything, we need the master and self-signing keys for this user. // Then we can verify the signatures make sense. queryReq := &api.QueryKeysRequest{ @@ -276,7 +277,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req for userID := range req.Signatures { queryReq.UserToDevices[userID] = []string{} } - a.QueryKeys(ctx, queryReq, queryRes) + _ = a.QueryKeys(ctx, queryReq, queryRes) selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} @@ -322,14 +323,14 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processSelfSignatures: %s", err), } - return + return nil } if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processOtherSignatures: %s", err), } - return + return nil } // Finally, generate a notification that we updated the signatures. @@ -345,9 +346,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return + return nil } } + return nil } func (a *KeyInternalAPI) processSelfSignatures( @@ -520,7 +522,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( } } -func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) { +func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { for targetUserID, forTargetUser := range req.TargetIDs { keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { @@ -559,7 +561,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), } - return + return nil } for sourceUserID, forSourceUser := range sigMap { @@ -581,4 +583,5 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign } } } + return nil } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index e317755e4..80efbec51 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -119,7 +119,7 @@ type DeviceListUpdaterDatabase interface { } type DeviceListUpdaterAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) + PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error } // KeyChangeProducer is the interface for producers.KeyChange useful for testing. @@ -421,7 +421,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam uploadReq.SelfSigningKey = *res.SelfSigningKey } } - u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) } err = u.updateDeviceList(&res) if err != nil { diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 0c2405f34..0520a9e66 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -113,8 +113,8 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys type mockDeviceListUpdaterAPI struct { } -func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { - +func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { + return nil } type roundTripper struct { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 91f011517..41b4d44a4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -48,18 +48,20 @@ func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { a.UserAPI = i } -func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { +func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), } + return nil } res.Offset = latest res.UserIDs = userIDs + return nil } -func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { res.KeyErrors = make(map[string]map[string]*api.KeyError) if len(req.DeviceKeys) > 0 { a.uploadLocalDeviceKeys(ctx, req, res) @@ -67,9 +69,10 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } + return nil } -func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { +func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // wrap request map in a top-level by-domain map @@ -113,6 +116,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } + return nil } func (a *KeyInternalAPI) claimRemoteKeys( @@ -172,32 +176,34 @@ func (a *KeyInternalAPI) claimRemoteKeys( util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } -func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) { +func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to delete device keys: %s", err), } } + return nil } -func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) { +func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to query OTK counts: %s", err), } - return + return nil } res.Count = *count + return nil } -func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { +func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), } - return + return nil } maxStreamID := int64(0) for _, m := range msgs { @@ -215,10 +221,11 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query } res.Devices = result res.StreamID = maxStreamID + return nil } // nolint:gocyclo -func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { +func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -244,7 +251,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), } - return + return nil } // pull out display names after we have the keys so we handle wildcards correctly @@ -318,7 +325,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return + return nil } logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") continue @@ -344,7 +351,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return + return nil } logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") continue @@ -372,6 +379,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } } } + return nil } func (a *KeyInternalAPI) remoteKeysFromDatabase( diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index dac61d1ea..7a7131145 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP APIs @@ -68,168 +67,108 @@ func (h *httpKeyInternalAPI) PerformClaimKeys( ctx context.Context, request *api.PerformClaimKeysRequest, response *api.PerformClaimKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformClaimKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformClaimKeys", h.apiURL+PerformClaimKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformDeleteKeys( ctx context.Context, request *api.PerformDeleteKeysRequest, response *api.PerformDeleteKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformClaimKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformUploadKeys( ctx context.Context, request *api.PerformUploadKeysRequest, response *api.PerformUploadKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformUploadKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUploadKeys", h.apiURL+PerformUploadKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryKeys( ctx context.Context, request *api.QueryKeysRequest, response *api.QueryKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys") - defer span.Finish() - - apiURL := h.apiURL + QueryKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryKeys", h.apiURL+QueryKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryOneTimeKeys( ctx context.Context, request *api.QueryOneTimeKeysRequest, response *api.QueryOneTimeKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys") - defer span.Finish() - - apiURL := h.apiURL + QueryOneTimeKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryDeviceMessages( ctx context.Context, request *api.QueryDeviceMessagesRequest, response *api.QueryDeviceMessagesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceMessages") - defer span.Finish() - - apiURL := h.apiURL + QueryDeviceMessagesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryKeyChanges( ctx context.Context, request *api.QueryKeyChangesRequest, response *api.QueryKeyChangesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges") - defer span.Finish() - - apiURL := h.apiURL + QueryKeyChangesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryKeyChanges", h.apiURL+QueryKeyChangesPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformUploadDeviceKeys( ctx context.Context, request *api.PerformUploadDeviceKeysRequest, response *api.PerformUploadDeviceKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformUploadDeviceKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures( ctx context.Context, request *api.PerformUploadDeviceSignaturesRequest, response *api.PerformUploadDeviceSignaturesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceSignatures") - defer span.Finish() - - apiURL := h.apiURL + PerformUploadDeviceSignaturesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QuerySignatures( ctx context.Context, request *api.QuerySignaturesRequest, response *api.QuerySignaturesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySignatures") - defer span.Finish() - - apiURL := h.apiURL + QuerySignaturesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QuerySignatures", h.apiURL+QuerySignaturesPath, + h.httpClient, ctx, request, response, + ) } diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index 5bf5976a8..4e5f9fba4 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -15,124 +15,59 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/util" ) func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { - internalAPIMux.Handle(PerformClaimKeysPath, - httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformClaimKeysRequest{} - response := api.PerformClaimKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformClaimKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformClaimKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys), ) - internalAPIMux.Handle(PerformDeleteKeysPath, - httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformDeleteKeysRequest{} - response := api.PerformDeleteKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformDeleteKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformDeleteKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys), ) - internalAPIMux.Handle(PerformUploadKeysPath, - httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformUploadKeysRequest{} - response := api.PerformUploadKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformUploadKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformUploadKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys), ) - internalAPIMux.Handle(PerformUploadDeviceKeysPath, - httputil.MakeInternalAPI("performUploadDeviceKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformUploadDeviceKeysRequest{} - response := api.PerformUploadDeviceKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformUploadDeviceKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformUploadDeviceKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys), ) - internalAPIMux.Handle(PerformUploadDeviceSignaturesPath, - httputil.MakeInternalAPI("performUploadDeviceSignatures", func(req *http.Request) util.JSONResponse { - request := api.PerformUploadDeviceSignaturesRequest{} - response := api.PerformUploadDeviceSignaturesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformUploadDeviceSignatures(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformUploadDeviceSignaturesPath, + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures), ) - internalAPIMux.Handle(QueryKeysPath, - httputil.MakeInternalAPI("queryKeys", func(req *http.Request) util.JSONResponse { - request := api.QueryKeysRequest{} - response := api.QueryKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryKeysPath, + httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys), ) - internalAPIMux.Handle(QueryOneTimeKeysPath, - httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse { - request := api.QueryOneTimeKeysRequest{} - response := api.QueryOneTimeKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryOneTimeKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryOneTimeKeysPath, + httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys), ) - internalAPIMux.Handle(QueryDeviceMessagesPath, - httputil.MakeInternalAPI("queryDeviceMessages", func(req *http.Request) util.JSONResponse { - request := api.QueryDeviceMessagesRequest{} - response := api.QueryDeviceMessagesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryDeviceMessages(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryDeviceMessagesPath, + httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages), ) - internalAPIMux.Handle(QueryKeyChangesPath, - httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { - request := api.QueryKeyChangesRequest{} - response := api.QueryKeyChangesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryKeyChanges(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryKeyChangesPath, + httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges), ) - internalAPIMux.Handle(QuerySignaturesPath, - httputil.MakeInternalAPI("querySignatures", func(req *http.Request) util.JSONResponse { - request := api.QuerySignaturesRequest{} - response := api.QuerySignaturesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QuerySignatures(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QuerySignaturesPath, + httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures), ) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 38baa617f..ee0212ecf 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -40,7 +40,7 @@ type InputRoomEventsAPI interface { ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, - ) + ) error } // Query the latest events and state for a room from the room server. @@ -139,15 +139,15 @@ type ClientRoomserverAPI interface { GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error // PerformRoomUpgrade upgrades a room to a newer version - PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) - PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) - PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) - PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) - PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) + PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error + PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error + PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error + PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error - PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) + PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error - PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) + PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error // PerformForget forgets a rooms history for a specific user PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error @@ -158,7 +158,7 @@ type UserRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error - PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) + PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error } type FederationRoomserverAPI interface { diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 211f320ff..18a617331 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -35,9 +35,10 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, -) { - t.Impl.InputRoomEvents(ctx, req, res) - util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.InputRoomEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformInvite( @@ -45,44 +46,49 @@ func (t *RoomserverInternalAPITrace) PerformInvite( req *PerformInviteRequest, res *PerformInviteResponse, ) error { - util.GetLogger(ctx).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) - return t.Impl.PerformInvite(ctx, req, res) + err := t.Impl.PerformInvite(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformPeek( ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse, -) { - t.Impl.PerformPeek(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformPeek(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformUnpeek( ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse, -) { - t.Impl.PerformUnpeek(ctx, req, res) - util.GetLogger(ctx).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformUnpeek(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformRoomUpgrade( ctx context.Context, req *PerformRoomUpgradeRequest, res *PerformRoomUpgradeResponse, -) { - t.Impl.PerformRoomUpgrade(ctx, req, res) - util.GetLogger(ctx).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformRoomUpgrade(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformJoin( ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse, -) { - t.Impl.PerformJoin(ctx, req, res) - util.GetLogger(ctx).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformJoin(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformLeave( @@ -99,27 +105,30 @@ func (t *RoomserverInternalAPITrace) PerformPublish( ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse, -) { - t.Impl.PerformPublish(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformPublish(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom( ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse, -) { - t.Impl.PerformAdminEvacuateRoom(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse, -) { - t.Impl.PerformAdminEvacuateUser(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformAdminEvacuateUser(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformInboundPeek( @@ -128,7 +137,7 @@ func (t *RoomserverInternalAPITrace) PerformInboundPeek( res *PerformInboundPeekResponse, ) error { err := t.Impl.PerformInboundPeek(ctx, req, res) - util.GetLogger(ctx).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res)) + util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res)) return err } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 344e9b079..bc2f28176 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -90,7 +90,9 @@ func SendInputRoomEvents( Asynchronous: async, } var response InputRoomEventsResponse - rsAPI.InputRoomEvents(ctx, &request, &response) + if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { + return err + } return response.Err() } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 339c9796c..8d24f3c59 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -337,18 +337,18 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) { +) error { // Queue up the event into the roomserver. replySub, err := r.queueInputRoomEvents(ctx, request) if err != nil { response.ErrMsg = err.Error() - return + return nil } // If we aren't waiting for synchronous responses then we can // give up here, there is nothing further to do. if replySub == nil { - return + return nil } // Otherwise, we'll want to sit and wait for the responses @@ -360,12 +360,14 @@ func (r *Inputer) InputRoomEvents( msg, err := replySub.NextMsgWithContext(ctx) if err != nil { response.ErrMsg = err.Error() - return + return nil } if len(msg.Data) > 0 { response.ErrMsg = string(msg.Data) } } + + return nil } var roomserverInputBackpressure = prometheus.NewGaugeVec( diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 6c7c6c98b..cb6b22d32 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -43,21 +43,21 @@ func (r *Admin) PerformAdminEvacuateRoom( ctx context.Context, req *api.PerformAdminEvacuateRoomRequest, res *api.PerformAdminEvacuateRoomResponse, -) { +) error { roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err), } - return + return nil } if roomInfo == nil || roomInfo.IsStub() { res.Error = &api.PerformError{ Code: api.PerformErrorNoRoom, Msg: fmt.Sprintf("Room %s not found", req.RoomID), } - return + return nil } memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) @@ -66,7 +66,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err), } - return + return nil } memberEvents, err := r.DB.Events(ctx, memberNIDs) @@ -75,7 +75,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.Events: %s", err), } - return + return nil } inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) @@ -89,7 +89,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err), } - return + return nil } prevEvents := latestRes.LatestEvents @@ -104,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("json.Unmarshal: %s", err), } - return + return nil } memberContent.Membership = gomatrixserverlib.Leave @@ -122,7 +122,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("json.Marshal: %s", err), } - return + return nil } eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) @@ -131,7 +131,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err), } - return + return nil } event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) @@ -140,7 +140,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err), } - return + return nil } inputEvents = append(inputEvents, api.InputRoomEvent{ @@ -160,28 +160,28 @@ func (r *Admin) PerformAdminEvacuateRoom( Asynchronous: true, } inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) + return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) } func (r *Admin) PerformAdminEvacuateUser( ctx context.Context, req *api.PerformAdminEvacuateUserRequest, res *api.PerformAdminEvacuateUserResponse, -) { +) error { _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Malformed user ID: %s", err), } - return + return nil } if domain != r.Cfg.Matrix.ServerName { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: "Can only evacuate local users using this endpoint", } - return + return nil } roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join) @@ -190,7 +190,7 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), } - return + return nil } inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite) @@ -199,7 +199,7 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), } - return + return nil } for _, roomID := range append(roomIDs, inviteRoomIDs...) { @@ -214,7 +214,7 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err), } - return + return nil } if len(outputEvents) == 0 { continue @@ -224,9 +224,10 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err), } - return + return nil } res.Affected = append(res.Affected, roomID) } + return nil } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e1ff4eabb..483e78c3f 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -241,7 +241,9 @@ func (r *Inviter) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { + return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) + } if err = inputRes.Err(); err != nil { res.Error = &api.PerformError{ Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 1445b4088..43be54beb 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -52,7 +52,7 @@ func (r *Joiner) PerformJoin( ctx context.Context, req *rsAPI.PerformJoinRequest, res *rsAPI.PerformJoinResponse, -) { +) error { logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomIDOrAlias, "user_id": req.UserID, @@ -71,11 +71,12 @@ func (r *Joiner) PerformJoin( Msg: err.Error(), } } - return + return nil } logger.Info("User joined room successfully") res.RoomID = roomID res.JoinedVia = joinedVia + return nil } func (r *Joiner) performJoin( @@ -291,7 +292,12 @@ func (r *Joiner) performJoinRoomByID( }, } inputRes := rsAPI.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNoOperation, + Msg: fmt.Sprintf("InputRoomEvents failed: %s", err), + } + } if err = inputRes.Err(); err != nil { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorNotAllowed, diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 56e7240d0..036404cd2 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -186,7 +186,9 @@ func (r *Leaver) performLeaveRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) + } if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 5560916b2..74d87a5b4 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -44,7 +44,7 @@ func (r *Peeker) PerformPeek( ctx context.Context, req *api.PerformPeekRequest, res *api.PerformPeekResponse, -) { +) error { roomID, err := r.performPeek(ctx, req) if err != nil { perr, ok := err.(*api.PerformError) @@ -57,6 +57,7 @@ func (r *Peeker) PerformPeek( } } res.RoomID = roomID + return nil } func (r *Peeker) performPeek( diff --git a/roomserver/internal/perform/perform_publish.go b/roomserver/internal/perform/perform_publish.go index 6ff42ac1a..1631fc657 100644 --- a/roomserver/internal/perform/perform_publish.go +++ b/roomserver/internal/perform/perform_publish.go @@ -29,11 +29,12 @@ func (r *Publisher) PerformPublish( ctx context.Context, req *api.PerformPublishRequest, res *api.PerformPublishResponse, -) { +) error { err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") if err != nil { res.Error = &api.PerformError{ Msg: err.Error(), } } + return nil } diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 1fe8d5a0f..49e9067c9 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -41,7 +41,7 @@ func (r *Unpeeker) PerformUnpeek( ctx context.Context, req *api.PerformUnpeekRequest, res *api.PerformUnpeekResponse, -) { +) error { if err := r.performUnpeek(ctx, req); err != nil { perr, ok := err.(*api.PerformError) if ok { @@ -52,6 +52,7 @@ func (r *Unpeeker) PerformUnpeek( } } } + return nil } func (r *Unpeeker) performUnpeek( diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 393d7dd14..d6dc9708c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -45,12 +45,13 @@ func (r *Upgrader) PerformRoomUpgrade( ctx context.Context, req *api.PerformRoomUpgradeRequest, res *api.PerformRoomUpgradeResponse, -) { +) error { res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req) if res.Error != nil { res.NewRoomID = "" logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed") } + return nil } func (r *Upgrader) performRoomUpgrade( @@ -286,22 +287,24 @@ func publishNewRoomAndUnpublishOldRoom( ) { // expose this room in the published room list var pubNewRoomRes api.PerformPublishResponse - URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ + if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ RoomID: newRoomID, Visibility: "public", - }, &pubNewRoomRes) - if pubNewRoomRes.Error != nil { + }, &pubNewRoomRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to reach internal API") + } else if pubNewRoomRes.Error != nil { // treat as non-fatal since the room is already made by this point util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public") } var unpubOldRoomRes api.PerformPublishResponse // remove the old room from the published room list - URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ + if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ RoomID: oldRoomID, Visibility: "private", - }, &unpubOldRoomRes) - if unpubOldRoomRes.Error != nil { + }, &unpubOldRoomRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to reach internal API") + } else if unpubOldRoomRes.Error != nil { // treat as non-fatal since the room is already made by this point util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private") } diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 2fa8afc49..d16f67c69 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -3,7 +3,6 @@ package inthttp import ( "context" "errors" - "fmt" "net/http" asAPI "github.com/matrix-org/dendrite/appservice/api" @@ -14,7 +13,6 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/opentracing/opentracing-go" ) const ( @@ -106,11 +104,10 @@ func (h *httpRoomserverInternalAPI) SetRoomAlias( request *api.SetRoomAliasRequest, response *api.SetRoomAliasResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverSetRoomAliasPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath, + h.httpClient, ctx, request, response, + ) } // GetRoomIDForAlias implements RoomserverAliasAPI @@ -119,11 +116,10 @@ func (h *httpRoomserverInternalAPI) GetRoomIDForAlias( request *api.GetRoomIDForAliasRequest, response *api.GetRoomIDForAliasResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath, + h.httpClient, ctx, request, response, + ) } // GetAliasesForRoomID implements RoomserverAliasAPI @@ -132,11 +128,10 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID( request *api.GetAliasesForRoomIDRequest, response *api.GetAliasesForRoomIDResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath, + h.httpClient, ctx, request, response, + ) } // RemoveRoomAlias implements RoomserverAliasAPI @@ -145,11 +140,10 @@ func (h *httpRoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath, + h.httpClient, ctx, request, response, + ) } // InputRoomEvents implements RoomserverInputAPI @@ -157,15 +151,14 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverInputRoomEventsPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { +) error { + if err := httputil.CallInternalRPCAPI( + "InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath, + h.httpClient, ctx, request, response, + ); err != nil { response.ErrMsg = err.Error() } + return nil } func (h *httpRoomserverInternalAPI) PerformInvite( @@ -173,45 +166,32 @@ func (h *httpRoomserverInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInvite") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformInvitePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformInvite", h.roomserverURL+RoomserverPerformInvitePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformJoin( ctx context.Context, request *api.PerformJoinRequest, response *api.PerformJoinResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoin") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformJoinPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformJoin", h.roomserverURL+RoomserverPerformJoinPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformPeek( ctx context.Context, request *api.PerformPeekRequest, response *api.PerformPeekResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformPeekPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformPeek", h.roomserverURL+RoomserverPerformPeekPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformInboundPeek( @@ -219,45 +199,32 @@ func (h *httpRoomserverInternalAPI) PerformInboundPeek( request *api.PerformInboundPeekRequest, response *api.PerformInboundPeekResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInboundPeek") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformInboundPeekPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformUnpeek( ctx context.Context, request *api.PerformUnpeekRequest, response *api.PerformUnpeekResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUnpeek") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformUnpeekPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformRoomUpgrade( ctx context.Context, request *api.PerformRoomUpgradeRequest, response *api.PerformRoomUpgradeResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformRoomUpgrade") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformRoomUpgradePath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformLeave( @@ -265,62 +232,43 @@ func (h *httpRoomserverInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeave") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformLeavePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLeave", h.roomserverURL+RoomserverPerformLeavePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformPublish( ctx context.Context, - req *api.PerformPublishRequest, - res *api.PerformPublishResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPublish") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformPublishPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } + request *api.PerformPublishRequest, + response *api.PerformPublishResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformPublish", h.roomserverURL+RoomserverPerformPublishPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom( ctx context.Context, - req *api.PerformAdminEvacuateRoomRequest, - res *api.PerformAdminEvacuateRoomResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } + request *api.PerformAdminEvacuateRoomRequest, + response *api.PerformAdminEvacuateRoomResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( ctx context.Context, - req *api.PerformAdminEvacuateUserRequest, - res *api.PerformAdminEvacuateUserResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateUserPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } + request *api.PerformAdminEvacuateUserRequest, + response *api.PerformAdminEvacuateUserResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath, + h.httpClient, ctx, request, response, + ) } // QueryLatestEventsAndState implements RoomserverQueryAPI @@ -329,11 +277,10 @@ func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath, + h.httpClient, ctx, request, response, + ) } // QueryStateAfterEvents implements RoomserverQueryAPI @@ -342,11 +289,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath, + h.httpClient, ctx, request, response, + ) } // QueryEventsByID implements RoomserverQueryAPI @@ -355,11 +301,10 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID( request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryPublishedRooms( @@ -367,11 +312,10 @@ func (h *httpRoomserverInternalAPI) QueryPublishedRooms( request *api.QueryPublishedRoomsRequest, response *api.QueryPublishedRoomsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublishedRooms") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryPublishedRoomsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath, + h.httpClient, ctx, request, response, + ) } // QueryMembershipForUser implements RoomserverQueryAPI @@ -380,11 +324,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipForUser( request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath, + h.httpClient, ctx, request, response, + ) } // QueryMembershipsForRoom implements RoomserverQueryAPI @@ -393,11 +336,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom( request *api.QueryMembershipsForRoomRequest, response *api.QueryMembershipsForRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath, + h.httpClient, ctx, request, response, + ) } // QueryMembershipsForRoom implements RoomserverQueryAPI @@ -406,11 +348,10 @@ func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom( request *api.QueryServerJoinedToRoomRequest, response *api.QueryServerJoinedToRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath, + h.httpClient, ctx, request, response, + ) } // QueryServerAllowedToSeeEvent implements RoomserverQueryAPI @@ -419,11 +360,10 @@ func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent( request *api.QueryServerAllowedToSeeEventRequest, response *api.QueryServerAllowedToSeeEventResponse, ) (err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath, + h.httpClient, ctx, request, response, + ) } // QueryMissingEvents implements RoomServerQueryAPI @@ -432,11 +372,10 @@ func (h *httpRoomserverInternalAPI) QueryMissingEvents( request *api.QueryMissingEventsRequest, response *api.QueryMissingEventsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath, + h.httpClient, ctx, request, response, + ) } // QueryStateAndAuthChain implements RoomserverQueryAPI @@ -445,11 +384,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain( request *api.QueryStateAndAuthChainRequest, response *api.QueryStateAndAuthChainResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath, + h.httpClient, ctx, request, response, + ) } // PerformBackfill implements RoomServerQueryAPI @@ -458,11 +396,10 @@ func (h *httpRoomserverInternalAPI) PerformBackfill( request *api.PerformBackfillRequest, response *api.PerformBackfillResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformBackfillPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath, + h.httpClient, ctx, request, response, + ) } // QueryRoomVersionCapabilities implements RoomServerQueryAPI @@ -471,11 +408,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities( request *api.QueryRoomVersionCapabilitiesRequest, response *api.QueryRoomVersionCapabilitiesResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath, + h.httpClient, ctx, request, response, + ) } // QueryRoomVersionForRoom implements RoomServerQueryAPI @@ -488,16 +424,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( response.RoomVersion = roomVersion return nil } - - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err == nil { - h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) - } - return err + return httputil.CallInternalRPCAPI( + "QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryCurrentState( @@ -505,11 +435,10 @@ func (h *httpRoomserverInternalAPI) QueryCurrentState( request *api.QueryCurrentStateRequest, response *api.QueryCurrentStateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryRoomsForUser( @@ -517,11 +446,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomsForUser( request *api.QueryRoomsForUserRequest, response *api.QueryRoomsForUserResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryBulkStateContent( @@ -529,68 +457,75 @@ func (h *httpRoomserverInternalAPI) QueryBulkStateContent( request *api.QueryBulkStateContentRequest, response *api.QueryBulkStateContentResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QuerySharedUsers( - ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, + ctx context.Context, + request *api.QuerySharedUsersRequest, + response *api.QuerySharedUsersResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryKnownUsers( - ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, + ctx context.Context, + request *api.QueryKnownUsersRequest, + response *api.QueryKnownUsersResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryAuthChain( - ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse, + ctx context.Context, + request *api.QueryAuthChainRequest, + response *api.QueryAuthChainResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryAuthChainPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, + ctx context.Context, + request *api.QueryServerBannedFromRoomRequest, + response *api.QueryServerBannedFromRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed( - ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse, + ctx context.Context, + request *api.QueryRestrictedJoinAllowedRequest, + response *api.QueryRestrictedJoinAllowedResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed, + h.httpClient, ctx, request, response, + ) } -func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformForgetPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpRoomserverInternalAPI) PerformForget( + ctx context.Context, + request *api.PerformForgetRequest, + response *api.PerformForgetResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformForget", h.roomserverURL+RoomserverPerformForgetPath, + h.httpClient, ctx, request, response, + ) } diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 993381585..e325d76a5 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -1,499 +1,196 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/util" ) // AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. // nolint: gocyclo func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { - internalAPIMux.Handle(RoomserverInputRoomEventsPath, - httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { - var request api.InputRoomEventsRequest - var response api.InputRoomEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.InputRoomEvents(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + RoomserverInputRoomEventsPath, + httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents), ) - internalAPIMux.Handle(RoomserverPerformInvitePath, - httputil.MakeInternalAPI("performInvite", func(req *http.Request) util.JSONResponse { - var request api.PerformInviteRequest - var response api.PerformInviteResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.PerformInvite(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformInvitePath, + httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite), ) - internalAPIMux.Handle(RoomserverPerformJoinPath, - httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse { - var request api.PerformJoinRequest - var response api.PerformJoinResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformJoin(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformJoinPath, + httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin), ) - internalAPIMux.Handle(RoomserverPerformLeavePath, - httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse { - var request api.PerformLeaveRequest - var response api.PerformLeaveResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.PerformLeave(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformLeavePath, + httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave), ) - internalAPIMux.Handle(RoomserverPerformPeekPath, - httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse { - var request api.PerformPeekRequest - var response api.PerformPeekResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformPeek(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformPeekPath, + httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek), ) - internalAPIMux.Handle(RoomserverPerformInboundPeekPath, - httputil.MakeInternalAPI("performInboundPeek", func(req *http.Request) util.JSONResponse { - var request api.PerformInboundPeekRequest - var response api.PerformInboundPeekResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.PerformInboundPeek(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformInboundPeekPath, + httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek), ) - internalAPIMux.Handle(RoomserverPerformPeekPath, - httputil.MakeInternalAPI("performUnpeek", func(req *http.Request) util.JSONResponse { - var request api.PerformUnpeekRequest - var response api.PerformUnpeekResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformUnpeek(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformUnpeekPath, + httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek), ) - internalAPIMux.Handle(RoomserverPerformRoomUpgradePath, - httputil.MakeInternalAPI("performRoomUpgrade", func(req *http.Request) util.JSONResponse { - var request api.PerformRoomUpgradeRequest - var response api.PerformRoomUpgradeResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformRoomUpgrade(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformRoomUpgradePath, + httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade), ) - internalAPIMux.Handle(RoomserverPerformPublishPath, - httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse { - var request api.PerformPublishRequest - var response api.PerformPublishResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformPublish(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformPublishPath, + httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish), ) - internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath, - httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse { - var request api.PerformAdminEvacuateRoomRequest - var response api.PerformAdminEvacuateRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformAdminEvacuateRoom(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformAdminEvacuateRoomPath, + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom), ) - internalAPIMux.Handle(RoomserverPerformAdminEvacuateUserPath, - httputil.MakeInternalAPI("performAdminEvacuateUser", func(req *http.Request) util.JSONResponse { - var request api.PerformAdminEvacuateUserRequest - var response api.PerformAdminEvacuateUserResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformAdminEvacuateUser(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformAdminEvacuateUserPath, + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser), ) + internalAPIMux.Handle( RoomserverQueryPublishedRoomsPath, - httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { - var request api.QueryPublishedRoomsRequest - var response api.QueryPublishedRoomsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryPublishedRooms(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms), ) + internalAPIMux.Handle( RoomserverQueryLatestEventsAndStatePath, - httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { - var request api.QueryLatestEventsAndStateRequest - var response api.QueryLatestEventsAndStateResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState), ) + internalAPIMux.Handle( RoomserverQueryStateAfterEventsPath, - httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryStateAfterEventsRequest - var response api.QueryStateAfterEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents), ) + internalAPIMux.Handle( RoomserverQueryEventsByIDPath, - httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { - var request api.QueryEventsByIDRequest - var response api.QueryEventsByIDResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID), ) + internalAPIMux.Handle( RoomserverQueryMembershipForUserPath, - httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { - var request api.QueryMembershipForUserRequest - var response api.QueryMembershipForUserResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser), ) + internalAPIMux.Handle( RoomserverQueryMembershipsForRoomPath, - httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryMembershipsForRoomRequest - var response api.QueryMembershipsForRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom), ) + internalAPIMux.Handle( RoomserverQueryServerJoinedToRoomPath, - httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryServerJoinedToRoomRequest - var response api.QueryServerJoinedToRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom), ) + internalAPIMux.Handle( RoomserverQueryServerAllowedToSeeEventPath, - httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { - var request api.QueryServerAllowedToSeeEventRequest - var response api.QueryServerAllowedToSeeEventResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent), ) + internalAPIMux.Handle( RoomserverQueryMissingEventsPath, - httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryMissingEventsRequest - var response api.QueryMissingEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents), ) + internalAPIMux.Handle( RoomserverQueryStateAndAuthChainPath, - httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { - var request api.QueryStateAndAuthChainRequest - var response api.QueryStateAndAuthChainResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain), ) + internalAPIMux.Handle( RoomserverPerformBackfillPath, - httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse { - var request api.PerformBackfillRequest - var response api.PerformBackfillResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.PerformBackfill(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill), ) + internalAPIMux.Handle( RoomserverPerformForgetPath, - httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse { - var request api.PerformForgetRequest - var response api.PerformForgetResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.PerformForget(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget), ) + internalAPIMux.Handle( RoomserverQueryRoomVersionCapabilitiesPath, - httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { - var request api.QueryRoomVersionCapabilitiesRequest - var response api.QueryRoomVersionCapabilitiesResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities), ) + internalAPIMux.Handle( RoomserverQueryRoomVersionForRoomPath, - httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryRoomVersionForRoomRequest - var response api.QueryRoomVersionForRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom), ) + internalAPIMux.Handle( RoomserverSetRoomAliasPath, - httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { - var request api.SetRoomAliasRequest - var response api.SetRoomAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias), ) + internalAPIMux.Handle( RoomserverGetRoomIDForAliasPath, - httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { - var request api.GetRoomIDForAliasRequest - var response api.GetRoomIDForAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias), ) + internalAPIMux.Handle( RoomserverGetAliasesForRoomIDPath, - httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { - var request api.GetAliasesForRoomIDRequest - var response api.GetAliasesForRoomIDResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID), ) + internalAPIMux.Handle( RoomserverRemoveRoomAliasPath, - httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { - var request api.RemoveRoomAliasRequest - var response api.RemoveRoomAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias), ) - internalAPIMux.Handle(RoomserverQueryCurrentStatePath, - httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse { - request := api.QueryCurrentStateRequest{} - response := api.QueryCurrentStateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryCurrentStatePath, + httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState), ) - internalAPIMux.Handle(RoomserverQueryRoomsForUserPath, - httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { - request := api.QueryRoomsForUserRequest{} - response := api.QueryRoomsForUserResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryRoomsForUserPath, + httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser), ) - internalAPIMux.Handle(RoomserverQueryBulkStateContentPath, - httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { - request := api.QueryBulkStateContentRequest{} - response := api.QueryBulkStateContentResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryBulkStateContentPath, + httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent), ) - internalAPIMux.Handle(RoomserverQuerySharedUsersPath, - httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { - request := api.QuerySharedUsersRequest{} - response := api.QuerySharedUsersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQuerySharedUsersPath, + httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers), ) - internalAPIMux.Handle(RoomserverQueryKnownUsersPath, - httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { - request := api.QueryKnownUsersRequest{} - response := api.QueryKnownUsersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryKnownUsersPath, + httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers), ) - internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath, - httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { - request := api.QueryServerBannedFromRoomRequest{} - response := api.QueryServerBannedFromRoomResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryServerBannedFromRoomPath, + httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom), ) - internalAPIMux.Handle(RoomserverQueryAuthChainPath, - httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse { - request := api.QueryAuthChainRequest{} - response := api.QueryAuthChainResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryAuthChainPath, + httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain), ) - internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed, - httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse { - request := api.QueryRestrictedJoinAllowedRequest{} - response := api.QueryRestrictedJoinAllowedResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryRestrictedJoinAllowed, + httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed), ) } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index eeded4275..3e9d90a1f 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -164,9 +164,9 @@ func TestMSC2836(t *testing.T) { // make everyone joined to each other's rooms nopRsAPI := &testRoomserverAPI{ userToJoinedRooms: map[string][]string{ - alice: []string{roomID}, - bob: []string{roomID}, - charlie: []string{roomID}, + alice: {roomID}, + bob: {roomID}, + charlie: {roomID}, }, events: map[string]*gomatrixserverlib.HeaderedEvent{ eventA.EventID(): eventA, diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 4bf54cae0..23824e366 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -31,7 +31,7 @@ import ( // DeviceOTKCounts adds one-time key counts to the /sync response func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { var queryRes keyapi.QueryOneTimeKeysResponse - keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ + _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ UserID: userID, DeviceID: deviceID, }, &queryRes) @@ -73,7 +73,7 @@ func DeviceListCatchup( offset = int64(from) } var queryRes keyapi.QueryKeyChangesResponse - keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ + _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ Offset: offset, ToOffset: toOffset, }, &queryRes) diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index c7d8df740..80d2811be 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -22,31 +22,41 @@ var ( type mockKeyAPI struct{} -func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) { +func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error { + return nil } func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) { +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error { + return nil } -func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) { +func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error { + return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) { +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error { + return nil } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) { +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error { + return nil } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) { +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error { + return nil } -func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { + return nil } -func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { + return nil } -func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) { +func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error { + return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) { +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error { + return nil } type mockRoomserverAPI struct { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index b10864ff5..931fef883 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -78,10 +78,11 @@ type syncKeyAPI struct { keyapi.SyncKeyAPI } -func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { + return nil } -func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { - +func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { + return nil } func TestSyncAPIAccessTokens(t *testing.T) { diff --git a/userapi/api/api.go b/userapi/api/api.go index df9408acb..388f97cb4 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -100,7 +100,7 @@ type ClientUserAPI interface { QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error - QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 6d8d28007..7e2f69615 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -94,9 +94,10 @@ func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *Per util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) return err } -func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { - t.Impl.QueryKeyBackup(ctx, req, res) +func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error { + err := t.Impl.QueryKeyBackup(ctx, req, res) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) + return err } func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error { err := t.Impl.QueryProfile(ctx, req, res) diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 422eb076e..78b226d46 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -192,7 +192,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) } deleteRes := &keyapi.PerformDeleteKeysResponse{} - a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes) + if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { + return err + } if err := deleteRes.Error; err != nil { return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err) } @@ -211,10 +213,12 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er } var uploadRes keyapi.PerformUploadKeysResponse - a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ UserID: userID, DeviceKeys: deviceKeys, - }, &uploadRes) + }, &uploadRes); err != nil { + return err + } if uploadRes.Error != nil { return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error) } @@ -268,7 +272,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { // display name has changed: update the device key var uploadRes keyapi.PerformUploadKeysResponse - a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ UserID: req.RequestingUserID, DeviceKeys: []keyapi.DeviceKeys{ { @@ -279,7 +283,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf }, }, OnlyDisplayNameUpdates: true, - }, &uploadRes) + }, &uploadRes); err != nil { + return err + } if uploadRes.Error != nil { return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error) } @@ -479,7 +485,9 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), } evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} - a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes) + if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { + return err + } if err := evacuateRes.Error; err != nil { logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation") } @@ -538,9 +546,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform if req.Version == "" { res.BadInput = true res.Error = "must specify a version to delete" - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) @@ -549,9 +554,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = exists res.Version = req.Version - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } // Create metadata @@ -562,9 +564,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = err == nil res.Version = version - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } // Update metadata @@ -575,16 +574,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = err == nil res.Version = req.Version - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } // Upload Keys for a specific version metadata a.uploadBackupKeys(ctx, req, res) - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } @@ -627,16 +620,16 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform res.KeyETag = etag } -func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error { version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { res.Exists = false - return + return nil } res.Error = fmt.Sprintf("failed to query key backup: %s", err) - return + return nil } res.Algorithm = algorithm res.AuthData = authData @@ -648,15 +641,16 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB if err != nil { res.Error = fmt.Sprintf("failed to count keys: %s", err) } - return + return nil } result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { res.Error = fmt.Sprintf("failed to query keys: %s", err) - return + return nil } res.Keys = result + return nil } func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 23c335cf2..a375d6caa 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP APIs @@ -84,11 +83,10 @@ type httpUserInternalAPI struct { } func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData") - defer span.Finish() - - apiURL := h.apiURL + InputAccountDataPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "InputAccountData", h.apiURL+InputAccountDataPath, + h.httpClient, ctx, req, res, + ) } func (h *httpUserInternalAPI) PerformAccountCreation( @@ -96,11 +94,10 @@ func (h *httpUserInternalAPI) PerformAccountCreation( request *api.PerformAccountCreationRequest, response *api.PerformAccountCreationResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformAccountCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformAccountCreation", h.apiURL+PerformAccountCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformPasswordUpdate( @@ -108,11 +105,10 @@ func (h *httpUserInternalAPI) PerformPasswordUpdate( request *api.PerformPasswordUpdateRequest, response *api.PerformPasswordUpdateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate") - defer span.Finish() - - apiURL := h.apiURL + PerformPasswordUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformDeviceCreation( @@ -120,11 +116,10 @@ func (h *httpUserInternalAPI) PerformDeviceCreation( request *api.PerformDeviceCreationRequest, response *api.PerformDeviceCreationResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformDeviceCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformDeviceDeletion( @@ -132,47 +127,54 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( request *api.PerformDeviceDeletionRequest, response *api.PerformDeviceDeletionResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion") - defer span.Finish() - - apiURL := h.apiURL + PerformDeviceDeletionPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformLastSeenUpdate( ctx context.Context, - req *api.PerformLastSeenUpdateRequest, - res *api.PerformLastSeenUpdateResponse, + request *api.PerformLastSeenUpdateRequest, + response *api.PerformLastSeenUpdateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen") - defer span.Finish() - - apiURL := h.apiURL + PerformLastSeenUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") - defer span.Finish() - - apiURL := h.apiURL + PerformDeviceUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformDeviceUpdate( + ctx context.Context, + request *api.PerformDeviceUpdateRequest, + response *api.PerformDeviceUpdateResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountDeactivation") - defer span.Finish() - - apiURL := h.apiURL + PerformAccountDeactivationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformAccountDeactivation( + ctx context.Context, + request *api.PerformAccountDeactivationRequest, + response *api.PerformAccountDeactivationResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformOpenIDTokenCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +func (h *httpUserInternalAPI) PerformOpenIDTokenCreation( + ctx context.Context, + request *api.PerformOpenIDTokenCreationRequest, + response *api.PerformOpenIDTokenCreationResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryProfile( @@ -180,11 +182,10 @@ func (h *httpUserInternalAPI) QueryProfile( request *api.QueryProfileRequest, response *api.QueryProfileResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile") - defer span.Finish() - - apiURL := h.apiURL + QueryProfilePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryProfile", h.apiURL+QueryProfilePath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryDeviceInfos( @@ -192,11 +193,10 @@ func (h *httpUserInternalAPI) QueryDeviceInfos( request *api.QueryDeviceInfosRequest, response *api.QueryDeviceInfosResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos") - defer span.Finish() - - apiURL := h.apiURL + QueryDeviceInfosPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryAccessToken( @@ -204,72 +204,87 @@ func (h *httpUserInternalAPI) QueryAccessToken( request *api.QueryAccessTokenRequest, response *api.QueryAccessTokenResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken") - defer span.Finish() - - apiURL := h.apiURL + QueryAccessTokenPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryAccessToken", h.apiURL+QueryAccessTokenPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices") - defer span.Finish() - - apiURL := h.apiURL + QueryDevicesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryDevices( + ctx context.Context, + request *api.QueryDevicesRequest, + response *api.QueryDevicesResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryDevices", h.apiURL+QueryDevicesPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData") - defer span.Finish() - - apiURL := h.apiURL + QueryAccountDataPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryAccountData( + ctx context.Context, + request *api.QueryAccountDataRequest, + response *api.QueryAccountDataResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountData", h.apiURL+QueryAccountDataPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles") - defer span.Finish() - - apiURL := h.apiURL + QuerySearchProfilesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QuerySearchProfiles( + ctx context.Context, + request *api.QuerySearchProfilesRequest, + response *api.QuerySearchProfilesResponse, +) error { + return httputil.CallInternalRPCAPI( + "QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken") - defer span.Finish() - - apiURL := h.apiURL + QueryOpenIDTokenPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryOpenIDToken( + ctx context.Context, + request *api.QueryOpenIDTokenRequest, + response *api.QueryOpenIDTokenResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") - defer span.Finish() - - apiURL := h.apiURL + PerformKeyBackupPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = err.Error() - } - return nil -} -func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") - defer span.Finish() - - apiURL := h.apiURL + QueryKeyBackupPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = err.Error() - } +func (h *httpUserInternalAPI) PerformKeyBackup( + ctx context.Context, + request *api.PerformKeyBackupRequest, + response *api.PerformKeyBackupResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformKeyBackup", h.apiURL+PerformKeyBackupPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications") - defer span.Finish() +func (h *httpUserInternalAPI) QueryKeyBackup( + ctx context.Context, + request *api.QueryKeyBackupRequest, + response *api.QueryKeyBackupResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryKeyBackup", h.apiURL+QueryKeyBackupPath, + h.httpClient, ctx, request, response, + ) +} - return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res) +func (h *httpUserInternalAPI) QueryNotifications( + ctx context.Context, + request *api.QueryNotificationsRequest, + response *api.QueryNotificationsResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryNotifications", h.apiURL+QueryNotificationsPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformPusherSet( @@ -277,27 +292,32 @@ func (h *httpUserInternalAPI) PerformPusherSet( request *api.PerformPusherSetRequest, response *struct{}, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet") - defer span.Finish() - - apiURL := h.apiURL + PerformPusherSetPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformPusherSet", h.apiURL+PerformPusherSetPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion") - defer span.Finish() - - apiURL := h.apiURL + PerformPusherDeletionPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformPusherDeletion( + ctx context.Context, + request *api.PerformPusherDeletionRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers") - defer span.Finish() - - apiURL := h.apiURL + QueryPushersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryPushers( + ctx context.Context, + request *api.QueryPushersRequest, + response *api.QueryPushersResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryPushers", h.apiURL+QueryPushersPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformPushRulesPut( @@ -305,89 +325,117 @@ func (h *httpUserInternalAPI) PerformPushRulesPut( request *api.PerformPushRulesPutRequest, response *struct{}, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut") - defer span.Finish() - - apiURL := h.apiURL + PerformPushRulesPutPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules") - defer span.Finish() - - apiURL := h.apiURL + QueryPushRulesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryPushRules( + ctx context.Context, + request *api.QueryPushRulesRequest, + response *api.QueryPushRulesResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryPushRules", h.apiURL+QueryPushRulesPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetAvatarURLPath) - defer span.Finish() - - apiURL := h.apiURL + PerformSetAvatarURLPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) SetAvatarURL( + ctx context.Context, + request *api.PerformSetAvatarURLRequest, + response *api.PerformSetAvatarURLResponse, +) error { + return httputil.CallInternalRPCAPI( + "SetAvatarURL", h.apiURL+PerformSetAvatarURLPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryNumericLocalpartPath) - defer span.Finish() - - apiURL := h.apiURL + QueryNumericLocalpartPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, struct{}{}, res) +func (h *httpUserInternalAPI) QueryNumericLocalpart( + ctx context.Context, + response *api.QueryNumericLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, + h.httpClient, ctx, &struct{}{}, response, + ) } -func (h *httpUserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountAvailabilityPath) - defer span.Finish() - - apiURL := h.apiURL + QueryAccountAvailabilityPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryAccountAvailability( + ctx context.Context, + request *api.QueryAccountAvailabilityRequest, + response *api.QueryAccountAvailabilityResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountByPasswordPath) - defer span.Finish() - - apiURL := h.apiURL + QueryAccountByPasswordPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryAccountByPassword( + ctx context.Context, + request *api.QueryAccountByPasswordRequest, + response *api.QueryAccountByPasswordResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetDisplayNamePath) - defer span.Finish() - - apiURL := h.apiURL + PerformSetDisplayNamePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) SetDisplayName( + ctx context.Context, + request *api.PerformUpdateDisplayNameRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForThreePIDPath) - defer span.Finish() - - apiURL := h.apiURL + QueryLocalpartForThreePIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryLocalpartForThreePID( + ctx context.Context, + request *api.QueryLocalpartForThreePIDRequest, + response *api.QueryLocalpartForThreePIDResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryThreePIDsForLocalpartPath) - defer span.Finish() - - apiURL := h.apiURL + QueryThreePIDsForLocalpartPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart( + ctx context.Context, + request *api.QueryThreePIDsForLocalpartRequest, + response *api.QueryThreePIDsForLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetThreePIDPath) - defer span.Finish() - - apiURL := h.apiURL + PerformForgetThreePIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformForgetThreePID( + ctx context.Context, + request *api.PerformForgetThreePIDRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveThreePIDAssociationPath) - defer span.Finish() - - apiURL := h.apiURL + PerformSaveThreePIDAssociationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( + ctx context.Context, + request *api.PerformSaveThreePIDAssociationRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath, + h.httpClient, ctx, request, response, + ) } diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go index 366a97099..211b1b7a1 100644 --- a/userapi/inthttp/client_logintoken.go +++ b/userapi/inthttp/client_logintoken.go @@ -19,7 +19,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/opentracing/opentracing-go" ) const ( @@ -33,11 +32,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenCreation( request *api.PerformLoginTokenCreationRequest, response *api.PerformLoginTokenCreationResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformLoginTokenCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformLoginTokenDeletion( @@ -45,11 +43,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenDeletion( request *api.PerformLoginTokenDeletionRequest, response *api.PerformLoginTokenDeletionResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") - defer span.Finish() - - apiURL := h.apiURL + PerformLoginTokenDeletionPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryLoginToken( @@ -57,9 +54,8 @@ func (h *httpUserInternalAPI) QueryLoginToken( request *api.QueryLoginTokenRequest, response *api.QueryLoginTokenResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") - defer span.Finish() - - apiURL := h.apiURL + QueryLoginTokenPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryLoginToken", h.apiURL+QueryLoginTokenPath, + h.httpClient, ctx, request, response, + ) } diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ad532b901..99148b760 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -15,8 +15,6 @@ package inthttp import ( - "encoding/json" - "fmt" "net/http" "github.com/gorilla/mux" @@ -29,339 +27,134 @@ import ( func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { addRoutesLoginToken(internalAPIMux, s) - internalAPIMux.Handle(PerformAccountCreationPath, - httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformAccountCreationRequest{} - response := api.PerformAccountCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformAccountCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformPasswordUpdatePath, - httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse { - request := api.PerformPasswordUpdateRequest{} - response := api.PerformPasswordUpdateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformDeviceCreationPath, - httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformDeviceCreationRequest{} - response := api.PerformDeviceCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformDeviceCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformLastSeenUpdatePath, - httputil.MakeInternalAPI("performLastSeenUpdate", func(req *http.Request) util.JSONResponse { - request := api.PerformLastSeenUpdateRequest{} - response := api.PerformLastSeenUpdateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformLastSeenUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformDeviceUpdatePath, - httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { - request := api.PerformDeviceUpdateRequest{} - response := api.PerformDeviceUpdateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformDeviceDeletionPath, - httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { - request := api.PerformDeviceDeletionRequest{} - response := api.PerformDeviceDeletionResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformAccountDeactivationPath, - httputil.MakeInternalAPI("performAccountDeactivation", func(req *http.Request) util.JSONResponse { - request := api.PerformAccountDeactivationRequest{} - response := api.PerformAccountDeactivationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformAccountDeactivation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformOpenIDTokenCreationPath, - httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformOpenIDTokenCreationRequest{} - response := api.PerformOpenIDTokenCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryProfilePath, - httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { - request := api.QueryProfileRequest{} - response := api.QueryProfileResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryProfile(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryAccessTokenPath, - httputil.MakeInternalAPI("queryAccessToken", func(req *http.Request) util.JSONResponse { - request := api.QueryAccessTokenRequest{} - response := api.QueryAccessTokenResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccessToken(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryDevicesPath, - httputil.MakeInternalAPI("queryDevices", func(req *http.Request) util.JSONResponse { - request := api.QueryDevicesRequest{} - response := api.QueryDevicesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryDevices(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryAccountDataPath, - httputil.MakeInternalAPI("queryAccountData", func(req *http.Request) util.JSONResponse { - request := api.QueryAccountDataRequest{} - response := api.QueryAccountDataResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccountData(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryDeviceInfosPath, - httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse { - request := api.QueryDeviceInfosRequest{} - response := api.QueryDeviceInfosResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QuerySearchProfilesPath, - httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse { - request := api.QuerySearchProfilesRequest{} - response := api.QuerySearchProfilesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryOpenIDTokenPath, - httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse { - request := api.QueryOpenIDTokenRequest{} - response := api.QueryOpenIDTokenResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(InputAccountDataPath, - httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse { - request := api.InputAccountDataRequest{} - response := api.InputAccountDataResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.InputAccountData(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryKeyBackupPath, - httputil.MakeInternalAPI("queryKeyBackup", func(req *http.Request) util.JSONResponse { - request := api.QueryKeyBackupRequest{} - response := api.QueryKeyBackupResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryKeyBackup(req.Context(), &request, &response) - if response.Error != "" { - return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", response.Error)) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformKeyBackupPath, - httputil.MakeInternalAPI("performKeyBackup", func(req *http.Request) util.JSONResponse { - request := api.PerformKeyBackupRequest{} - response := api.PerformKeyBackupResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - err := s.PerformKeyBackup(req.Context(), &request, &response) - if err != nil { - return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryNotificationsPath, - httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse { - var request api.QueryNotificationsRequest - var response api.QueryNotificationsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryNotifications(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformAccountCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation), ) - internalAPIMux.Handle(PerformPusherSetPath, - httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse { - request := api.PerformPusherSetRequest{} - response := struct{}{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformPusherDeletionPath, - httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse { - request := api.PerformPusherDeletionRequest{} - response := struct{}{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformPasswordUpdatePath, + httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate), ) - internalAPIMux.Handle(QueryPushersPath, - httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse { - request := api.QueryPushersRequest{} - response := api.QueryPushersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryPushers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformDeviceCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation), ) - internalAPIMux.Handle(PerformPushRulesPutPath, - httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse { - request := api.PerformPushRulesPutRequest{} - response := struct{}{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformLastSeenUpdatePath, + httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate), ) - internalAPIMux.Handle(QueryPushRulesPath, - httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse { - request := api.QueryPushRulesRequest{} - response := api.QueryPushRulesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryPushRules(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformDeviceUpdatePath, + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate), ) - internalAPIMux.Handle(PerformSetAvatarURLPath, - httputil.MakeInternalAPI("performSetAvatarURL", func(req *http.Request) util.JSONResponse { - request := api.PerformSetAvatarURLRequest{} - response := api.PerformSetAvatarURLResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.SetAvatarURL(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformDeviceDeletionPath, + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion), ) + + internalAPIMux.Handle( + PerformAccountDeactivationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation), + ) + + internalAPIMux.Handle( + PerformOpenIDTokenCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation), + ) + + internalAPIMux.Handle( + QueryProfilePath, + httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile), + ) + + internalAPIMux.Handle( + QueryAccessTokenPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken), + ) + + internalAPIMux.Handle( + QueryDevicesPath, + httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices), + ) + + internalAPIMux.Handle( + QueryAccountDataPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData), + ) + + internalAPIMux.Handle( + QueryDeviceInfosPath, + httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos), + ) + + internalAPIMux.Handle( + QuerySearchProfilesPath, + httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles), + ) + + internalAPIMux.Handle( + QueryOpenIDTokenPath, + httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken), + ) + + internalAPIMux.Handle( + InputAccountDataPath, + httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData), + ) + + internalAPIMux.Handle( + QueryKeyBackupPath, + httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup), + ) + + internalAPIMux.Handle( + PerformKeyBackupPath, + httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup), + ) + + internalAPIMux.Handle( + QueryNotificationsPath, + httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications), + ) + + internalAPIMux.Handle( + PerformPusherSetPath, + httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet), + ) + + internalAPIMux.Handle( + PerformPusherDeletionPath, + httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion), + ) + + internalAPIMux.Handle( + QueryPushersPath, + httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers), + ) + + internalAPIMux.Handle( + PerformPushRulesPutPath, + httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut), + ) + + internalAPIMux.Handle( + QueryPushRulesPath, + httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules), + ) + + internalAPIMux.Handle( + PerformSetAvatarURLPath, + httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), + ) + + // TODO: Look at the shape of this internalAPIMux.Handle(QueryNumericLocalpartPath, - httputil.MakeInternalAPI("queryNumericLocalpart", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { response := api.QueryNumericLocalpartResponse{} if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil { return util.ErrorResponse(err) @@ -369,92 +162,39 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(QueryAccountAvailabilityPath, - httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse { - request := api.QueryAccountAvailabilityRequest{} - response := api.QueryAccountAvailabilityResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryAccountAvailabilityPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability), ) - internalAPIMux.Handle(QueryAccountByPasswordPath, - httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse { - request := api.QueryAccountByPasswordRequest{} - response := api.QueryAccountByPasswordResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccountByPassword(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryAccountByPasswordPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword), ) - internalAPIMux.Handle(PerformSetDisplayNamePath, - httputil.MakeInternalAPI("performSetDisplayName", func(req *http.Request) util.JSONResponse { - request := api.PerformUpdateDisplayNameRequest{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.SetDisplayName(req.Context(), &request, &struct{}{}); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} - }), + + internalAPIMux.Handle( + PerformSetDisplayNamePath, + httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName), ) - internalAPIMux.Handle(QueryLocalpartForThreePIDPath, - httputil.MakeInternalAPI("queryLocalpartForThreePID", func(req *http.Request) util.JSONResponse { - request := api.QueryLocalpartForThreePIDRequest{} - response := api.QueryLocalpartForThreePIDResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryLocalpartForThreePID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryLocalpartForThreePIDPath, + httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID), ) - internalAPIMux.Handle(QueryThreePIDsForLocalpartPath, - httputil.MakeInternalAPI("queryThreePIDsForLocalpart", func(req *http.Request) util.JSONResponse { - request := api.QueryThreePIDsForLocalpartRequest{} - response := api.QueryThreePIDsForLocalpartResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryThreePIDsForLocalpart(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryThreePIDsForLocalpartPath, + httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart), ) - internalAPIMux.Handle(PerformForgetThreePIDPath, - httputil.MakeInternalAPI("performForgetThreePID", func(req *http.Request) util.JSONResponse { - request := api.PerformForgetThreePIDRequest{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformForgetThreePID(req.Context(), &request, &struct{}{}); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} - }), + + internalAPIMux.Handle( + PerformForgetThreePIDPath, + httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID), ) - internalAPIMux.Handle(PerformSaveThreePIDAssociationPath, - httputil.MakeInternalAPI("performSaveThreePIDAssociation", func(req *http.Request) util.JSONResponse { - request := api.PerformSaveThreePIDAssociationRequest{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformSaveThreePIDAssociation(req.Context(), &request, &struct{}{}); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} - }), + + internalAPIMux.Handle( + PerformSaveThreePIDAssociationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation), ) } diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go index 1f2eb34b9..b57348413 100644 --- a/userapi/inthttp/server_logintoken.go +++ b/userapi/inthttp/server_logintoken.go @@ -15,54 +15,25 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // addRoutesLoginToken adds routes for all login token API calls. func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { - internalAPIMux.Handle(PerformLoginTokenCreationPath, - httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformLoginTokenCreationRequest{} - response := api.PerformLoginTokenCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformLoginTokenCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation), ) - internalAPIMux.Handle(PerformLoginTokenDeletionPath, - httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { - request := api.PerformLoginTokenDeletionRequest{} - response := api.PerformLoginTokenDeletionResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformLoginTokenDeletionPath, + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion), ) - internalAPIMux.Handle(QueryLoginTokenPath, - httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { - request := api.QueryLoginTokenRequest{} - response := api.QueryLoginTokenResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryLoginTokenPath, + httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken), ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 40e37c5d6..31a69793b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -117,16 +117,20 @@ func TestQueryProfile(t *testing.T) { }, } - runCases := func(testAPI api.UserInternalAPI) { + runCases := func(testAPI api.UserInternalAPI, http bool) { + mode := "monolith" + if http { + mode = "HTTP" + } for _, tc := range testCases { var gotRes api.QueryProfileResponse gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes) if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil { - t.Errorf("QueryProfile error, got %s want %s", gotErr, tc.wantErr) + t.Errorf("QueryProfile %s error, got %s want %s", mode, gotErr, tc.wantErr) continue } if !reflect.DeepEqual(tc.wantRes, gotRes) { - t.Errorf("QueryProfile response got %+v want %+v", gotRes, tc.wantRes) + t.Errorf("QueryProfile %s response got %+v want %+v", mode, gotRes, tc.wantRes) } } } @@ -140,10 +144,10 @@ func TestQueryProfile(t *testing.T) { if err != nil { t.Fatalf("failed to create HTTP client") } - runCases(httpAPI) + runCases(httpAPI, true) }) t.Run("Monolith", func(t *testing.T) { - runCases(userAPI) + runCases(userAPI, false) }) }