0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-12-15 16:33:44 +01:00
dendrite/internal/httputil/http.go

94 lines
2.5 KiB
Go
Raw Normal View History

// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package httputil
2018-07-17 16:36:04 +02:00
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
2018-07-17 16:36:04 +02:00
"net/http"
"net/url"
"strings"
2018-07-17 16:36:04 +02:00
"github.com/opentracing/opentracing-go"
2018-07-17 16:36:04 +02:00
"github.com/opentracing/opentracing-go/ext"
)
// PostJSON performs a POST request with JSON on an internal HTTP API.
// The error will match the errtype if returned from the remote API, or
// will be a different type if there was a problem reaching the API.
func PostJSON[reqtype, restype any, errtype error](
2018-07-17 16:36:04 +02:00
ctx context.Context, span opentracing.Span, httpClient *http.Client,
apiURL string, request *reqtype, response *restype,
2018-07-17 16:36:04 +02:00
) error {
jsonBytes, err := json.Marshal(request)
if err != nil {
return err
}
parsedAPIURL, err := url.Parse(apiURL)
if err != nil {
return err
}
parsedAPIURL.Path = InternalPathPrefix + strings.TrimLeft(parsedAPIURL.Path, "/")
apiURL = parsedAPIURL.String()
2018-07-17 16:36:04 +02:00
req, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewReader(jsonBytes))
if err != nil {
return err
}
// Mark the span as being an RPC client.
ext.SpanKindRPCClient.Set(span)
carrier := opentracing.HTTPHeadersCarrier(req.Header)
tracer := opentracing.GlobalTracer()
if err = tracer.Inject(span.Context(), opentracing.HTTPHeaders, carrier); err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
res, err := httpClient.Do(req.WithContext(ctx))
if res != nil {
defer (func() { err = res.Body.Close() })()
}
if err != nil {
return err
}
var body []byte
body, err = io.ReadAll(res.Body)
if err != nil {
return err
}
2018-07-17 16:36:04 +02:00
if res.StatusCode != http.StatusOK {
if len(body) == 0 {
return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL)
2018-07-17 16:36:04 +02:00
}
var reserr errtype
if err = json.Unmarshal(body, &reserr); err != nil {
return fmt.Errorf("HTTP %d from %s - %w", res.StatusCode, apiURL, err)
}
return reserr
}
if err = json.Unmarshal(body, response); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
2018-07-17 16:36:04 +02:00
}
return nil
2018-07-17 16:36:04 +02:00
}