fix: use better code to group UID and stopwatches

- Instead of having code that relied on the result being sorted (which
wasn't specified in the query and therefore not safe to assume so). Use
a map where it doesn't care if the result that we get from the database
is sorted or not.
- Added unit test.

(cherry picked from commit e4eb82b738)
This commit is contained in:
Gusted 2024-11-16 15:51:59 +01:00 committed by forgejo-backport-action
parent 004fe296cc
commit 35435c573a
4 changed files with 58 additions and 22 deletions

View file

@ -0,0 +1,11 @@
-
id: 3
user_id: 1
issue_id: 2
created_unix: 1500988004
-
id: 4
user_id: 3
issue_id: 0
created_unix: 1500988003

View file

@ -60,34 +60,19 @@ func getStopwatch(ctx context.Context, userID, issueID int64) (sw *Stopwatch, ex
return sw, exists, err return sw, exists, err
} }
// UserIDCount is a simple coalition of UserID and Count
type UserStopwatch struct {
UserID int64
StopWatches []*Stopwatch
}
// GetUIDsAndNotificationCounts between the two provided times // GetUIDsAndNotificationCounts between the two provided times
func GetUIDsAndStopwatch(ctx context.Context) ([]*UserStopwatch, error) { func GetUIDsAndStopwatch(ctx context.Context) (map[int64][]*Stopwatch, error) {
sws := []*Stopwatch{} sws := []*Stopwatch{}
if err := db.GetEngine(ctx).Where("issue_id != 0").Find(&sws); err != nil { if err := db.GetEngine(ctx).Where("issue_id != 0").Find(&sws); err != nil {
return nil, err return nil, err
} }
res := map[int64][]*Stopwatch{}
if len(sws) == 0 { if len(sws) == 0 {
return []*UserStopwatch{}, nil return res, nil
} }
lastUserID := int64(-1)
res := []*UserStopwatch{}
for _, sw := range sws { for _, sw := range sws {
if lastUserID == sw.UserID { res[sw.UserID] = append(res[sw.UserID], sw)
lastUserStopwatch := res[len(res)-1]
lastUserStopwatch.StopWatches = append(lastUserStopwatch.StopWatches, sw)
} else {
res = append(res, &UserStopwatch{
UserID: sw.UserID,
StopWatches: []*Stopwatch{sw},
})
}
} }
return res, nil return res, nil
} }

View file

@ -4,12 +4,14 @@
package issues_test package issues_test
import ( import (
"path/filepath"
"testing" "testing"
"code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues" issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user" user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/timeutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -77,3 +79,41 @@ func TestCreateOrStopIssueStopwatch(t *testing.T) {
unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: 2, IssueID: 2}) unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: 2, IssueID: 2})
unittest.AssertExistsAndLoadBean(t, &issues_model.TrackedTime{UserID: 2, IssueID: 2}) unittest.AssertExistsAndLoadBean(t, &issues_model.TrackedTime{UserID: 2, IssueID: 2})
} }
func TestGetUIDsAndStopwatch(t *testing.T) {
defer unittest.OverrideFixtures(
unittest.FixturesOptions{
Dir: filepath.Join(setting.AppWorkPath, "models/fixtures/"),
Base: setting.AppWorkPath,
Dirs: []string{"models/issues/TestGetUIDsAndStopwatch/"},
},
)()
require.NoError(t, unittest.PrepareTestDatabase())
uidStopwatches, err := issues_model.GetUIDsAndStopwatch(db.DefaultContext)
require.NoError(t, err)
assert.EqualValues(t, map[int64][]*issues_model.Stopwatch{
1: {
{
ID: 1,
UserID: 1,
IssueID: 1,
CreatedUnix: timeutil.TimeStamp(1500988001),
},
{
ID: 3,
UserID: 1,
IssueID: 2,
CreatedUnix: timeutil.TimeStamp(1500988004),
},
},
2: {
{
ID: 2,
UserID: 2,
IssueID: 2,
CreatedUnix: timeutil.TimeStamp(1500988002),
},
},
}, uidStopwatches)
}

View file

@ -90,8 +90,8 @@ loop:
return return
} }
for _, userStopwatches := range usersStopwatches { for uid, stopwatches := range usersStopwatches {
apiSWs, err := convert.ToStopWatches(ctx, userStopwatches.StopWatches) apiSWs, err := convert.ToStopWatches(ctx, stopwatches)
if err != nil { if err != nil {
if !issues_model.IsErrIssueNotExist(err) { if !issues_model.IsErrIssueNotExist(err) {
log.Error("Unable to APIFormat stopwatches: %v", err) log.Error("Unable to APIFormat stopwatches: %v", err)
@ -103,7 +103,7 @@ loop:
log.Error("Unable to marshal stopwatches: %v", err) log.Error("Unable to marshal stopwatches: %v", err)
continue continue
} }
m.SendMessage(userStopwatches.UserID, &Event{ m.SendMessage(uid, &Event{
Name: "stopwatches", Name: "stopwatches",
Data: string(dataBs), Data: string(dataBs),
}) })