it works!

This commit is contained in:
Lukas Werner 2025-07-02 00:27:22 -07:00
commit 7bbaf7ceb2
No known key found for this signature in database
7 changed files with 517 additions and 0 deletions

56
config.go Normal file
View File

@ -0,0 +1,56 @@
package main
import (
"fmt"
"github.com/BurntSushi/toml"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
type Upstream struct {
Addr string `toml:"addr"`
Program string `toml:"program"`
Args []string `toml:"args"`
}
type OAuthProvider struct {
Kind string `toml:"kind"`
ClientID string `toml:"client_id"`
ClientSecret string `toml:"client_secret"`
RedirectURL string `toml:"redirect_url"`
}
type Config struct {
ListenURL string `toml:"listen_url"`
GuardedPaths []string `toml:"guarded_paths"`
AllowedUsers []string `toml:"allowed_users"`
Upstream Upstream `toml:"upstream"`
OAuthProvider OAuthProvider `toml:"provider"`
}
func LoadConfig() (Config, oauth2.Config, error) {
config := Config{}
oa2 := oauth2.Config{}
_, err := toml.DecodeFile("config.toml", &config)
if err != nil {
return config, oa2, fmt.Errorf("unable to parse 'config.toml' tompl decoding error: %w", err)
}
oa2.ClientID = config.OAuthProvider.ClientID
oa2.ClientSecret = config.OAuthProvider.ClientSecret
oa2.Endpoint = oauth2.Endpoint{}
oa2.RedirectURL = config.OAuthProvider.RedirectURL
oa2.Scopes = []string{}
switch config.OAuthProvider.Kind {
case "github":
oa2.Endpoint = endpoints.GitHub
oa2.Scopes = []string{"read:user"}
case "google":
oa2.Endpoint = endpoints.Google
oa2.Scopes = []string{"https://www.googleapis.com/auth/userinfo.email"}
}
return config, oa2, err
}

17
config.toml Normal file
View File

@ -0,0 +1,17 @@
listen_url = "http://localhost:3000"
guarded_paths = ["/"]
# A list of all allowed users. For GitHub it is a list of usernames.
# For Google it is a list of emails
allowed_users = ["lukasmwerner"]
[upstream]
addr = "http://localhost:8080"
# An optional program to run as the upstream
program = "rezepte"
args = []
[provider]
kind = "github" # can `google` or `github`
client_id = "<CHANGE_ME>"
client_secret = "<CHANGE_ME>"
redirect_url = "http://localhost:3000/oauth/callback"

9
go.mod Normal file
View File

@ -0,0 +1,9 @@
module git.hafen.run/lukas/oauth-guard
go 1.24.3
require (
github.com/BurntSushi/toml v1.5.0
github.com/google/uuid v1.6.0
golang.org/x/oauth2 v0.30.0
)

6
go.sum Normal file
View File

@ -0,0 +1,6 @@
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=

91
main.go Normal file
View File

@ -0,0 +1,91 @@
package main
import (
"context"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/exec"
"os/signal"
"syscall"
)
func main() {
config, oa2, err := LoadConfig()
if err != nil {
log.Printf("error loading config: %s\n", err.Error())
}
upstream, _ := url.Parse(config.Upstream.Addr)
rp := httputil.NewSingleHostReverseProxy(upstream)
var cmd *exec.Cmd
if config.Upstream.Program != "" {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cmd = exec.CommandContext(ctx, config.Upstream.Program, config.Upstream.Args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
defer func() {
if cmd != nil && cmd.Process != nil {
log.Println("Terminating child process...")
cmd.Process.Signal(syscall.SIGTERM)
cmd.Process.Wait()
}
}()
go func() {
if err := cmd.Run(); err != nil {
log.Printf("Child process exited with error: %v\n", err)
}
}()
}
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
log.Println("Received shutdown signal, cleaning up...")
if cmd != nil && cmd.Process != nil {
cmd.Process.Signal(syscall.SIGTERM)
cmd.Process.Wait()
}
os.Exit(0)
}()
oauthStore := NewSessionStore(&config, &oa2)
http.Handle("/oauth/callback", oauthStore.CallbackHandler())
http.Handle("/oauth/login", oauthStore.LoginPage())
protectedRoot := false
for _, pattern := range config.GuardedPaths {
if pattern == "/" {
protectedRoot = true
}
http.Handle(pattern, oauthStore.Protected(rp))
}
if !protectedRoot {
http.Handle("/", rp)
}
http.Handle("/favicon.ico", rp)
listenAddr, _ := url.Parse(config.ListenURL)
listenPort := listenAddr.Port()
if listenPort == "" {
if listenAddr.Scheme == "https" {
listenPort = "443"
} else {
listenPort = "80"
}
}
log.Printf("Starting server on port %s\n", listenPort)
if err := http.ListenAndServe(":"+listenPort, nil); err != nil {
log.Printf("Server error: %v\n", err)
}
}

256
oauth.go Normal file
View File

@ -0,0 +1,256 @@
package main
import (
"crypto/rand"
_ "embed"
"encoding/base64"
"encoding/json"
"html/template"
"log"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/oauth2"
)
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
const SessionCookie = "session"
type OAuthStore struct {
sessions map[string]*Session
mutex sync.RWMutex
oa2 *oauth2.Config
config *Config
}
func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore {
return &OAuthStore{
sessions: make(map[string]*Session),
oa2: oa2,
config: config,
}
}
func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) {
uid, err := uuid.NewRandom()
if err != nil {
return nil, err
}
id := uid.String()
session := &Session{
ID: id,
UserID: userID,
ExpiresAt: time.Now().Add(duration),
CreatedAt: time.Now(),
}
s.mutex.Lock()
s.sessions[id] = session
s.mutex.Unlock()
return session, nil
}
func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, false
}
if session.ExpiresAt.Before(time.Now()) {
s.DeleteSession(sessionID)
return nil, false
}
return session, true
}
func (s *OAuthStore) DeleteSession(sessionID string) {
s.mutex.Lock()
delete(s.sessions, sessionID)
s.mutex.Unlock()
}
func sendToLoginPage(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect)
}
func generateRandomToken() string {
b := make([]byte, 32)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
//go:embed templates/LoginPage.html
var loginPageContent string
func (s *OAuthStore) LoginPage() http.Handler {
loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state := generateRandomToken()
if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" {
state = cookie.Value
} else {
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state,
HttpOnly: true,
Secure: true, // Set to true in production
SameSite: http.SameSiteLaxMode,
MaxAge: 60 * 10, // 10 minutes
})
}
url := s.oa2.AuthCodeURL(state)
provider := s.config.OAuthProvider.Kind
switch provider {
case "google":
provider = "Google"
case "github":
provider = "GitHub"
}
loginPageTemplate.Execute(w, struct {
Url string
State string
Provider string
}{
Url: url,
State: state,
Provider: provider,
})
})
}
func (s *OAuthStore) Protected(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(SessionCookie)
if err != nil {
sendToLoginPage(w, r)
return
}
sess, exists := s.GetSession(cookie.Value)
if !exists {
sendToLoginPage(w, r)
return
}
found := false
for _, user := range s.config.AllowedUsers {
if user == sess.UserID {
found = true
}
}
if !found {
sendToLoginPage(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (s *OAuthStore) CallbackHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stateCookie, err := r.Cookie("oauth_state")
if err != nil || stateCookie.Value != r.URL.Query().Get("state") {
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
return
}
userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken)
if err != nil {
http.Error(w, "Failed to get info", http.StatusInternalServerError)
return
}
println(userID)
sess, err := s.CreateSession(userID, time.Hour*24)
if err != nil {
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
http.SetCookie(w, &http.Cookie{
Name: SessionCookie,
Value: sess.ID,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(time.Hour.Seconds() * 24),
Path: "/",
})
// clear cookie
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: "",
MaxAge: -1,
})
// TODO: remember what path the user was on and redirect them back there after doing the whole login process
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
})
}
func getUserInfo(providerKind, token string) (string, error) {
switch providerKind {
case "google":
type UserInfo struct {
ID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
}
resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token)
if err != nil {
return "", err
}
defer resp.Body.Close()
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", err
}
return userInfo.Email, nil
case "github":
type UserInfo struct {
Login string `json:"login"`
Name string `json:"name"`
}
req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
req.Header.Add("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", err
}
return userInfo.Login, nil
default:
panic("unimplemented")
}
}

82
templates/LoginPage.html Normal file
View File

@ -0,0 +1,82 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Login</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}
.login-container {
background: white;
padding: 3rem;
border-radius: 16px;
max-width: 400px;
width: 90%;
display: flex;
justify-content: center;
}
.login-button {
background: #4c4c4c;
color: white;
border: none;
padding: 16px 32px;
border-radius: 12px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: all 0.3s ease;
display: inline-flex;
align-items: center;
gap: 12px;
text-decoration: none;
min-width: 200px;
justify-content: center;
}
.login-button:hover {
transform: translateY(-2px);
}
.login-button:active {
transform: translateY(0);
}
.lock-icon {
width: 20px;
height: 20px;
fill: currentColor;
}
h1 {
color: #333;
margin-bottom: 2rem;
font-weight: 300;
font-size: 2rem;
}
</style>
</head>
<body>
<div class="login-container">
<a href="{{.Url}}" class="login-button">
<svg class="lock-icon" viewBox="0 0 24 24">
<path d="M18,8h-1V6c0-2.76-2.24-5-5-5S7,3.24,7,6v2H6c-1.1,0-2,0.9-2,2v10c0,1.1,0.9,2,2,2h12c1.1,0,2-0.9,2-2V10C20,8.9,19.1,8,18,8z M12,17c-1.1,0-2-0.9-2-2s0.9-2,2-2s2,0.9,2,2S13.1,17,12,17z M15.1,8H8.9V6c0-1.71,1.39-3.1,3.1-3.1s3.1,1.39,3.1,3.1V8z"/>
</svg>
Login with {{.Provider}}
</a>
</div>
</body>
</html>