mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2024-11-04 01:10:49 +01:00
Return 400 but not 500 when request archive with wrong format (#17691)
This commit is contained in:
parent
d8a8961b99
commit
81a4fc7528
4 changed files with 80 additions and 9 deletions
52
integrations/api_repo_archive_test.go
Normal file
52
integrations/api_repo_archive_test.go
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package integrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"code.gitea.io/gitea/models"
|
||||||
|
"code.gitea.io/gitea/models/unittest"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAPIDownloadArchive(t *testing.T) {
|
||||||
|
defer prepareTestEnv(t)()
|
||||||
|
|
||||||
|
repo := unittest.AssertExistsAndLoadBean(t, &models.Repository{ID: 1}).(*models.Repository)
|
||||||
|
user2 := unittest.AssertExistsAndLoadBean(t, &models.User{ID: 2}).(*models.User)
|
||||||
|
session := loginUser(t, user2.LowerName)
|
||||||
|
token := getTokenForLoggedInUser(t, session)
|
||||||
|
|
||||||
|
link, _ := url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.zip", user2.Name, repo.Name))
|
||||||
|
link.RawQuery = url.Values{"token": {token}}.Encode()
|
||||||
|
resp := MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK)
|
||||||
|
bs, err := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 320, len(bs))
|
||||||
|
|
||||||
|
link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.tar.gz", user2.Name, repo.Name))
|
||||||
|
link.RawQuery = url.Values{"token": {token}}.Encode()
|
||||||
|
resp = MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK)
|
||||||
|
bs, err = io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 266, len(bs))
|
||||||
|
|
||||||
|
link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.bundle", user2.Name, repo.Name))
|
||||||
|
link.RawQuery = url.Values{"token": {token}}.Encode()
|
||||||
|
resp = MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK)
|
||||||
|
bs, err = io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 382, len(bs))
|
||||||
|
|
||||||
|
link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master", user2.Name, repo.Name))
|
||||||
|
link.RawQuery = url.Values{"token": {token}}.Encode()
|
||||||
|
MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusBadRequest)
|
||||||
|
}
|
|
@ -373,7 +373,11 @@ func Download(ctx *context.Context) {
|
||||||
uri := ctx.Params("*")
|
uri := ctx.Params("*")
|
||||||
aReq, err := archiver_service.NewRequest(ctx.Repo.Repository.ID, ctx.Repo.GitRepo, uri)
|
aReq, err := archiver_service.NewRequest(ctx.Repo.Repository.ID, ctx.Repo.GitRepo, uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.ServerError("archiver_service.NewRequest", err)
|
if errors.Is(err, archiver_service.ErrUnknownArchiveFormat{}) {
|
||||||
|
ctx.Error(http.StatusBadRequest, err.Error())
|
||||||
|
} else {
|
||||||
|
ctx.ServerError("archiver_service.NewRequest", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if aReq == nil {
|
if aReq == nil {
|
||||||
|
|
|
@ -39,6 +39,22 @@ type ArchiveRequest struct {
|
||||||
// the way to 64.
|
// the way to 64.
|
||||||
var shaRegex = regexp.MustCompile(`^[0-9a-f]{4,64}$`)
|
var shaRegex = regexp.MustCompile(`^[0-9a-f]{4,64}$`)
|
||||||
|
|
||||||
|
// ErrUnknownArchiveFormat request archive format is not supported
|
||||||
|
type ErrUnknownArchiveFormat struct {
|
||||||
|
RequestFormat string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements error
|
||||||
|
func (err ErrUnknownArchiveFormat) Error() string {
|
||||||
|
return fmt.Sprintf("unknown format: %s", err.RequestFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is implements error
|
||||||
|
func (ErrUnknownArchiveFormat) Is(err error) bool {
|
||||||
|
_, ok := err.(ErrUnknownArchiveFormat)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// NewRequest creates an archival request, based on the URI. The
|
// NewRequest creates an archival request, based on the URI. The
|
||||||
// resulting ArchiveRequest is suitable for being passed to ArchiveRepository()
|
// resulting ArchiveRequest is suitable for being passed to ArchiveRepository()
|
||||||
// if it's determined that the request still needs to be satisfied.
|
// if it's determined that the request still needs to be satisfied.
|
||||||
|
@ -59,7 +75,7 @@ func NewRequest(repoID int64, repo *git.Repository, uri string) (*ArchiveRequest
|
||||||
ext = ".bundle"
|
ext = ".bundle"
|
||||||
r.Type = git.BUNDLE
|
r.Type = git.BUNDLE
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknown format: %s", uri)
|
return nil, ErrUnknownArchiveFormat{RequestFormat: uri}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.refName = strings.TrimSuffix(uri, ext)
|
r.refName = strings.TrimSuffix(uri, ext)
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package archiver
|
package archiver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -19,10 +20,6 @@ func TestMain(m *testing.M) {
|
||||||
unittest.MainTest(m, filepath.Join("..", ".."))
|
unittest.MainTest(m, filepath.Join("..", ".."))
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForCount(t *testing.T, num int) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestArchive_Basic(t *testing.T) {
|
func TestArchive_Basic(t *testing.T) {
|
||||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||||
|
|
||||||
|
@ -83,11 +80,8 @@ func TestArchive_Basic(t *testing.T) {
|
||||||
inFlight[2] = secondReq
|
inFlight[2] = secondReq
|
||||||
|
|
||||||
ArchiveRepository(zipReq)
|
ArchiveRepository(zipReq)
|
||||||
waitForCount(t, 1)
|
|
||||||
ArchiveRepository(tgzReq)
|
ArchiveRepository(tgzReq)
|
||||||
waitForCount(t, 2)
|
|
||||||
ArchiveRepository(secondReq)
|
ArchiveRepository(secondReq)
|
||||||
waitForCount(t, 3)
|
|
||||||
|
|
||||||
// Make sure sending an unprocessed request through doesn't affect the queue
|
// Make sure sending an unprocessed request through doesn't affect the queue
|
||||||
// count.
|
// count.
|
||||||
|
@ -132,3 +126,8 @@ func TestArchive_Basic(t *testing.T) {
|
||||||
assert.NotEqual(t, zipReq.GetArchiveName(), tgzReq.GetArchiveName())
|
assert.NotEqual(t, zipReq.GetArchiveName(), tgzReq.GetArchiveName())
|
||||||
assert.NotEqual(t, zipReq.GetArchiveName(), secondReq.GetArchiveName())
|
assert.NotEqual(t, zipReq.GetArchiveName(), secondReq.GetArchiveName())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestErrUnknownArchiveFormat(t *testing.T) {
|
||||||
|
var err = ErrUnknownArchiveFormat{RequestFormat: "master"}
|
||||||
|
assert.True(t, errors.Is(err, ErrUnknownArchiveFormat{}))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue