Return 400 but not 500 when request archive with wrong format (#17691)

This commit is contained in:
Lunny Xiao 2021-11-18 03:47:35 +08:00 committed by GitHub
parent d8a8961b99
commit 81a4fc7528
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 9 deletions

View file

@ -0,0 +1,52 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package integrations
import (
"fmt"
"io"
"net/http"
"net/url"
"testing"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestAPIDownloadArchive(t *testing.T) {
defer prepareTestEnv(t)()
repo := unittest.AssertExistsAndLoadBean(t, &models.Repository{ID: 1}).(*models.Repository)
user2 := unittest.AssertExistsAndLoadBean(t, &models.User{ID: 2}).(*models.User)
session := loginUser(t, user2.LowerName)
token := getTokenForLoggedInUser(t, session)
link, _ := url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.zip", user2.Name, repo.Name))
link.RawQuery = url.Values{"token": {token}}.Encode()
resp := MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK)
bs, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.EqualValues(t, 320, len(bs))
link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.tar.gz", user2.Name, repo.Name))
link.RawQuery = url.Values{"token": {token}}.Encode()
resp = MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK)
bs, err = io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.EqualValues(t, 266, len(bs))
link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master.bundle", user2.Name, repo.Name))
link.RawQuery = url.Values{"token": {token}}.Encode()
resp = MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusOK)
bs, err = io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.EqualValues(t, 382, len(bs))
link, _ = url.Parse(fmt.Sprintf("/api/v1/repos/%s/%s/archive/master", user2.Name, repo.Name))
link.RawQuery = url.Values{"token": {token}}.Encode()
MakeRequest(t, NewRequest(t, "GET", link.String()), http.StatusBadRequest)
}

View file

@ -373,7 +373,11 @@ func Download(ctx *context.Context) {
uri := ctx.Params("*") uri := ctx.Params("*")
aReq, err := archiver_service.NewRequest(ctx.Repo.Repository.ID, ctx.Repo.GitRepo, uri) aReq, err := archiver_service.NewRequest(ctx.Repo.Repository.ID, ctx.Repo.GitRepo, uri)
if err != nil { if err != nil {
ctx.ServerError("archiver_service.NewRequest", err) if errors.Is(err, archiver_service.ErrUnknownArchiveFormat{}) {
ctx.Error(http.StatusBadRequest, err.Error())
} else {
ctx.ServerError("archiver_service.NewRequest", err)
}
return return
} }
if aReq == nil { if aReq == nil {

View file

@ -39,6 +39,22 @@ type ArchiveRequest struct {
// the way to 64. // the way to 64.
var shaRegex = regexp.MustCompile(`^[0-9a-f]{4,64}$`) var shaRegex = regexp.MustCompile(`^[0-9a-f]{4,64}$`)
// ErrUnknownArchiveFormat request archive format is not supported
type ErrUnknownArchiveFormat struct {
RequestFormat string
}
// Error implements error
func (err ErrUnknownArchiveFormat) Error() string {
return fmt.Sprintf("unknown format: %s", err.RequestFormat)
}
// Is implements error
func (ErrUnknownArchiveFormat) Is(err error) bool {
_, ok := err.(ErrUnknownArchiveFormat)
return ok
}
// NewRequest creates an archival request, based on the URI. The // NewRequest creates an archival request, based on the URI. The
// resulting ArchiveRequest is suitable for being passed to ArchiveRepository() // resulting ArchiveRequest is suitable for being passed to ArchiveRepository()
// if it's determined that the request still needs to be satisfied. // if it's determined that the request still needs to be satisfied.
@ -59,7 +75,7 @@ func NewRequest(repoID int64, repo *git.Repository, uri string) (*ArchiveRequest
ext = ".bundle" ext = ".bundle"
r.Type = git.BUNDLE r.Type = git.BUNDLE
default: default:
return nil, fmt.Errorf("Unknown format: %s", uri) return nil, ErrUnknownArchiveFormat{RequestFormat: uri}
} }
r.refName = strings.TrimSuffix(uri, ext) r.refName = strings.TrimSuffix(uri, ext)

View file

@ -5,6 +5,7 @@
package archiver package archiver
import ( import (
"errors"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
@ -19,10 +20,6 @@ func TestMain(m *testing.M) {
unittest.MainTest(m, filepath.Join("..", "..")) unittest.MainTest(m, filepath.Join("..", ".."))
} }
func waitForCount(t *testing.T, num int) {
}
func TestArchive_Basic(t *testing.T) { func TestArchive_Basic(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase()) assert.NoError(t, unittest.PrepareTestDatabase())
@ -83,11 +80,8 @@ func TestArchive_Basic(t *testing.T) {
inFlight[2] = secondReq inFlight[2] = secondReq
ArchiveRepository(zipReq) ArchiveRepository(zipReq)
waitForCount(t, 1)
ArchiveRepository(tgzReq) ArchiveRepository(tgzReq)
waitForCount(t, 2)
ArchiveRepository(secondReq) ArchiveRepository(secondReq)
waitForCount(t, 3)
// Make sure sending an unprocessed request through doesn't affect the queue // Make sure sending an unprocessed request through doesn't affect the queue
// count. // count.
@ -132,3 +126,8 @@ func TestArchive_Basic(t *testing.T) {
assert.NotEqual(t, zipReq.GetArchiveName(), tgzReq.GetArchiveName()) assert.NotEqual(t, zipReq.GetArchiveName(), tgzReq.GetArchiveName())
assert.NotEqual(t, zipReq.GetArchiveName(), secondReq.GetArchiveName()) assert.NotEqual(t, zipReq.GetArchiveName(), secondReq.GetArchiveName())
} }
func TestErrUnknownArchiveFormat(t *testing.T) {
var err = ErrUnknownArchiveFormat{RequestFormat: "master"}
assert.True(t, errors.Is(err, ErrUnknownArchiveFormat{}))
}