1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
diff --git a/shared/ssh.go b/shared/ssh.go
index b30a2ef..f0270e7 100644
--- a/shared/ssh.go
+++ b/shared/ssh.go
@@ -1,7 +1,6 @@
package shared
import (
- "fmt"
"log/slog"
"github.com/charmbracelet/ssh"
@@ -9,47 +8,6 @@ import (
"github.com/picosh/utils"
)
-type ctxUserKey struct{}
-type ctxFeatureFlagKey struct{}
-
-func GetUser(ctx ssh.Context) (*db.User, error) {
- user, ok := ctx.Value(ctxUserKey{}).(*db.User)
- if !ok {
- return user, fmt.Errorf("user not set on `ssh.Context()` for connection")
- }
- return user, nil
-}
-
-func SetUser(ctx ssh.Context, user *db.User) {
- ctx.SetValue(ctxUserKey{}, user)
-}
-
-func GetFeatureFlag(ctx ssh.Context) (*db.FeatureFlag, error) {
- ff, ok := ctx.Value(ctxFeatureFlagKey{}).(*db.FeatureFlag)
- if !ok || ff.Name == "" {
- return ff, fmt.Errorf("feature flag not set on `ssh.Context()` for connection")
- }
- return ff, nil
-}
-
-func SetFeatureFlag(ctx ssh.Context, ff *db.FeatureFlag) {
- ctx.SetValue(ctxFeatureFlagKey{}, ff)
-}
-
-type ctxPublicKey struct{}
-
-func GetPublicKey(ctx ssh.Context) (ssh.PublicKey, error) {
- pk, ok := ctx.Value(ctxPublicKey{}).(ssh.PublicKey)
- if !ok {
- return nil, fmt.Errorf("public key not set on `ssh.Context()` for connection")
- }
- return pk, nil
-}
-
-func SetPublicKey(ctx ssh.Context, pk ssh.PublicKey) {
- ctx.SetValue(ctxPublicKey{}, pk)
-}
-
type SshAuthHandler struct {
DBPool db.DB
Logger *slog.Logger
@@ -64,11 +22,28 @@ func NewSshAuthHandler(dbpool db.DB, logger *slog.Logger, cfg *ConfigSite) *SshA
}
}
-func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool {
- SetPublicKey(ctx, key)
+func FindPlusFF(dbpool db.DB, cfg *ConfigSite, userID string) *db.FeatureFlag {
+ ff, _ := dbpool.FindFeatureForUser(userID, "plus")
+ // we have free tiers so users might not have a feature flag
+ // in which case we set sane defaults
+ if ff == nil {
+ ff = db.NewFeatureFlag(
+ userID,
+ "plus",
+ cfg.MaxSize,
+ cfg.MaxAssetSize,
+ cfg.MaxSpecialFileSize,
+ )
+ }
+ // this is jank
+ ff.Data.StorageMax = ff.FindStorageMax(cfg.MaxSize)
+ ff.Data.FileMax = ff.FindFileMax(cfg.MaxAssetSize)
+ ff.Data.SpecialFileMax = ff.FindSpecialFileMax(cfg.MaxSpecialFileSize)
+ return ff
+}
+func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool {
pubkey := utils.KeyForKeyText(key)
-
user, err := r.DBPool.FindUserForKey(ctx.User(), pubkey)
if err != nil {
r.Logger.Error(
@@ -84,24 +59,10 @@ func (r *SshAuthHandler) PubkeyAuthHandler(ctx ssh.Context, key ssh.PublicKey) b
return false
}
- ff, _ := r.DBPool.FindFeatureForUser(user.ID, "plus")
- // we have free tiers so users might not have a feature flag
- // in which case we set sane defaults
- if ff == nil {
- ff = db.NewFeatureFlag(
- user.ID,
- "plus",
- r.Cfg.MaxSize,
- r.Cfg.MaxAssetSize,
- r.Cfg.MaxSpecialFileSize,
- )
+ if ctx.Permissions().Extensions == nil {
+ ctx.Permissions().Extensions = map[string]string{}
}
- // this is jank
- ff.Data.StorageMax = ff.FindStorageMax(r.Cfg.MaxSize)
- ff.Data.FileMax = ff.FindFileMax(r.Cfg.MaxAssetSize)
- ff.Data.SpecialFileMax = ff.FindSpecialFileMax(r.Cfg.MaxSpecialFileSize)
-
- SetUser(ctx, user)
- SetFeatureFlag(ctx, ff)
+ ctx.Permissions().Extensions["user_id"] = user.ID
+ ctx.Permissions().Extensions["pubkey"] = pubkey
return true
}
|