dashboard / erock/pico / feat: access control using ssh certs #84 rss

open · opened on 2025-12-01T05:25:52Z by erock
Help
checkout latest patchset:
ssh pr.pico.sh print pr-84 | git am -3
checkout any patchset in a patch request:
ssh pr.pico.sh print ps-X | git am -3
add changes to patch request:
git format-patch main --stdout | ssh pr.pico.sh pr add 84
add review to patch request:
git format-patch main --stdout | ssh pr.pico.sh pr add --review 84
accept PR:
ssh pr.pico.sh pr accept 84
close PR:
ssh pr.pico.sh pr close 84

Logs

erock created pr with ps-158 on 2025-12-01T05:25:52Z

Patchsets

ps-158 by erock on 2025-12-01T05:25:52Z

+2 -1 Makefile link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
diff --git a/Makefile b/Makefile
index f0b596c..67c883f 100644
--- a/Makefile
+++ b/Makefile
@@ -142,10 +142,11 @@ migrate:
 	$(DOCKER_CMD) exec -i $(DB_CONTAINER) psql -U $(PGUSER) -d $(PGDATABASE) < ./sql/migrations/20250320_add_tunnel_id_to_tuns_event_logs_table.sql
 	$(DOCKER_CMD) exec -i $(DB_CONTAINER) psql -U $(PGUSER) -d $(PGDATABASE) < ./sql/migrations/20250410_add_index_analytics_visits_host_list.sql
 	$(DOCKER_CMD) exec -i $(DB_CONTAINER) psql -U $(PGUSER) -d $(PGDATABASE) < ./sql/migrations/20250418_add_project_post_idx_analytics.sql
+	$(DOCKER_CMD) exec -i $(DB_CONTAINER) psql -U $(PGUSER) -d $(PGDATABASE) < ./sql/migrations/20251130_add_expires_at_to_public_keys.sql
 .PHONY: migrate
 
 latest:
-	$(DOCKER_CMD) exec -i $(DB_CONTAINER) psql -U $(PGUSER) -d $(PGDATABASE) < ./sql/migrations/20250418_add_project_post_idx_analytics.sql
+	$(DOCKER_CMD) exec -i $(DB_CONTAINER) psql -U $(PGUSER) -d $(PGDATABASE) < ./sql/migrations/20251130_add_expires_at_to_public_keys.sql
 .PHONY: latest
 
 psql:
+1 -1 pkg/apps/feeds/ssh.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/apps/feeds/ssh.go b/pkg/apps/feeds/ssh.go
index 37ed52a..548f1a8 100644
--- a/pkg/apps/feeds/ssh.go
+++ b/pkg/apps/feeds/ssh.go
@@ -46,7 +46,7 @@ func StartSshServer() {
 	}
 	handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap)
 
-	sshAuth := shared.NewSshAuthHandler(dbh, logger)
+	sshAuth := shared.NewSshAuthHandler(dbh, logger, "feeds")
 
 	// Create a new SSH server
 	server, err := pssh.NewSSHServerWithConfig(
+1 -1 pkg/apps/pastes/ssh.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/apps/pastes/ssh.go b/pkg/apps/pastes/ssh.go
index 71ed1b8..86a6f42 100644
--- a/pkg/apps/pastes/ssh.go
+++ b/pkg/apps/pastes/ssh.go
@@ -45,7 +45,7 @@ func StartSshServer() {
 		"fallback": filehandlers.NewScpPostHandler(dbh, cfg, hooks),
 	}
 	handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap)
-	sshAuth := shared.NewSshAuthHandler(dbh, logger)
+	sshAuth := shared.NewSshAuthHandler(dbh, logger, "pastes")
 
 	// Create a new SSH server
 	server, err := pssh.NewSSHServerWithConfig(
+1 -1 pkg/apps/pgs/ssh.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/apps/pgs/ssh.go b/pkg/apps/pgs/ssh.go
index 30ddbeb..98d302d 100644
--- a/pkg/apps/pgs/ssh.go
+++ b/pkg/apps/pgs/ssh.go
@@ -34,7 +34,7 @@ func StartSshServer(cfg *PgsConfig, killCh chan error) {
 		ctx,
 	)
 
-	sshAuth := shared.NewSshAuthHandler(cfg.DB, logger)
+	sshAuth := shared.NewSshAuthHandler(cfg.DB, logger, "pgs")
 
 	webTunnel := &tunkit.WebTunnelHandler{
 		Logger:      logger,
+1 -1 pkg/apps/pico/ssh.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/apps/pico/ssh.go b/pkg/apps/pico/ssh.go
index 2050e28..c5d710b 100644
--- a/pkg/apps/pico/ssh.go
+++ b/pkg/apps/pico/ssh.go
@@ -64,7 +64,7 @@ func StartSshServer() {
 		DBPool: dbpool,
 	}
 
-	sshAuth := shared.NewSshAuthHandler(dbpool, logger)
+	sshAuth := shared.NewSshAuthHandler(dbpool, logger, "pico")
 
 	// Create a new SSH server
 	server, err := pssh.NewSSHServerWithConfig(
+1 -1 pkg/apps/pipe/ssh.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/apps/pipe/ssh.go b/pkg/apps/pipe/ssh.go
index 9ddf122..af7312f 100644
--- a/pkg/apps/pipe/ssh.go
+++ b/pkg/apps/pipe/ssh.go
@@ -46,7 +46,7 @@ func StartSshServer() {
 		Access:  syncmap.New[string, []string](),
 	}
 
-	sshAuth := shared.NewSshAuthHandler(dbh, logger)
+	sshAuth := shared.NewSshAuthHandler(dbh, logger, "pipe")
 
 	// Create a new SSH server
 	server, err := pssh.NewSSHServerWithConfig(
+1 -1 pkg/apps/prose/ssh.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
diff --git a/pkg/apps/prose/ssh.go b/pkg/apps/prose/ssh.go
index 94dd7b9..709f8d8 100644
--- a/pkg/apps/prose/ssh.go
+++ b/pkg/apps/prose/ssh.go
@@ -59,7 +59,7 @@ func StartSshServer() {
 	}
 	handler := filehandlers.NewFileHandlerRouter(cfg, dbh, fileMap)
 
-	sshAuth := shared.NewSshAuthHandler(dbh, logger)
+	sshAuth := shared.NewSshAuthHandler(dbh, logger, "prose")
 
 	// Create a new SSH server
 	server, err := pssh.NewSSHServerWithConfig(
+9 -0 pkg/db/db.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
diff --git a/pkg/db/db.go b/pkg/db/db.go
index 8176c21..9f99f27 100644
--- a/pkg/db/db.go
+++ b/pkg/db/db.go
@@ -36,6 +36,15 @@ type PublicKey struct {
 	Name      string     `json:"name" db:"name"`
 	Key       string     `json:"public_key" db:"public_key"`
 	CreatedAt *time.Time `json:"created_at" db:"created_at"`
+	ExpiresAt *time.Time `json:"expires_at" db:"expires_at"`
+}
+
+func (pk *PublicKey) IsValid() bool {
+	if pk.ExpiresAt == nil {
+		return true
+	}
+	now := time.Now()
+	return pk.ExpiresAt.After(now)
 }
 
 type User struct {
+7 -7 pkg/db/postgres/storage.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
diff --git a/pkg/db/postgres/storage.go b/pkg/db/postgres/storage.go
index 84fcfcf..eb7bd1b 100644
--- a/pkg/db/postgres/storage.go
+++ b/pkg/db/postgres/storage.go
@@ -135,11 +135,11 @@ var (
 )
 
 const (
-	sqlSelectPublicKey         = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE public_key = $1`
-	sqlSelectPublicKeys        = `SELECT id, user_id, name, public_key, created_at FROM public_keys WHERE user_id = $1 ORDER BY created_at ASC`
+	sqlSelectPublicKey         = `SELECT id, user_id, name, public_key, created_at, expires_at FROM public_keys WHERE public_key = $1`
+	sqlSelectPublicKeys        = `SELECT id, user_id, name, public_key, created_at, expires_at FROM public_keys WHERE user_id = $1 ORDER BY created_at ASC`
 	sqlSelectUser              = `SELECT id, name, created_at FROM app_users WHERE id = $1`
 	sqlSelectUserForName       = `SELECT id, name, created_at FROM app_users WHERE name = $1`
-	sqlSelectUserForNameAndKey = `SELECT app_users.id, app_users.name, app_users.created_at, public_keys.id as pk_id, public_keys.public_key, public_keys.created_at as pk_created_at FROM app_users LEFT JOIN public_keys ON public_keys.user_id = app_users.id WHERE app_users.name = $1 AND public_keys.public_key = $2`
+	sqlSelectUserForNameAndKey = `SELECT app_users.id, app_users.name, app_users.created_at, public_keys.id as pk_id, public_keys.public_key, public_keys.created_at as pk_created_at, public_keys.expires_at FROM app_users LEFT JOIN public_keys ON public_keys.user_id = app_users.id WHERE app_users.name = $1 AND public_keys.public_key = $2`
 	sqlSelectUsers             = `SELECT id, name, created_at FROM app_users ORDER BY name ASC`
 
 	sqlSelectUserForToken = `
@@ -450,7 +450,7 @@ func (me *PsqlDB) FindPublicKeyForKey(key string) (*db.PublicKey, error) {
 
 	for rs.Next() {
 		pk := &db.PublicKey{}
-		err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
+		err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt, &pk.ExpiresAt)
 		if err != nil {
 			return nil, err
 		}
@@ -486,7 +486,7 @@ func (me *PsqlDB) FindPublicKey(pubkeyID string) (*db.PublicKey, error) {
 
 	for rs.Next() {
 		pk := &db.PublicKey{}
-		err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
+		err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt, &pk.ExpiresAt)
 		if err != nil {
 			return nil, err
 		}
@@ -513,7 +513,7 @@ func (me *PsqlDB) FindKeysForUser(user *db.User) ([]*db.PublicKey, error) {
 	}
 	for rs.Next() {
 		pk := &db.PublicKey{}
-		err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt)
+		err := rs.Scan(&pk.ID, &pk.UserID, &pk.Name, &pk.Key, &pk.CreatedAt, &pk.ExpiresAt)
 		if err != nil {
 			return keys, err
 		}
@@ -678,7 +678,7 @@ func (me *PsqlDB) FindUserForNameAndKey(name string, key string) (*db.User, erro
 	pk := &db.PublicKey{}
 
 	r := me.Db.QueryRow(sqlSelectUserForNameAndKey, strings.ToLower(name), key)
-	err := r.Scan(&user.ID, &user.Name, &user.CreatedAt, &pk.ID, &pk.Key, &pk.CreatedAt)
+	err := r.Scan(&user.ID, &user.Name, &user.CreatedAt, &pk.ID, &pk.Key, &pk.CreatedAt, &pk.ExpiresAt)
 	if err != nil {
 		return nil, err
 	}
+52 -9 pkg/shared/ssh.go link
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
diff --git a/pkg/shared/ssh.go b/pkg/shared/ssh.go
index 30840c5..ec080cc 100644
--- a/pkg/shared/ssh.go
+++ b/pkg/shared/ssh.go
@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"log/slog"
 	"strings"
+	"time"
 
 	"github.com/picosh/pico/pkg/db"
 	"github.com/picosh/utils"
@@ -13,8 +14,9 @@ import (
 const adminPrefix = "admin__"
 
 type SshAuthHandler struct {
-	DB     AuthFindUser
-	Logger *slog.Logger
+	DB        AuthFindUser
+	Logger    *slog.Logger
+	Principal string
 }
 
 type AuthFindUser interface {
@@ -23,18 +25,54 @@ type AuthFindUser interface {
 	FindFeature(userID, name string) (*db.FeatureFlag, error)
 }
 
-func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger) *SshAuthHandler {
+func NewSshAuthHandler(dbh AuthFindUser, logger *slog.Logger, principal string) *SshAuthHandler {
 	return &SshAuthHandler{
-		DB:     dbh,
-		Logger: logger,
+		DB:        dbh,
+		Logger:    logger,
+		Principal: principal,
 	}
 }
 
 func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
-	pubkey := utils.KeyForKeyText(key)
-	user, err := r.DB.FindUserByPubkey(pubkey)
+	log := r.Logger
+	var user *db.User
+	var err error
+	pubkey := ""
+
+	cert, ok := key.(*ssh.Certificate)
+	if ok {
+		if cert.CertType != ssh.UserCert {
+			return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
+		}
+
+		found := false
+		for _, princ := range cert.ValidPrincipals {
+			if princ == "admin" || princ == r.Principal {
+				found = true
+				break
+			}
+		}
+		if !found {
+			return nil, fmt.Errorf("ssh: principals not valid")
+		}
+
+		clock := time.Now
+		unixNow := clock().Unix()
+		if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
+			return nil, fmt.Errorf("ssh: cert is not yet valid")
+		}
+		if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
+			return nil, fmt.Errorf("ssh: cert has expired")
+		}
+
+		pubkey = utils.KeyForKeyText(cert.SignatureKey)
+	} else {
+		pubkey = utils.KeyForKeyText(key)
+	}
+
+	user, err = r.DB.FindUserByPubkey(pubkey)
 	if err != nil {
-		r.Logger.Error(
+		log.Error(
 			"could not find user for key",
 			"keyType", key.Type(),
 			"key", string(key.Marshal()),
@@ -43,8 +81,13 @@ func (r *SshAuthHandler) PubkeyAuthHandler(conn ssh.ConnMetadata, key ssh.Public
 		return nil, err
 	}
 
+	// TODO: fix since we don't always have access to public key record here
+	if !user.PublicKey.IsValid() {
+		return nil, fmt.Errorf("public key has been revoked")
+	}
+
 	if user.Name == "" {
-		r.Logger.Error("username is not set")
+		log.Error("username is not set")
 		return nil, fmt.Errorf("username is not set")
 	}
 
+1 -0 sql/migrations/20251130_add_expires_at_to_public_keys.sql link
1
2
3
4
5
6
7
diff --git a/sql/migrations/20251130_add_expires_at_to_public_keys.sql b/sql/migrations/20251130_add_expires_at_to_public_keys.sql
new file mode 100644
index 0000000..936698e
--- /dev/null
+++ b/sql/migrations/20251130_add_expires_at_to_public_keys.sql
@@ -0,0 +1,1 @@
+ALTER TABLE public_keys ADD COLUMN expires_at timestamp without time zone NOT NULL DEFAULT NOW();