dashboard / erock/pico / fix(pgs): prevent infinite redirects #106 rss

open · opened on 2026-02-17T01:20:13Z by erock
Help
checkout latest patchset:
ssh pr.pico.sh print pr-106 | git am -3
checkout any patchset in a patch request:
ssh pr.pico.sh print ps-X | git am -3
add changes to patch request:
git format-patch main --stdout | ssh pr.pico.sh pr add 106
add review to patch request:
git format-patch main --stdout | ssh pr.pico.sh pr add --review 106
accept PR:
ssh pr.pico.sh pr accept 106
close PR:
ssh pr.pico.sh pr close 106
Timeline Patchsets
Now when we perform a redirect for a user, we add a special header
`X-Pgs-Redirect-Depth` that contains the number of times a request
has been redirected.  If that number reaches 5 then we send a status 508
Loop Detected.
+24 -0 pkg/apps/pgs/web_asset_handler.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
diff --git a/pkg/apps/pgs/web_asset_handler.go b/pkg/apps/pgs/web_asset_handler.go
index 144b42c..2a3bf1e 100644
--- a/pkg/apps/pgs/web_asset_handler.go
+++ b/pkg/apps/pgs/web_asset_handler.go
@@ -18,6 +18,11 @@ import (
 	"github.com/picosh/pico/pkg/shared/storage"
 )
 
+const (
+	redirectDepthHeader = "X-Pgs-Redirect-Depth"
+	maxRedirectDepth    = 5
+)
+
 type ApiAssetHandler struct {
 	*WebRouter
 	Logger *slog.Logger
@@ -100,6 +105,24 @@ func (h *ApiAssetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		destUrl.RawQuery = r.URL.RawQuery
 
 		if checkIsRedirect(fp.Status) {
+			// Check for redirect loops
+			redirectDepth := 0
+			if depthStr := r.Header.Get(redirectDepthHeader); depthStr != "" {
+				if d, err := strconv.Atoi(depthStr); err == nil {
+					redirectDepth = d
+				}
+			}
+
+			if redirectDepth > maxRedirectDepth {
+				logger.Error(
+					"redirect loop detected",
+					"depth", redirectDepth,
+					"destination", destUrl.String(),
+				)
+				http.Error(w, "Too many redirects", http.StatusLoopDetected)
+				return
+			}
+
 			// hack: check to see if there's an index file in the requested directory
 			// before redirecting, this saves a hop that will just end up a 404
 			if !hasProtocol(fp.Filepath) && strings.HasSuffix(fp.Filepath, "/") {
@@ -115,6 +138,7 @@ func (h *ApiAssetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 				"destination", destUrl.String(),
 				"status", fp.Status,
 			)
+			w.Header().Set(redirectDepthHeader, strconv.Itoa(redirectDepth+1))
 			http.Redirect(w, r, destUrl.String(), fp.Status)
 			return
 		} else if hasProtocol(fp.Filepath) {
+89 -0 pkg/apps/pgs/web_test.go link
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
diff --git a/pkg/apps/pgs/web_test.go b/pkg/apps/pgs/web_test.go
index 7f0de4f..268f7b9 100644
--- a/pkg/apps/pgs/web_test.go
+++ b/pkg/apps/pgs/web_test.go
@@ -572,3 +572,92 @@ func TestImageManipulation(t *testing.T) {
 		})
 	}
 }
+
+func TestRedirectLoopDetection(t *testing.T) {
+	logger := slog.Default()
+	dbpool := NewPgsDb(logger)
+	bucketName := shared.GetAssetBucketName(dbpool.Users[0].ID)
+
+	tt := []struct {
+		name          string
+		path          string
+		redirectDepth string
+		status        int
+		maxDepth      int
+		shouldContain string
+	}{
+		{
+			name:          "no-redirect-depth-header-on-first-request",
+			path:          "/anything",
+			redirectDepth: "",
+			status:        http.StatusMovedPermanently,
+			maxDepth:      5,
+			shouldContain: `<a href="https://example.com">Moved Permanently</a>.`,
+		},
+		{
+			name:          "allow-request-at-max-depth",
+			path:          "/anything",
+			redirectDepth: "5",
+			status:        http.StatusMovedPermanently,
+			maxDepth:      5,
+			shouldContain: `<a href="https://example.com">Moved Permanently</a>.`,
+		},
+		{
+			name:          "allow-request-below-max-depth",
+			path:          "/anything",
+			redirectDepth: "2",
+			status:        http.StatusMovedPermanently,
+			maxDepth:      5,
+			shouldContain: `<a href="https://example.com">Moved Permanently</a>.`,
+		},
+		{
+			name:          "reject-at-depth-6-with-maxdepth-5",
+			path:          "/anything",
+			redirectDepth: "6",
+			status:        http.StatusLoopDetected,
+			maxDepth:      5,
+			shouldContain: "Too many redirects",
+		},
+	}
+
+	for _, tc := range tt {
+		t.Run(tc.name, func(t *testing.T) {
+			request := httptest.NewRequest("GET", dbpool.mkpath(tc.path), strings.NewReader(""))
+			if tc.redirectDepth != "" {
+				request.Header.Set("X-Pgs-Redirect-Depth", tc.redirectDepth)
+			}
+			responseRecorder := httptest.NewRecorder()
+
+			st, _ := storage.NewStorageMemory(map[string]map[string]string{
+				bucketName: {
+					"/test/_redirects": "/anything https://example.com 301",
+				},
+			})
+			pubsub := NewPubsubChan()
+			defer func() {
+				_ = pubsub.Close()
+			}()
+			cfg := NewPgsConfig(logger, dbpool, st, pubsub)
+			cfg.Domain = "pgs.test"
+			router := NewWebRouter(cfg)
+			router.ServeHTTP(responseRecorder, request)
+
+			if responseRecorder.Code != tc.status {
+				t.Errorf("Want status '%d', got '%d'", tc.status, responseRecorder.Code)
+			}
+
+			body := strings.TrimSpace(responseRecorder.Body.String())
+			if !strings.Contains(body, tc.shouldContain) {
+				t.Errorf("Want body to contain '%s', got '%s'", tc.shouldContain, body)
+			}
+
+			// When redirecting, verify the header is incremented for next hop
+			if tc.status == http.StatusMovedPermanently {
+				nextDepth := responseRecorder.Header().Get("X-Pgs-Redirect-Depth")
+				if nextDepth == "" {
+					t.Error("Expected X-Pgs-Redirect-Depth header in redirect response")
+				}
+			}
+		})
+	}
+}