Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ require (
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
github.com/yosida95/uritemplate/v3 v3.0.2
golang.org/x/oauth2 v0.35.0
)

require (
Expand All @@ -40,7 +41,6 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/oauth2 v0.35.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.28.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
Expand Down
157 changes: 157 additions & 0 deletions internal/oauth/callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package oauth

import (
"context"
"embed"
"fmt"
"html/template"
"net"
"net/http"
"time"
)

//go:embed templates/*.html
var templateFS embed.FS

var (
errorTemplate = template.Must(template.ParseFS(templateFS, "templates/error.html"))
successTemplate = template.Must(template.ParseFS(templateFS, "templates/success.html"))
)

// callbackResult is delivered by the callback server once the browser redirect
// arrives. Exactly one of code or err is set.
type callbackResult struct {
code string
err error
}

// callbackServer is a short-lived local HTTP server that captures the
// authorization code from the OAuth redirect.
type callbackServer struct {
server *http.Server
listener net.Listener
redirect string
results chan callbackResult
}

// listenCallback binds the local callback listener.
//
// It binds to loopback (127.0.0.1) by default so the callback server is never
// exposed on other interfaces. bindAll is set only inside a container, where
// Docker's published-port DNAT delivers traffic to the container's eth0 rather
// than to loopback; host-side exposure is still constrained by the publish
// (e.g. -p 127.0.0.1:8085:8085). A native run — even with a fixed port — stays
// on loopback.
func listenCallback(port int, bindAll bool) (net.Listener, error) {
host := "127.0.0.1"
if bindAll {
host = "0.0.0.0"
}
addr := fmt.Sprintf("%s:%d", host, port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("starting callback listener on %s: %w", addr, err)
}
return listener, nil
}

// newCallbackServer starts a callback server on listener that validates state
// and reports the result on a buffered channel. The redirect URI always uses
// localhost so it matches the value registered on the OAuth/GitHub App.
func newCallbackServer(listener net.Listener, expectedState string) *callbackServer {
cs := &callbackServer{
server: &http.Server{ReadHeaderTimeout: 10 * time.Second}, // ReadHeaderTimeout guards against Slowloris.
listener: listener,
redirect: fmt.Sprintf("http://localhost:%d/callback", listener.Addr().(*net.TCPAddr).Port),
results: make(chan callbackResult, 1),
}
cs.server.Handler = cs.handler(expectedState)

go func() {
if err := cs.server.Serve(listener); err != nil && err != http.ErrServerClosed {
cs.report(callbackResult{err: fmt.Errorf("callback server: %w", err)})
}
}()

return cs
}

// handler renders the callback endpoint. It reports the outcome exactly once and
// always shows the user a friendly page.
func (cs *callbackServer) handler(expectedState string) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()

if errCode := q.Get("error"); errCode != "" {
msg := errCode
if desc := q.Get("error_description"); desc != "" {
msg = fmt.Sprintf("%s: %s", errCode, desc)
}
cs.report(callbackResult{err: fmt.Errorf("authorization failed: %s", msg)})
renderError(w, msg)
return
}

if q.Get("state") != expectedState {
cs.report(callbackResult{err: fmt.Errorf("state mismatch (possible CSRF)")})
renderError(w, "state mismatch")
return
}

code := q.Get("code")
if code == "" {
cs.report(callbackResult{err: fmt.Errorf("no authorization code in callback")})
renderError(w, "no authorization code received")
return
}

cs.report(callbackResult{code: code})
renderSuccess(w)
})
return mux
}

// report delivers the first outcome and drops later ones (the channel is
// buffered for one; subsequent redirect retries must not block the handler).
func (cs *callbackServer) report(res callbackResult) {
select {
case cs.results <- res:
default:
}
}

// wait blocks for the callback outcome or ctx cancellation, then shuts the
// server down. It is safe to call once per server.
func (cs *callbackServer) wait(ctx context.Context) (string, error) {
defer cs.close()
select {
case res := <-cs.results:
return res.code, res.err
case <-ctx.Done():
return "", ctx.Err()
}
}

func (cs *callbackServer) close() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = cs.server.Shutdown(shutdownCtx)
_ = cs.listener.Close()
}

func renderSuccess(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := successTemplate.Execute(w, nil); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
}
}

// renderError shows the failure page. html/template auto-escapes msg, so a
// hostile error_description cannot inject markup.
func renderError(w http.ResponseWriter, msg string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := errorTemplate.Execute(w, struct{ ErrorMessage string }{ErrorMessage: msg}); err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
}
}
92 changes: 92 additions & 0 deletions internal/oauth/callback_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package oauth

import (
"net"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// serveCallback drives the callback handler with the given query string and
// returns the recorded response and the single reported result.
func serveCallback(t *testing.T, expectedState, query string) (*httptest.ResponseRecorder, callbackResult) {
t.Helper()
cs := &callbackServer{results: make(chan callbackResult, 1)}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/callback?"+query, nil)

cs.handler(expectedState).ServeHTTP(rec, req)

select {
case res := <-cs.results:
return rec, res
default:
t.Fatal("handler did not report a result")
return nil, callbackResult{}
}
}

func TestCallbackHandlerSuccess(t *testing.T) {
rec, res := serveCallback(t, "state123", "code=the-code&state=state123")

require.NoError(t, res.err)
assert.Equal(t, "the-code", res.code)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "Authorization Successful")
}

func TestCallbackHandlerStateMismatch(t *testing.T) {
rec, res := serveCallback(t, "expected", "code=the-code&state=attacker")

require.Error(t, res.err)
assert.Empty(t, res.code)
assert.Contains(t, res.err.Error(), "state mismatch")
assert.Contains(t, rec.Body.String(), "state mismatch")
}

func TestCallbackHandlerMissingCode(t *testing.T) {
_, res := serveCallback(t, "state123", "state=state123")

require.Error(t, res.err)
assert.Contains(t, res.err.Error(), "no authorization code")
}

func TestCallbackHandlerOAuthError(t *testing.T) {
_, res := serveCallback(t, "state123", "error=access_denied&error_description=user+said+no")

require.Error(t, res.err)
assert.Contains(t, res.err.Error(), "access_denied")
assert.Contains(t, res.err.Error(), "user said no")
}

func TestCallbackHandlerEscapesError(t *testing.T) {
rec, _ := serveCallback(t, "state123", "error=evil&error_description=%3Cscript%3Ealert(1)%3C%2Fscript%3E")

body := rec.Body.String()
assert.NotContains(t, body, "<script>", "error message must be HTML-escaped")
assert.Contains(t, body, "&lt;script&gt;")
}

func TestListenCallbackRandomPortIsLoopback(t *testing.T) {
listener, err := listenCallback(0, false)
require.NoError(t, err)
defer listener.Close()

addr, ok := listener.Addr().(*net.TCPAddr)
require.True(t, ok)
assert.True(t, addr.IP.IsLoopback(), "default bind must be loopback only, got %s", addr.IP)
assert.NotZero(t, addr.Port)
}

func TestListenCallbackBindAllForContainer(t *testing.T) {
listener, err := listenCallback(0, true)
require.NoError(t, err)
defer listener.Close()

addr, ok := listener.Addr().(*net.TCPAddr)
require.True(t, ok)
assert.True(t, addr.IP.IsUnspecified(), "bindAll must bind all interfaces, got %s", addr.IP)
}
63 changes: 63 additions & 0 deletions internal/oauth/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package oauth

import (
"fmt"
"io"
"os"
"os/exec"
"runtime"
"strings"
)

// openBrowser tries to open url in the user's default browser. It returns an
// error when no browser can plausibly be launched so the caller can fall back
// to elicitation. On Linux it treats a headless session (no display server) as
// unopenable, which is the common case for SSH and containers.
func openBrowser(url string) error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "linux":
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
return fmt.Errorf("no display server detected")
}
cmd = exec.Command("xdg-open", url)
case "darwin":
cmd = exec.Command("open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}

cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
if err := cmd.Start(); err != nil {
return err
}
// The launcher (xdg-open/open/rundll32) exits as soon as it hands off to the
// browser. Reap it asynchronously so it does not linger as a zombie for the
// lifetime of this long-running server.
go func() { _ = cmd.Wait() }()
return nil
}

// isRunningInDocker reports whether the process is running inside a Docker (or
// containerd) container. Detection relies on Linux-specific paths and is always
// false elsewhere. It is used only to skip a PKCE flow that cannot work: a
// random callback port inside a container cannot be reached from the host
// browser, so we go straight to device flow in that case.
func isRunningInDocker() bool {
if runtime.GOOS != "linux" {
return false
}
if _, err := os.Stat("/.dockerenv"); err == nil {
return true
}
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
s := string(data)
if strings.Contains(s, "docker") || strings.Contains(s, "containerd") {
return true
}
}
return false
}
Loading
Loading