dashboard / erock/pico / feat(pgs): admins can impersonate #57 rss

open · opened on 2025-03-28T16:55:01Z by erock
Help
checkout latest patchset:
ssh pr.pico.sh print pr-57 | git am -3
checkout any patchset in a patch request:
ssh pr.pico.sh print ps-X | git am -3
add changes to patch request:
git format-patch main --stdout | ssh pr.pico.sh pr add 57
add review to patch request:
git format-patch main --stdout | ssh pr.pico.sh pr add --review 57
accept PR:
ssh pr.pico.sh pr accept 57
close PR:
ssh pr.pico.sh pr close 57

Logs

erock created pr with ps-117 on 2025-03-28T16:55:01Z

Patchsets

ps-117 by erock on 2025-03-28T16:55:01Z

feat(pgs): admins can impersonate

This change will let admins impersonate any user for the pgs cli
pkg/apps/auth/api.go link
+3 -3
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
diff --git a/pkg/apps/auth/api.go b/pkg/apps/auth/api.go
index 44b1769..508996b 100644
--- a/pkg/apps/auth/api.go
+++ b/pkg/apps/auth/api.go
@@ -303,7 +303,7 @@ func userHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
 			"publicKey", data.PublicKey,
 		)
 
-		user, err := apiConfig.Dbpool.FindUserForName(data.Username)
+		user, err := apiConfig.Dbpool.FindUserByName(data.Username)
 		if err != nil {
 			apiConfig.Cfg.Logger.Error(err.Error())
 			http.Error(w, err.Error(), http.StatusNotFound)
@@ -461,7 +461,7 @@ func paymentWebhookHandler(apiConfig *shared.ApiConfig) http.HandlerFunc {
 		status := event.Data.Attr.Status
 		txID := fmt.Sprint(event.Data.Attr.OrderNumber)
 
-		user, err := apiConfig.Dbpool.FindUserForName(username)
+		user, err := apiConfig.Dbpool.FindUserByName(username)
 		if err != nil {
 			logger.Error("no user found with username", "username", username)
 			w.WriteHeader(http.StatusOK)
@@ -624,7 +624,7 @@ func deserializeCaddyAccessLog(dbpool db.DB, access *AccessLog) (*db.AnalyticsVi
 	}
 
 	// get user ID
-	user, err := dbpool.FindUserForName(props.Username)
+	user, err := dbpool.FindUserByName(props.Username)
 	if err != nil {
 		return nil, fmt.Errorf("could not find user for name %s: %w", props.Username, err)
 	}
pkg/apps/auth/api_test.go link
+2 -2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
diff --git a/pkg/apps/auth/api_test.go b/pkg/apps/auth/api_test.go
index 597865e..210cb45 100644
--- a/pkg/apps/auth/api_test.go
+++ b/pkg/apps/auth/api_test.go
@@ -220,7 +220,7 @@ func (a *AuthDb) AddPicoPlusUser(username, email, from, txid string) error {
 	return nil
 }
 
-func (a *AuthDb) FindUserForName(username string) (*db.User, error) {
+func (a *AuthDb) FindUserByName(username string) (*db.User, error) {
 	return &db.User{ID: testUserID, Name: username}, nil
 }
 
@@ -243,7 +243,7 @@ func (a *AuthDb) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
 	return []*db.PublicKey{{ID: "1", UserID: user.ID, Name: "my-key", Key: "nice-pubkey", CreatedAt: &time.Time{}}}, nil
 }
 
-func (a *AuthDb) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
+func (a *AuthDb) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
 	now := time.Date(2021, 8, 15, 14, 30, 45, 100, time.UTC)
 	oneDayWarning := now.AddDate(0, 0, 1)
 	return &db.FeatureFlag{ID: "2", UserID: userID, Name: "plus", ExpiresAt: &oneDayWarning, CreatedAt: &now}, nil
pkg/apps/pastes/api.go link
+3 -3
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
diff --git a/pkg/apps/pastes/api.go b/pkg/apps/pastes/api.go
index 6ad50d7..a31d52f 100644
--- a/pkg/apps/pastes/api.go
+++ b/pkg/apps/pastes/api.go
@@ -78,7 +78,7 @@ func blogHandler(w http.ResponseWriter, r *http.Request) {
 	logger := blogger.With("user", username)
 	cfg := shared.GetCfg(r)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("user not found")
 		http.Error(w, "user not found", http.StatusNotFound)
@@ -170,7 +170,7 @@ func postHandler(w http.ResponseWriter, r *http.Request) {
 	blogger := shared.GetLogger(r)
 	logger := blogger.With("slug", slug, "user", username)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("paste not found")
 		http.Error(w, "paste not found", http.StatusNotFound)
@@ -271,7 +271,7 @@ func postHandlerRaw(w http.ResponseWriter, r *http.Request) {
 	blogger := shared.GetLogger(r)
 	logger := blogger.With("user", username, "slug", slug)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("user not found")
 		http.Error(w, "user not found", http.StatusNotFound)
pkg/apps/pgs/cli_middleware.go link
+4 -16
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
diff --git a/pkg/apps/pgs/cli_middleware.go b/pkg/apps/pgs/cli_middleware.go
index f974066..05707d6 100644
--- a/pkg/apps/pgs/cli_middleware.go
+++ b/pkg/apps/pgs/cli_middleware.go
@@ -10,26 +10,14 @@ import (
 	"github.com/picosh/pico/pkg/db"
 	"github.com/picosh/pico/pkg/pssh"
 	sendutils "github.com/picosh/pico/pkg/send/utils"
-	"github.com/picosh/utils"
 )
 
 func getUser(s *pssh.SSHServerConnSession, dbpool pgsdb.PgsDB) (*db.User, error) {
-	if s.PublicKey() == nil {
-		return nil, fmt.Errorf("key not found")
+	userID, ok := s.Conn.Permissions.Extensions["user_id"]
+	if !ok {
+		return nil, fmt.Errorf("`user_id` extension not found")
 	}
-
-	key := utils.KeyForKeyText(s.PublicKey())
-
-	user, err := dbpool.FindUserByPubkey(key)
-	if err != nil {
-		return nil, err
-	}
-
-	if user.Name == "" {
-		return nil, fmt.Errorf("must have username set")
-	}
-
-	return user, nil
+	return dbpool.FindUser(userID)
 }
 
 type arrayFlags []string
pkg/apps/pico/cli.go link
+2 -2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
diff --git a/pkg/apps/pico/cli.go b/pkg/apps/pico/cli.go
index 4588d59..3408e8f 100644
--- a/pkg/apps/pico/cli.go
+++ b/pkg/apps/pico/cli.go
@@ -117,10 +117,10 @@ func Middleware(handler *CliHandler) pssh.SSHServerMiddleware {
 					return err
 				}
 
-				ff, err := dbpool.FindFeatureForUser(user.ID, "plus")
+				ff, err := dbpool.FindFeature(user.ID, "plus")
 				if err != nil {
 					handler.Logger.Error("Unable to find plus feature flag", "err", err, "user", user, "command", args)
-					ff, err = dbpool.FindFeatureForUser(user.ID, "bouncer")
+					ff, err = dbpool.FindFeature(user.ID, "bouncer")
 					if err != nil {
 						handler.Logger.Error("Unable to find bouncer feature flag", "err", err, "user", user, "command", args)
 						sesh.Fatal(err)
pkg/apps/prose/api.go link
+6 -6
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
diff --git a/pkg/apps/prose/api.go b/pkg/apps/prose/api.go
index ee86a9d..a0ae6d1 100644
--- a/pkg/apps/prose/api.go
+++ b/pkg/apps/prose/api.go
@@ -125,7 +125,7 @@ func blogStyleHandler(w http.ResponseWriter, r *http.Request) {
 	logger := shared.GetLogger(r)
 	cfg := shared.GetCfg(r)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("blog not found", "user", username)
 		http.Error(w, "blog not found", http.StatusNotFound)
@@ -155,7 +155,7 @@ func blogHandler(w http.ResponseWriter, r *http.Request) {
 	logger := shared.GetLogger(r)
 	cfg := shared.GetCfg(r)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("blog not found", "user", username)
 		http.Error(w, "blog not found", http.StatusNotFound)
@@ -301,7 +301,7 @@ func postRawHandler(w http.ResponseWriter, r *http.Request) {
 	logger := shared.GetLogger(r)
 	logger = logger.With("slug", slug)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("blog not found", "user", username)
 		http.Error(w, "blog not found", http.StatusNotFound)
@@ -341,7 +341,7 @@ func postHandler(w http.ResponseWriter, r *http.Request) {
 	dbpool := shared.GetDB(r)
 	logger := shared.GetLogger(r)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("blog not found", "user", username)
 		http.Error(w, "blog not found", http.StatusNotFound)
@@ -589,7 +589,7 @@ func rssBlogHandler(w http.ResponseWriter, r *http.Request) {
 	logger := shared.GetLogger(r)
 	cfg := shared.GetCfg(r)
 
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Info("rss feed not found", "user", username)
 		http.Error(w, "rss feed not found", http.StatusNotFound)
@@ -852,7 +852,7 @@ func imgRequest(w http.ResponseWriter, r *http.Request) {
 	logger := shared.GetLogger(r)
 	dbpool := shared.GetDB(r)
 	username := shared.GetUsernameFromRequest(r)
-	user, err := dbpool.FindUserForName(username)
+	user, err := dbpool.FindUserByName(username)
 	if err != nil {
 		logger.Error("could not find user", "username", username)
 		http.Error(w, "could find user", http.StatusNotFound)
pkg/db/db.go link
+3 -3
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
diff --git a/pkg/db/db.go b/pkg/db/db.go
index 1dd264d..bf0ff85 100644
--- a/pkg/db/db.go
+++ b/pkg/db/db.go
@@ -212,7 +212,7 @@ type Token struct {
 type FeatureFlag struct {
 	ID               string          `json:"id" db:"id"`
 	UserID           string          `json:"user_id" db:"user_id"`
-	PaymentHistoryID string          `json:"payment_history_id" db:"payment_history_id"`
+	PaymentHistoryID sql.NullString  `json:"payment_history_id" db:"payment_history_id"`
 	Name             string          `json:"name" db:"name"`
 	CreatedAt        *time.Time      `json:"created_at" db:"created_at"`
 	ExpiresAt        *time.Time      `json:"expires_at" db:"expires_at"`
@@ -370,7 +370,7 @@ type DB interface {
 	RemoveKeys(pubkeyIDs []string) error
 
 	FindUsers() ([]*User, error)
-	FindUserForName(name string) (*User, error)
+	FindUserByName(name string) (*User, error)
 	FindUserForNameAndKey(name string, pubkey string) (*User, error)
 	FindUserForKey(name string, pubkey string) (*User, error)
 	FindUserByPubkey(pubkey string) (*User, error)
@@ -414,7 +414,7 @@ type DB interface {
 	FindVisitSiteList(opts *SummaryOpts) ([]*VisitUrl, error)
 
 	AddPicoPlusUser(username, email, paymentType, txId string) error
-	FindFeatureForUser(userID string, feature string) (*FeatureFlag, error)
+	FindFeature(userID string, feature string) (*FeatureFlag, error)
 	FindFeaturesForUser(userID string) ([]*FeatureFlag, error)
 	HasFeatureForUser(userID string, feature string) bool
 	FindTotalSizeForUser(userID string) (int, error)
pkg/db/postgres/storage.go link
+9 -9
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
diff --git a/pkg/db/postgres/storage.go b/pkg/db/postgres/storage.go
index ffac498..0ba8d58 100644
--- a/pkg/db/postgres/storage.go
+++ b/pkg/db/postgres/storage.go
@@ -617,14 +617,14 @@ func (me *PsqlDB) ValidateName(name string) (bool, error) {
 	if !v {
 		return false, fmt.Errorf("%s is invalid: %w", lower, db.ErrNameInvalid)
 	}
-	user, _ := me.FindUserForName(lower)
+	user, _ := me.FindUserByName(lower)
 	if user == nil {
 		return true, nil
 	}
 	return false, fmt.Errorf("%s already taken: %w", lower, db.ErrNameTaken)
 }
 
-func (me *PsqlDB) FindUserForName(name string) (*db.User, error) {
+func (me *PsqlDB) FindUserByName(name string) (*db.User, error) {
 	user := &db.User{}
 	r := me.Db.QueryRow(sqlSelectUserForName, strings.ToLower(name))
 	err := r.Scan(&user.ID, &user.Name, &user.CreatedAt)
@@ -1457,7 +1457,7 @@ func (me *PsqlDB) FindTagsForPost(postID string) ([]string, error) {
 	return tags, nil
 }
 
-func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
+func (me *PsqlDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
 	ff := &db.FeatureFlag{}
 	// payment history is allowed to be null
 	// https://devtidbits.com/2020/08/03/go-sql-error-converting-null-to-string-is-unsupported/
@@ -1475,7 +1475,7 @@ func (me *PsqlDB) FindFeatureForUser(userID string, feature string) (*db.Feature
 		return nil, err
 	}
 
-	ff.PaymentHistoryID = paymentHistoryID.String
+	ff.PaymentHistoryID = paymentHistoryID
 
 	return ff, nil
 }
@@ -1507,7 +1507,7 @@ func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error)
 		if err != nil {
 			return features, err
 		}
-		ff.PaymentHistoryID = paymentHistoryID.String
+		ff.PaymentHistoryID = paymentHistoryID
 
 		features = append(features, ff)
 	}
@@ -1518,7 +1518,7 @@ func (me *PsqlDB) FindFeaturesForUser(userID string) ([]*db.FeatureFlag, error)
 }
 
 func (me *PsqlDB) HasFeatureForUser(userID string, feature string) bool {
-	ff, err := me.FindFeatureForUser(userID, feature)
+	ff, err := me.FindFeature(userID, feature)
 	if err != nil {
 		return false
 	}
@@ -1695,7 +1695,7 @@ func (me *PsqlDB) InsertFeature(userID, name string, expiresAt time.Time) (*db.F
 		return nil, err
 	}
 
-	feature, err := me.FindFeatureForUser(userID, name)
+	feature, err := me.FindFeature(userID, name)
 	if err != nil {
 		return nil, err
 	}
@@ -1709,7 +1709,7 @@ func (me *PsqlDB) RemoveFeature(userID string, name string) error {
 }
 
 func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
-	ff, _ := me.FindFeatureForUser(userID, name)
+	ff, _ := me.FindFeature(userID, name)
 	if ff == nil {
 		t := time.Now()
 		return t.AddDate(1, 0, 0)
@@ -1718,7 +1718,7 @@ func (me *PsqlDB) createFeatureExpiresAt(userID, name string) time.Time {
 }
 
 func (me *PsqlDB) AddPicoPlusUser(username, email, paymentType, txId string) error {
-	user, err := me.FindUserForName(username)
+	user, err := me.FindUserByName(username)
 	if err != nil {
 		return err
 	}
pkg/db/stub/stub.go link
+2 -2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
diff --git a/pkg/db/stub/stub.go b/pkg/db/stub/stub.go
index a240455..454d0c4 100644
--- a/pkg/db/stub/stub.go
+++ b/pkg/db/stub/stub.go
@@ -77,7 +77,7 @@ func (me *StubDB) ValidateName(name string) (bool, error) {
 	return false, notImpl
 }
 
-func (me *StubDB) FindUserForName(name string) (*db.User, error) {
+func (me *StubDB) FindUserByName(name string) (*db.User, error) {
 	return nil, notImpl
 }
 
@@ -189,7 +189,7 @@ func (me *StubDB) FindTagsForPost(postID string) ([]string, error) {
 	return []string{}, notImpl
 }
 
-func (me *StubDB) FindFeatureForUser(userID string, feature string) (*db.FeatureFlag, error) {
+func (me *StubDB) FindFeature(userID string, feature string) (*db.FeatureFlag, error) {
 	return nil, notImpl
 }
 
pkg/shared/api.go link
+1 -1
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/shared/api.go b/pkg/shared/api.go
index 9089b1a..d58aeba 100644
--- a/pkg/shared/api.go
+++ b/pkg/shared/api.go
@@ -76,7 +76,7 @@ func CheckHandler(w http.ResponseWriter, r *http.Request) {
 		if !strings.Contains(hostDomain, appDomain) {
 			subdomain := GetCustomDomain(hostDomain, cfg.Space)
 			if subdomain != "" {
-				u, err := dbpool.FindUserForName(subdomain)
+				u, err := dbpool.FindUserByName(subdomain)
 				if u != nil && err == nil {
 					w.WriteHeader(http.StatusOK)
 					return
pkg/shared/feed.go link
+1 -1
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/shared/feed.go b/pkg/shared/feed.go
index 50cafd4..d222982 100644
--- a/pkg/shared/feed.go
+++ b/pkg/shared/feed.go
@@ -90,7 +90,7 @@ func UserFeed(me db.DB, user *db.User, token string) (*feeds.Feed, error) {
 	var feedItems []*feeds.Item
 
 	now := time.Now()
-	ff, err := me.FindFeatureForUser(user.ID, "plus")
+	ff, err := me.FindFeature(user.ID, "plus")
 	if err != nil {
 		// still want to send an empty feed
 	} else {
pkg/shared/ssh.go link
+25 -1
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
diff --git a/pkg/shared/ssh.go b/pkg/shared/ssh.go
index 7f52b74..29107e1 100644
--- a/pkg/shared/ssh.go
+++ b/pkg/shared/ssh.go
@@ -3,6 +3,7 @@ package shared
 import (
 	"fmt"
 	"log/slog"
+	"strings"
 
 	"github.com/picosh/pico/pkg/db"
 	"github.com/picosh/utils"
@@ -16,6 +17,8 @@ type SshAuthHandler struct {
 
 type AuthFindUser interface {
 	FindUserByPubkey(key string) (*db.User, error)
+	FindUserByName(name string) (*db.User, error)
+	FindFeature(userID, name string) (*db.FeatureFlag, error)
 }
 
 func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger) *SshAuthHandler {
@@ -43,8 +46,29 @@ func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.Public
 		return nil, fmt.Errorf("username is not set")
 	}
 
+	// impersonation
+	impID := user.ID
+	adminPrefix := "admin__"
+	usr := conn.User()
+	if strings.HasPrefix(usr, adminPrefix) {
+		ff, err := r.DB.FindFeature(user.ID, "admin")
+		if err != nil {
+			return nil, fmt.Errorf("only admins can impersonate a user: %w", err)
+		}
+		if !ff.IsValid() {
+			return nil, fmt.Errorf("expired admin feature flag, cannot impersonate a user")
+		}
+
+		impersonate := strings.Replace(usr, adminPrefix, "", 1)
+		user, err = r.DB.FindUserByName(impersonate)
+		if err != nil {
+			return nil, err
+		}
+	}
+
 	return &ssh.Permissions{
 		Extensions: map[string]string{
+			"imp_id":  impID,
 			"user_id": user.ID,
 			"pubkey":  pubkey,
 		},
@@ -52,7 +76,7 @@ func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.Public
 }
 
 func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag {
-	ff, _ := dbpool.FindFeatureForUser(userID, "plus")
+	ff, _ := dbpool.FindFeature(userID, "plus")
 	// we have free tiers so users might not have a feature flag
 	// in which case we set sane defaults
 	if ff == nil {
pkg/tui/tuns.go link
+1 -1
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/tui/tuns.go b/pkg/tui/tuns.go
index 9d9ec9d..e60150a 100644
--- a/pkg/tui/tuns.go
+++ b/pkg/tui/tuns.go
@@ -230,7 +230,7 @@ func (m *TunsPage) HandleEvent(ev vaxis.Event, ph vxfw.EventPhase) (vxfw.Command
 	switch msg := ev.(type) {
 	case PageIn:
 		m.loading = true
-		ff, _ := m.shared.Dbpool.FindFeatureForUser(m.shared.User.ID, "admin")
+		ff, _ := m.shared.Dbpool.FindFeature(m.shared.User.ID, "admin")
 		if ff != nil {
 			m.isAdmin = true
 		}
pkg/tui/ui.go link
+2 -2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
diff --git a/pkg/tui/ui.go b/pkg/tui/ui.go
index f6b3d14..0627e81 100644
--- a/pkg/tui/ui.go
+++ b/pkg/tui/ui.go
@@ -296,7 +296,7 @@ func FindUser(shrd *SharedModel) (*db.User, error) {
 			return nil, fmt.Errorf("only admins can impersonate a user")
 		}
 		impersonate := strings.Replace(usr, adminPrefix, "", 1)
-		user, err = shrd.Dbpool.FindUserForName(impersonate)
+		user, err = shrd.Dbpool.FindUserByName(impersonate)
 		if err != nil {
 			return nil, err
 		}
@@ -311,7 +311,7 @@ func FindFeatureFlag(shrd *SharedModel, name string) (*db.FeatureFlag, error) {
 		return nil, nil
 	}
 
-	ff, err := shrd.Dbpool.FindFeatureForUser(shrd.User.ID, name)
+	ff, err := shrd.Dbpool.FindFeature(shrd.User.ID, name)
 	if err != nil {
 		return nil, err
 	}

chore: use `user_id` in log middleware to set the user ctx

pkg/pssh/logger.go link
+7 -2
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
diff --git a/pkg/pssh/logger.go b/pkg/pssh/logger.go
index a1ce9cd..10e465f 100644
--- a/pkg/pssh/logger.go
+++ b/pkg/pssh/logger.go
@@ -1,6 +1,7 @@
 package pssh
 
 import (
+	"fmt"
 	"log/slog"
 	"time"
 
@@ -11,7 +12,7 @@ type ctxLoggerKey struct{}
 type ctxUserKey struct{}
 
 type FindUserInterface interface {
-	FindUserByPubkey(string) (*db.User, error)
+	FindUser(string) (*db.User, error)
 }
 
 type GetLoggerInterface interface {
@@ -29,7 +30,11 @@ func LogMiddleware(getLogger GetLoggerInterface, db FindUserInterface) SSHServer
 
 				user := GetUser(s)
 				if user == nil {
-					user, err := db.FindUserByPubkey(s.Permissions().Extensions["pubkey"])
+					userID, ok := s.Permissions().Extensions["user_id"]
+					if !ok {
+						return fmt.Errorf("`user_id` not set in permissions")
+					}
+					user, err := db.FindUser(userID)
 					if err == nil && user != nil {
 						logger = logger.With(
 							"user", user.Name,