it works!
This commit is contained in:
commit
7bbaf7ceb2
56
config.go
Normal file
56
config.go
Normal file
@ -0,0 +1,56 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
)
|
||||
|
||||
type Upstream struct {
|
||||
Addr string `toml:"addr"`
|
||||
Program string `toml:"program"`
|
||||
Args []string `toml:"args"`
|
||||
}
|
||||
|
||||
type OAuthProvider struct {
|
||||
Kind string `toml:"kind"`
|
||||
ClientID string `toml:"client_id"`
|
||||
ClientSecret string `toml:"client_secret"`
|
||||
RedirectURL string `toml:"redirect_url"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ListenURL string `toml:"listen_url"`
|
||||
GuardedPaths []string `toml:"guarded_paths"`
|
||||
AllowedUsers []string `toml:"allowed_users"`
|
||||
Upstream Upstream `toml:"upstream"`
|
||||
OAuthProvider OAuthProvider `toml:"provider"`
|
||||
}
|
||||
|
||||
func LoadConfig() (Config, oauth2.Config, error) {
|
||||
config := Config{}
|
||||
oa2 := oauth2.Config{}
|
||||
_, err := toml.DecodeFile("config.toml", &config)
|
||||
if err != nil {
|
||||
return config, oa2, fmt.Errorf("unable to parse 'config.toml' tompl decoding error: %w", err)
|
||||
}
|
||||
|
||||
oa2.ClientID = config.OAuthProvider.ClientID
|
||||
oa2.ClientSecret = config.OAuthProvider.ClientSecret
|
||||
oa2.Endpoint = oauth2.Endpoint{}
|
||||
oa2.RedirectURL = config.OAuthProvider.RedirectURL
|
||||
oa2.Scopes = []string{}
|
||||
|
||||
switch config.OAuthProvider.Kind {
|
||||
case "github":
|
||||
oa2.Endpoint = endpoints.GitHub
|
||||
oa2.Scopes = []string{"read:user"}
|
||||
case "google":
|
||||
oa2.Endpoint = endpoints.Google
|
||||
oa2.Scopes = []string{"https://www.googleapis.com/auth/userinfo.email"}
|
||||
}
|
||||
|
||||
return config, oa2, err
|
||||
}
|
17
config.toml
Normal file
17
config.toml
Normal file
@ -0,0 +1,17 @@
|
||||
listen_url = "http://localhost:3000"
|
||||
guarded_paths = ["/"]
|
||||
# A list of all allowed users. For GitHub it is a list of usernames.
|
||||
# For Google it is a list of emails
|
||||
allowed_users = ["lukasmwerner"]
|
||||
|
||||
[upstream]
|
||||
addr = "http://localhost:8080"
|
||||
# An optional program to run as the upstream
|
||||
program = "rezepte"
|
||||
args = []
|
||||
|
||||
[provider]
|
||||
kind = "github" # can `google` or `github`
|
||||
client_id = "<CHANGE_ME>"
|
||||
client_secret = "<CHANGE_ME>"
|
||||
redirect_url = "http://localhost:3000/oauth/callback"
|
9
go.mod
Normal file
9
go.mod
Normal file
@ -0,0 +1,9 @@
|
||||
module git.hafen.run/lukas/oauth-guard
|
||||
|
||||
go 1.24.3
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.5.0
|
||||
github.com/google/uuid v1.6.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
)
|
6
go.sum
Normal file
6
go.sum
Normal file
@ -0,0 +1,6 @@
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
91
main.go
Normal file
91
main.go
Normal file
@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config, oa2, err := LoadConfig()
|
||||
if err != nil {
|
||||
log.Printf("error loading config: %s\n", err.Error())
|
||||
}
|
||||
|
||||
upstream, _ := url.Parse(config.Upstream.Addr)
|
||||
rp := httputil.NewSingleHostReverseProxy(upstream)
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if config.Upstream.Program != "" {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cmd = exec.CommandContext(ctx, config.Upstream.Program, config.Upstream.Args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
defer func() {
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
log.Println("Terminating child process...")
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
cmd.Process.Wait()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Printf("Child process exited with error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-c
|
||||
log.Println("Received shutdown signal, cleaning up...")
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
cmd.Process.Signal(syscall.SIGTERM)
|
||||
cmd.Process.Wait()
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
oauthStore := NewSessionStore(&config, &oa2)
|
||||
|
||||
http.Handle("/oauth/callback", oauthStore.CallbackHandler())
|
||||
http.Handle("/oauth/login", oauthStore.LoginPage())
|
||||
protectedRoot := false
|
||||
for _, pattern := range config.GuardedPaths {
|
||||
if pattern == "/" {
|
||||
protectedRoot = true
|
||||
}
|
||||
http.Handle(pattern, oauthStore.Protected(rp))
|
||||
}
|
||||
if !protectedRoot {
|
||||
http.Handle("/", rp)
|
||||
}
|
||||
http.Handle("/favicon.ico", rp)
|
||||
|
||||
listenAddr, _ := url.Parse(config.ListenURL)
|
||||
listenPort := listenAddr.Port()
|
||||
if listenPort == "" {
|
||||
if listenAddr.Scheme == "https" {
|
||||
listenPort = "443"
|
||||
} else {
|
||||
listenPort = "80"
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Starting server on port %s\n", listenPort)
|
||||
if err := http.ListenAndServe(":"+listenPort, nil); err != nil {
|
||||
log.Printf("Server error: %v\n", err)
|
||||
}
|
||||
}
|
256
oauth.go
Normal file
256
oauth.go
Normal file
@ -0,0 +1,256 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
_ "embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
const SessionCookie = "session"
|
||||
|
||||
type OAuthStore struct {
|
||||
sessions map[string]*Session
|
||||
mutex sync.RWMutex
|
||||
oa2 *oauth2.Config
|
||||
config *Config
|
||||
}
|
||||
|
||||
func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore {
|
||||
return &OAuthStore{
|
||||
sessions: make(map[string]*Session),
|
||||
oa2: oa2,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) {
|
||||
uid, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id := uid.String()
|
||||
|
||||
session := &Session{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(duration),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.sessions[id] = session
|
||||
s.mutex.Unlock()
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
session, exists := s.sessions[sessionID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
s.DeleteSession(sessionID)
|
||||
return nil, false
|
||||
}
|
||||
return session, true
|
||||
|
||||
}
|
||||
|
||||
func (s *OAuthStore) DeleteSession(sessionID string) {
|
||||
s.mutex.Lock()
|
||||
delete(s.sessions, sessionID)
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
|
||||
func sendToLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect)
|
||||
}
|
||||
func generateRandomToken() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
//go:embed templates/LoginPage.html
|
||||
var loginPageContent string
|
||||
|
||||
func (s *OAuthStore) LoginPage() http.Handler {
|
||||
|
||||
loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent))
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
state := generateRandomToken()
|
||||
|
||||
if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" {
|
||||
state = cookie.Value
|
||||
} else {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state,
|
||||
HttpOnly: true,
|
||||
Secure: true, // Set to true in production
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: 60 * 10, // 10 minutes
|
||||
})
|
||||
}
|
||||
|
||||
url := s.oa2.AuthCodeURL(state)
|
||||
|
||||
provider := s.config.OAuthProvider.Kind
|
||||
switch provider {
|
||||
case "google":
|
||||
provider = "Google"
|
||||
case "github":
|
||||
provider = "GitHub"
|
||||
}
|
||||
|
||||
loginPageTemplate.Execute(w, struct {
|
||||
Url string
|
||||
State string
|
||||
Provider string
|
||||
}{
|
||||
Url: url,
|
||||
State: state,
|
||||
Provider: provider,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OAuthStore) Protected(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
cookie, err := r.Cookie(SessionCookie)
|
||||
if err != nil {
|
||||
sendToLoginPage(w, r)
|
||||
return
|
||||
}
|
||||
sess, exists := s.GetSession(cookie.Value)
|
||||
if !exists {
|
||||
sendToLoginPage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, user := range s.config.AllowedUsers {
|
||||
if user == sess.UserID {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
sendToLoginPage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OAuthStore) CallbackHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
stateCookie, err := r.Cookie("oauth_state")
|
||||
if err != nil || stateCookie.Value != r.URL.Query().Get("state") {
|
||||
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get info", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
println(userID)
|
||||
sess, err := s.CreateSession(userID, time.Hour*24)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookie,
|
||||
Value: sess.ID,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(time.Hour.Seconds() * 24),
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
// clear cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
})
|
||||
|
||||
// TODO: remember what path the user was on and redirect them back there after doing the whole login process
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
})
|
||||
}
|
||||
|
||||
func getUserInfo(providerKind, token string) (string, error) {
|
||||
switch providerKind {
|
||||
case "google":
|
||||
type UserInfo struct {
|
||||
ID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return userInfo.Email, nil
|
||||
case "github":
|
||||
type UserInfo struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
req.Header.Add("Authorization", "Bearer "+token)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var userInfo UserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userInfo.Login, nil
|
||||
default:
|
||||
panic("unimplemented")
|
||||
}
|
||||
}
|
82
templates/LoginPage.html
Normal file
82
templates/LoginPage.html
Normal file
@ -0,0 +1,82 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>Login</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.login-container {
|
||||
background: white;
|
||||
padding: 3rem;
|
||||
border-radius: 16px;
|
||||
max-width: 400px;
|
||||
width: 90%;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.login-button {
|
||||
background: #4c4c4c;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 16px 32px;
|
||||
border-radius: 12px;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
text-decoration: none;
|
||||
min-width: 200px;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.login-button:hover {
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.login-button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.lock-icon {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
fill: currentColor;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #333;
|
||||
margin-bottom: 2rem;
|
||||
font-weight: 300;
|
||||
font-size: 2rem;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="login-container">
|
||||
<a href="{{.Url}}" class="login-button">
|
||||
<svg class="lock-icon" viewBox="0 0 24 24">
|
||||
<path d="M18,8h-1V6c0-2.76-2.24-5-5-5S7,3.24,7,6v2H6c-1.1,0-2,0.9-2,2v10c0,1.1,0.9,2,2,2h12c1.1,0,2-0.9,2-2V10C20,8.9,19.1,8,18,8z M12,17c-1.1,0-2-0.9-2-2s0.9-2,2-2s2,0.9,2,2S13.1,17,12,17z M15.1,8H8.9V6c0-1.71,1.39-3.1,3.1-3.1s3.1,1.39,3.1,3.1V8z"/>
|
||||
</svg>
|
||||
Login with {{.Provider}}
|
||||
</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
Loading…
x
Reference in New Issue
Block a user