From d67e40684f43b0eb744cad26e0265002f033dbc3 Mon Sep 17 00:00:00 2001 From: Jason Song Date: Mon, 3 Apr 2023 16:42:38 +0800 Subject: [PATCH] Improve LoadUnitConfig to handle invalid or duplicate units (#23736) The old code just parses an invalid key to `TypeInvalid` and uses it as normal, and duplicate keys will be kept. So this PR will ignore invalid key and log warning and also deduplicate valid units. --- models/unit/unit.go | 39 ++++++++++++++++------------ models/unit/unit_test.go | 53 ++++++++++++++++++++++++++++++++++++++ routers/api/v1/org/team.go | 2 +- 3 files changed, 77 insertions(+), 17 deletions(-) create mode 100644 models/unit/unit_test.go diff --git a/models/unit/unit.go b/models/unit/unit.go index 883f443cbe..3d5a8842cd 100644 --- a/models/unit/unit.go +++ b/models/unit/unit.go @@ -151,7 +151,11 @@ func validateDefaultRepoUnits(defaultUnits, settingDefaultUnits []Type) []Type { // LoadUnitConfig load units from settings func LoadUnitConfig() { - DisabledRepoUnits = FindUnitTypes(setting.Repository.DisabledRepoUnits...) + var invalidKeys []string + DisabledRepoUnits, invalidKeys = FindUnitTypes(setting.Repository.DisabledRepoUnits...) + if len(invalidKeys) > 0 { + log.Warn("Invalid keys in disabled repo units: %s", strings.Join(invalidKeys, ", ")) + } // Check that must units are not disabled for i, disabledU := range DisabledRepoUnits { if !disabledU.CanDisable() { @@ -160,9 +164,15 @@ func LoadUnitConfig() { } } - setDefaultRepoUnits := FindUnitTypes(setting.Repository.DefaultRepoUnits...) + setDefaultRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultRepoUnits...) + if len(invalidKeys) > 0 { + log.Warn("Invalid keys in default repo units: %s", strings.Join(invalidKeys, ", ")) + } DefaultRepoUnits = validateDefaultRepoUnits(DefaultRepoUnits, setDefaultRepoUnits) - setDefaultForkRepoUnits := FindUnitTypes(setting.Repository.DefaultForkRepoUnits...) + setDefaultForkRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultForkRepoUnits...) + if len(invalidKeys) > 0 { + log.Warn("Invalid keys in default fork repo units: %s", strings.Join(invalidKeys, ", ")) + } DefaultForkRepoUnits = validateDefaultRepoUnits(DefaultForkRepoUnits, setDefaultForkRepoUnits) } @@ -334,22 +344,19 @@ var ( } ) -// FindUnitTypes give the unit key names and return unit -func FindUnitTypes(nameKeys ...string) (res []Type) { +// FindUnitTypes give the unit key names and return valid unique units and invalid keys +func FindUnitTypes(nameKeys ...string) (res []Type, invalidKeys []string) { + m := map[Type]struct{}{} for _, key := range nameKeys { - var found bool - for t, u := range Units { - if strings.EqualFold(key, u.NameKey) { - res = append(res, t) - found = true - break - } - } - if !found { - res = append(res, TypeInvalid) + t := TypeFromKey(key) + if t == TypeInvalid { + invalidKeys = append(invalidKeys, key) + } else if _, ok := m[t]; !ok { + res = append(res, t) + m[t] = struct{}{} } } - return res + return res, invalidKeys } // TypeFromKey give the unit key name and return unit diff --git a/models/unit/unit_test.go b/models/unit/unit_test.go new file mode 100644 index 0000000000..50d7817197 --- /dev/null +++ b/models/unit/unit_test.go @@ -0,0 +1,53 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unit + +import ( + "testing" + + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" +) + +func TestLoadUnitConfig(t *testing.T) { + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) { + DisabledRepoUnits = disabledRepoUnits + DefaultRepoUnits = defaultRepoUnits + DefaultForkRepoUnits = defaultForkRepoUnits + }(DisabledRepoUnits, DefaultRepoUnits, DefaultForkRepoUnits) + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) { + setting.Repository.DisabledRepoUnits = disabledRepoUnits + setting.Repository.DefaultRepoUnits = defaultRepoUnits + setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits + }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits) + + t.Run("regular", func(t *testing.T) { + setting.Repository.DisabledRepoUnits = []string{"repo.issues"} + setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls"} + setting.Repository.DefaultForkRepoUnits = []string{"repo.releases"} + LoadUnitConfig() + assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits) + assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits) + assert.Equal(t, []Type{TypeCode, TypeReleases}, DefaultForkRepoUnits) + }) + t.Run("invalid", func(t *testing.T) { + setting.Repository.DisabledRepoUnits = []string{"repo.issues", "invalid.1"} + setting.Repository.DefaultRepoUnits = []string{"repo.code", "invalid.2", "repo.releases", "repo.issues", "repo.pulls"} + setting.Repository.DefaultForkRepoUnits = []string{"invalid.3", "repo.releases"} + LoadUnitConfig() + assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits) + assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits) + assert.Equal(t, []Type{TypeCode, TypeReleases}, DefaultForkRepoUnits) + }) + t.Run("duplicate", func(t *testing.T) { + setting.Repository.DisabledRepoUnits = []string{"repo.issues", "repo.issues"} + setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls", "repo.code"} + setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"} + LoadUnitConfig() + assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits) + assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits) + assert.Equal(t, []Type{TypeCode, TypeReleases}, DefaultForkRepoUnits) + }) +} diff --git a/routers/api/v1/org/team.go b/routers/api/v1/org/team.go index 0c6926759a..597f846206 100644 --- a/routers/api/v1/org/team.go +++ b/routers/api/v1/org/team.go @@ -135,7 +135,7 @@ func GetTeam(ctx *context.APIContext) { } func attachTeamUnits(team *organization.Team, units []string) { - unitTypes := unit_model.FindUnitTypes(units...) + unitTypes, _ := unit_model.FindUnitTypes(units...) team.Units = make([]*organization.TeamUnit, 0, len(units)) for _, tp := range unitTypes { team.Units = append(team.Units, &organization.TeamUnit{