diff --git a/database/portal.go b/database/portal.go index cb6c050..62802d0 100644 --- a/database/portal.go +++ b/database/portal.go @@ -18,6 +18,7 @@ package database import ( "database/sql" + "fmt" log "maunium.net/go/maulogger/v2" @@ -63,32 +64,26 @@ func (pq *PortalQuery) New() *Portal { } } -func (pq *PortalQuery) GetAll() []*Portal { - return pq.getAll("SELECT * FROM portal") -} +const portalColumns = "jid, receiver, mxid, name, topic, avatar, avatar_url, encrypted, first_event_id, next_batch_id, relay_user_id, expiration_time" -func (pq *PortalQuery) GetAllForUser(userID id.UserID) []*Portal { - return pq.getAll(` - SELECT p.* FROM portal p - LEFT JOIN user_portal up ON p.jid=up.portal_jid AND p.receiver=up.portal_receiver - WHERE mxid<>'' AND up.user_mxid=$1 - `, userID) +func (pq *PortalQuery) GetAll() []*Portal { + return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns)) } func (pq *PortalQuery) GetByJID(key PortalKey) *Portal { - return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver) + return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver) } func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { - return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid) + return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid) } func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal { - return pq.getAll("SELECT * FROM portal WHERE jid=$1", jid.ToNonAD()) + return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD()) } func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal { - return pq.getAll("SELECT * FROM portal WHERE receiver=$1 AND jid LIKE '%@s.whatsapp.net'", receiver.ToNonAD()) + return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%@s.whatsapp.net'", portalColumns), receiver.ToNonAD()) } func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) { diff --git a/go.mod b/go.mod index d37c403..1a5ccff 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( golang.org/x/net v0.0.0-20220513224357-95641704303c google.golang.org/protobuf v1.28.0 maunium.net/go/maulogger/v2 v2.3.2 - maunium.net/go/mautrix v0.11.1-0.20220623172243-579a4753b77a + maunium.net/go/mautrix v0.11.1-0.20220624140129-4eb8a89ebea6 ) require ( diff --git a/go.sum b/go.sum index e810f55..7245fc1 100644 --- a/go.sum +++ b/go.sum @@ -107,5 +107,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= -maunium.net/go/mautrix v0.11.1-0.20220623172243-579a4753b77a h1:KcGbhXMmBO1WOLwaDRf4awKYCQcNp0178Km0qEHlj0s= -maunium.net/go/mautrix v0.11.1-0.20220623172243-579a4753b77a/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I= +maunium.net/go/mautrix v0.11.1-0.20220624140129-4eb8a89ebea6 h1:MWtDhsSaNFR4pT6+sqQQBxykTh1c+awvXdKuv+igsso= +maunium.net/go/mautrix v0.11.1-0.20220624140129-4eb8a89ebea6/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I= diff --git a/portal.go b/portal.go index 6278da5..e457d53 100644 --- a/portal.go +++ b/portal.go @@ -118,10 +118,6 @@ func (br *WABridge) GetAllPortals() []*Portal { return br.dbPortalsToPortals(br.DB.Portal.GetAll()) } -func (br *WABridge) GetAllPortalsForUser(userID id.UserID) []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.GetAllForUser(userID)) -} - func (br *WABridge) GetAllPortalsByJID(jid types.JID) []*Portal { return br.dbPortalsToPortals(br.DB.Portal.GetAllByJID(jid)) }