Clean up Segment code

This commit is contained in:
Tulir Asokan 2022-05-16 13:46:18 +03:00
parent 512ca6be89
commit 7c0cf0513a
6 changed files with 50 additions and 51 deletions

View file

@ -47,7 +47,6 @@ type Config struct {
Provisioning struct { Provisioning struct {
Prefix string `yaml:"prefix"` Prefix string `yaml:"prefix"`
SharedSecret string `yaml:"shared_secret"` SharedSecret string `yaml:"shared_secret"`
SegmentKey string `yaml:"segment_key"`
} `yaml:"provisioning"` } `yaml:"provisioning"`
ID string `yaml:"id"` ID string `yaml:"id"`
@ -65,6 +64,8 @@ type Config struct {
HSToken string `yaml:"hs_token"` HSToken string `yaml:"hs_token"`
} `yaml:"appservice"` } `yaml:"appservice"`
SegmentKey string `yaml:"segment_key"`
Metrics struct { Metrics struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
Listen string `yaml:"listen"` Listen string `yaml:"listen"`

View file

@ -55,7 +55,6 @@ func (helper *UpgradeHelper) doUpgrade() {
} else { } else {
helper.Copy(Str, "appservice", "provisioning", "shared_secret") helper.Copy(Str, "appservice", "provisioning", "shared_secret")
} }
helper.Copy(Str|Null, "appservice", "provisioning", "segment_key")
helper.Copy(Str, "appservice", "id") helper.Copy(Str, "appservice", "id")
helper.Copy(Str, "appservice", "bot", "username") helper.Copy(Str, "appservice", "bot", "username")
helper.Copy(Str, "appservice", "bot", "displayname") helper.Copy(Str, "appservice", "bot", "displayname")
@ -64,6 +63,8 @@ func (helper *UpgradeHelper) doUpgrade() {
helper.Copy(Str, "appservice", "as_token") helper.Copy(Str, "appservice", "as_token")
helper.Copy(Str, "appservice", "hs_token") helper.Copy(Str, "appservice", "hs_token")
helper.Copy(Str|Null, "segment_key")
helper.Copy(Bool, "metrics", "enabled") helper.Copy(Bool, "metrics", "enabled")
helper.Copy(Str, "metrics", "listen") helper.Copy(Str, "metrics", "listen")
@ -170,6 +171,7 @@ func (helper *UpgradeHelper) addSpaces() {
helper.addSpaceBeforeComment("appservice", "provisioning") helper.addSpaceBeforeComment("appservice", "provisioning")
helper.addSpaceBeforeComment("appservice", "id") helper.addSpaceBeforeComment("appservice", "id")
helper.addSpaceBeforeComment("appservice", "as_token") helper.addSpaceBeforeComment("appservice", "as_token")
helper.addSpaceBeforeComment("segment_key")
helper.addSpaceBeforeComment("metrics") helper.addSpaceBeforeComment("metrics")
helper.addSpaceBeforeComment("whatsapp") helper.addSpaceBeforeComment("whatsapp")
helper.addSpaceBeforeComment("bridge") helper.addSpaceBeforeComment("bridge")

View file

@ -50,11 +50,6 @@ appservice:
# Shared secret for authentication. If set to "generate", a random secret will be generated, # Shared secret for authentication. If set to "generate", a random secret will be generated,
# or if set to "disable", the provisioning API will be disabled. # or if set to "disable", the provisioning API will be disabled.
shared_secret: generate shared_secret: generate
# Segment API key to enable analytics tracking for web server
# endpoints. Set to null to disable.
# Currently the only events are login start, QR code retrieve, and login
# success/failure.
segment_key: null
# The unique ID of this appservice. # The unique ID of this appservice.
id: whatsapp id: whatsapp
@ -76,6 +71,9 @@ appservice:
as_token: "This value is generated when generating the registration" as_token: "This value is generated when generating the registration"
hs_token: "This value is generated when generating the registration" hs_token: "This value is generated when generating the registration"
# Segment API key to track some events, like provisioning API login and encryption errors.
segment_key: null
# Prometheus config. # Prometheus config.
metrics: metrics:
# Enable prometheus metrics? # Enable prometheus metrics?

View file

@ -272,6 +272,12 @@ func (bridge *Bridge) Init() {
bridge.StateStore = database.NewSQLStateStore(bridge.DB) bridge.StateStore = database.NewSQLStateStore(bridge.DB)
bridge.AS.StateStore = bridge.StateStore bridge.AS.StateStore = bridge.StateStore
Segment.log = bridge.Log.Sub("Segment")
Segment.key = bridge.Config.SegmentKey
if Segment.IsEnabled() {
Segment.log.Infoln("Segment metrics are enabled")
}
bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil) bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil)
bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError

View file

@ -43,15 +43,11 @@ import (
type ProvisioningAPI struct { type ProvisioningAPI struct {
bridge *Bridge bridge *Bridge
log log.Logger log log.Logger
segment *Segment
} }
func (prov *ProvisioningAPI) Init() { func (prov *ProvisioningAPI) Init() {
prov.log = prov.bridge.Log.Sub("Provisioning") prov.log = prov.bridge.Log.Sub("Provisioning")
// Set up segment
prov.segment = NewSegment(prov.bridge.Config.AppService.Provisioning.SegmentKey, prov.log)
prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix) prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix)
r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter() r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter()
r.Use(prov.AuthMiddleware) r.Use(prov.AuthMiddleware)
@ -573,7 +569,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
} }
} }
user.log.Debugln("Started login via provisioning API") user.log.Debugln("Started login via provisioning API")
prov.segment.Track(user.MXID, "$login_start") Segment.Track(user.MXID, "$login_start")
for { for {
select { select {
@ -582,7 +578,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
case whatsmeow.QRChannelSuccess.Event: case whatsmeow.QRChannelSuccess.Event:
jid := user.Client.Store.ID jid := user.Client.Store.ID
user.log.Debugln("Successful login as", jid, "via provisioning API") user.log.Debugln("Successful login as", jid, "via provisioning API")
prov.segment.Track(user.MXID, "$login_success") Segment.Track(user.MXID, "$login_success")
_ = c.WriteJSON(map[string]interface{}{ _ = c.WriteJSON(map[string]interface{}{
"success": true, "success": true,
"jid": jid, "jid": jid,
@ -597,7 +593,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
case whatsmeow.QRChannelTimeout.Event: case whatsmeow.QRChannelTimeout.Event:
user.log.Debugln("Login via provisioning API timed out") user.log.Debugln("Login via provisioning API timed out")
errCode := "login timed out" errCode := "login timed out"
prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
Error: "QR code scan timed out. Please try again.", Error: "QR code scan timed out. Please try again.",
ErrCode: errCode, ErrCode: errCode,
@ -605,7 +601,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
case whatsmeow.QRChannelErrUnexpectedEvent.Event: case whatsmeow.QRChannelErrUnexpectedEvent.Event:
user.log.Debugln("Login via provisioning API failed due to unexpected event") user.log.Debugln("Login via provisioning API failed due to unexpected event")
errCode := "unexpected event" errCode := "unexpected event"
prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
Error: "Got unexpected event while waiting for QRs, perhaps you're already logged in?", Error: "Got unexpected event while waiting for QRs, perhaps you're already logged in?",
ErrCode: errCode, ErrCode: errCode,
@ -613,14 +609,14 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
case whatsmeow.QRChannelClientOutdated.Event: case whatsmeow.QRChannelClientOutdated.Event:
user.log.Debugln("Login via provisioning API failed due to outdated client") user.log.Debugln("Login via provisioning API failed due to outdated client")
errCode := "bridge outdated" errCode := "bridge outdated"
prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
Error: "Got client outdated error while waiting for QRs. The bridge must be updated to continue.", Error: "Got client outdated error while waiting for QRs. The bridge must be updated to continue.",
ErrCode: errCode, ErrCode: errCode,
}) })
case whatsmeow.QRChannelScannedWithoutMultidevice.Event: case whatsmeow.QRChannelScannedWithoutMultidevice.Event:
errCode := "multidevice not enabled" errCode := "multidevice not enabled"
prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
Error: "Please enable the WhatsApp multidevice beta and scan the QR code again.", Error: "Please enable the WhatsApp multidevice beta and scan the QR code again.",
ErrCode: errCode, ErrCode: errCode,
@ -628,13 +624,13 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
continue continue
case "error": case "error":
errCode := "fatal error" errCode := "fatal error"
prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode}) Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
_ = c.WriteJSON(Error{ _ = c.WriteJSON(Error{
Error: "Fatal error while logging in", Error: "Fatal error while logging in",
ErrCode: errCode, ErrCode: errCode,
}) })
case "code": case "code":
prov.segment.Track(user.MXID, "$qrcode_retrieved") Segment.Track(user.MXID, "$qrcode_retrieved")
_ = c.WriteJSON(map[string]interface{}{ _ = c.WriteJSON(map[string]interface{}{
"code": evt.Code, "code": evt.Code,
"timeout": int(evt.Timeout.Seconds()), "timeout": int(evt.Timeout.Seconds()),

View file

@ -27,37 +27,31 @@ import (
const SegmentURL = "https://api.segment.io/v1/track" const SegmentURL = "https://api.segment.io/v1/track"
type Segment struct { type SegmentClient struct {
segmentKey string key string
log log.Logger log log.Logger
client *http.Client client http.Client
} }
func NewSegment(segmentKey string, parentLogger log.Logger) *Segment { var Segment SegmentClient
return &Segment{
segmentKey: segmentKey,
log: parentLogger.Sub("Segment"),
client: &http.Client{},
}
}
func (segment *Segment) track(userID id.UserID, event string, properties map[string]interface{}) error { func (sc *SegmentClient) trackSync(userID id.UserID, event string, properties map[string]interface{}) error {
data := map[string]interface{}{ var buf bytes.Buffer
err := json.NewEncoder(&buf).Encode(map[string]interface{}{
"userId": userID, "userId": userID,
"event": event, "event": event,
"properties": properties, "properties": properties,
} })
json_data, err := json.Marshal(data)
if err != nil { if err != nil {
return err return err
} }
req, err := http.NewRequest("POST", SegmentURL, bytes.NewBuffer(json_data)) req, err := http.NewRequest("POST", SegmentURL, &buf)
if err != nil { if err != nil {
return err return err
} }
req.SetBasicAuth(segment.segmentKey, "") req.SetBasicAuth(sc.key, "")
resp, err := segment.client.Do(req) resp, err := sc.client.Do(req)
if err != nil { if err != nil {
return err return err
} }
@ -65,26 +59,28 @@ func (segment *Segment) track(userID id.UserID, event string, properties map[str
return nil return nil
} }
func (segment *Segment) Track(userID id.UserID, event string, properties ...map[string]interface{}) { func (sc *SegmentClient) IsEnabled() bool {
if segment.segmentKey == "" { return len(sc.key) > 0
return
}
if len(properties) > 1 {
segment.log.Fatalf("Track should be called with at most one property map")
} }
go (func() error { func (sc *SegmentClient) Track(userID id.UserID, event string, properties ...map[string]interface{}) {
if !sc.IsEnabled() {
return
} else if len(properties) > 1 {
panic("Track should be called with at most one property map")
}
go func() {
props := map[string]interface{}{} props := map[string]interface{}{}
if len(properties) > 0 { if len(properties) > 0 {
props = properties[0] props = properties[0]
} }
props["bridge"] = "whatsapp" props["bridge"] = "whatsapp"
err := segment.track(userID, event, props) err := sc.trackSync(userID, event, props)
if err != nil { if err != nil {
segment.log.Errorf("Error tracking %s: %v+", event, err) sc.log.Errorfln("Error tracking %s: %v", event, err)
return err } else {
sc.log.Debugln("Tracked", event)
} }
segment.log.Debug("Tracked ", event) }()
return nil
})()
} }