commit 7bbaf7ceb2019453788053b4f3fc514041ecbb6c Author: Lukas Werner Date: Wed Jul 2 00:27:22 2025 -0700 it works! diff --git a/config.go b/config.go new file mode 100644 index 0000000..b2c862f --- /dev/null +++ b/config.go @@ -0,0 +1,56 @@ +package main + +import ( + "fmt" + + "github.com/BurntSushi/toml" + "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" +) + +type Upstream struct { + Addr string `toml:"addr"` + Program string `toml:"program"` + Args []string `toml:"args"` +} + +type OAuthProvider struct { + Kind string `toml:"kind"` + ClientID string `toml:"client_id"` + ClientSecret string `toml:"client_secret"` + RedirectURL string `toml:"redirect_url"` +} + +type Config struct { + ListenURL string `toml:"listen_url"` + GuardedPaths []string `toml:"guarded_paths"` + AllowedUsers []string `toml:"allowed_users"` + Upstream Upstream `toml:"upstream"` + OAuthProvider OAuthProvider `toml:"provider"` +} + +func LoadConfig() (Config, oauth2.Config, error) { + config := Config{} + oa2 := oauth2.Config{} + _, err := toml.DecodeFile("config.toml", &config) + if err != nil { + return config, oa2, fmt.Errorf("unable to parse 'config.toml' tompl decoding error: %w", err) + } + + oa2.ClientID = config.OAuthProvider.ClientID + oa2.ClientSecret = config.OAuthProvider.ClientSecret + oa2.Endpoint = oauth2.Endpoint{} + oa2.RedirectURL = config.OAuthProvider.RedirectURL + oa2.Scopes = []string{} + + switch config.OAuthProvider.Kind { + case "github": + oa2.Endpoint = endpoints.GitHub + oa2.Scopes = []string{"read:user"} + case "google": + oa2.Endpoint = endpoints.Google + oa2.Scopes = []string{"https://www.googleapis.com/auth/userinfo.email"} + } + + return config, oa2, err +} diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..f3e9377 --- /dev/null +++ b/config.toml @@ -0,0 +1,17 @@ +listen_url = "http://localhost:3000" +guarded_paths = ["/"] +# A list of all allowed users. For GitHub it is a list of usernames. +# For Google it is a list of emails +allowed_users = ["lukasmwerner"] + +[upstream] +addr = "http://localhost:8080" +# An optional program to run as the upstream +program = "rezepte" +args = [] + +[provider] +kind = "github" # can `google` or `github` +client_id = "" +client_secret = "" +redirect_url = "http://localhost:3000/oauth/callback" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..340342f --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module git.hafen.run/lukas/oauth-guard + +go 1.24.3 + +require ( + github.com/BurntSushi/toml v1.5.0 + github.com/google/uuid v1.6.0 + golang.org/x/oauth2 v0.30.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b2f0a1f --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= diff --git a/main.go b/main.go new file mode 100644 index 0000000..f148c2d --- /dev/null +++ b/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "context" + "log" + "net/http" + "net/http/httputil" + "net/url" + "os" + "os/exec" + "os/signal" + "syscall" +) + +func main() { + config, oa2, err := LoadConfig() + if err != nil { + log.Printf("error loading config: %s\n", err.Error()) + } + + upstream, _ := url.Parse(config.Upstream.Addr) + rp := httputil.NewSingleHostReverseProxy(upstream) + + var cmd *exec.Cmd + if config.Upstream.Program != "" { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cmd = exec.CommandContext(ctx, config.Upstream.Program, config.Upstream.Args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + defer func() { + if cmd != nil && cmd.Process != nil { + log.Println("Terminating child process...") + cmd.Process.Signal(syscall.SIGTERM) + cmd.Process.Wait() + } + }() + + go func() { + if err := cmd.Run(); err != nil { + log.Printf("Child process exited with error: %v\n", err) + } + }() + } + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + log.Println("Received shutdown signal, cleaning up...") + if cmd != nil && cmd.Process != nil { + cmd.Process.Signal(syscall.SIGTERM) + cmd.Process.Wait() + } + os.Exit(0) + }() + + oauthStore := NewSessionStore(&config, &oa2) + + http.Handle("/oauth/callback", oauthStore.CallbackHandler()) + http.Handle("/oauth/login", oauthStore.LoginPage()) + protectedRoot := false + for _, pattern := range config.GuardedPaths { + if pattern == "/" { + protectedRoot = true + } + http.Handle(pattern, oauthStore.Protected(rp)) + } + if !protectedRoot { + http.Handle("/", rp) + } + http.Handle("/favicon.ico", rp) + + listenAddr, _ := url.Parse(config.ListenURL) + listenPort := listenAddr.Port() + if listenPort == "" { + if listenAddr.Scheme == "https" { + listenPort = "443" + } else { + listenPort = "80" + } + } + + log.Printf("Starting server on port %s\n", listenPort) + if err := http.ListenAndServe(":"+listenPort, nil); err != nil { + log.Printf("Server error: %v\n", err) + } +} diff --git a/oauth.go b/oauth.go new file mode 100644 index 0000000..5f58951 --- /dev/null +++ b/oauth.go @@ -0,0 +1,256 @@ +package main + +import ( + "crypto/rand" + _ "embed" + "encoding/base64" + "encoding/json" + "html/template" + "log" + "net/http" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/oauth2" +) + +type Session struct { + ID string `json:"id"` + UserID string `json:"user_id"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` +} + +const SessionCookie = "session" + +type OAuthStore struct { + sessions map[string]*Session + mutex sync.RWMutex + oa2 *oauth2.Config + config *Config +} + +func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore { + return &OAuthStore{ + sessions: make(map[string]*Session), + oa2: oa2, + config: config, + } +} + +func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) { + uid, err := uuid.NewRandom() + if err != nil { + return nil, err + } + id := uid.String() + + session := &Session{ + ID: id, + UserID: userID, + ExpiresAt: time.Now().Add(duration), + CreatedAt: time.Now(), + } + + s.mutex.Lock() + s.sessions[id] = session + s.mutex.Unlock() + + return session, nil +} + +func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) { + s.mutex.RLock() + defer s.mutex.RUnlock() + + session, exists := s.sessions[sessionID] + if !exists { + return nil, false + } + if session.ExpiresAt.Before(time.Now()) { + s.DeleteSession(sessionID) + return nil, false + } + return session, true + +} + +func (s *OAuthStore) DeleteSession(sessionID string) { + s.mutex.Lock() + delete(s.sessions, sessionID) + s.mutex.Unlock() +} + +func sendToLoginPage(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect) +} +func generateRandomToken() string { + b := make([]byte, 32) + rand.Read(b) + return base64.StdEncoding.EncodeToString(b) +} + +//go:embed templates/LoginPage.html +var loginPageContent string + +func (s *OAuthStore) LoginPage() http.Handler { + + loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent)) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + state := generateRandomToken() + + if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" { + state = cookie.Value + } else { + http.SetCookie(w, &http.Cookie{ + Name: "oauth_state", + Value: state, + HttpOnly: true, + Secure: true, // Set to true in production + SameSite: http.SameSiteLaxMode, + MaxAge: 60 * 10, // 10 minutes + }) + } + + url := s.oa2.AuthCodeURL(state) + + provider := s.config.OAuthProvider.Kind + switch provider { + case "google": + provider = "Google" + case "github": + provider = "GitHub" + } + + loginPageTemplate.Execute(w, struct { + Url string + State string + Provider string + }{ + Url: url, + State: state, + Provider: provider, + }) + }) +} + +func (s *OAuthStore) Protected(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + cookie, err := r.Cookie(SessionCookie) + if err != nil { + sendToLoginPage(w, r) + return + } + sess, exists := s.GetSession(cookie.Value) + if !exists { + sendToLoginPage(w, r) + return + } + + found := false + for _, user := range s.config.AllowedUsers { + if user == sess.UserID { + found = true + } + } + if !found { + sendToLoginPage(w, r) + return + } + + next.ServeHTTP(w, r) + }) +} + +func (s *OAuthStore) CallbackHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + stateCookie, err := r.Cookie("oauth_state") + if err != nil || stateCookie.Value != r.URL.Query().Get("state") { + http.Error(w, "Invalid state", http.StatusBadRequest) + return + } + + tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code")) + if err != nil { + http.Error(w, "Failed to exchange token", http.StatusInternalServerError) + return + } + userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken) + if err != nil { + http.Error(w, "Failed to get info", http.StatusInternalServerError) + return + } + println(userID) + sess, err := s.CreateSession(userID, time.Hour*24) + if err != nil { + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: SessionCookie, + Value: sess.ID, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: int(time.Hour.Seconds() * 24), + Path: "/", + }) + + // clear cookie + http.SetCookie(w, &http.Cookie{ + Name: "oauth_state", + Value: "", + MaxAge: -1, + }) + + // TODO: remember what path the user was on and redirect them back there after doing the whole login process + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + }) +} + +func getUserInfo(providerKind, token string) (string, error) { + switch providerKind { + case "google": + type UserInfo struct { + ID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + } + resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return "", err + } + + return userInfo.Email, nil + case "github": + type UserInfo struct { + Login string `json:"login"` + Name string `json:"name"` + } + req, _ := http.NewRequest("GET", "https://api.github.com/user", nil) + req.Header.Add("Authorization", "Bearer "+token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return "", err + } + return userInfo.Login, nil + default: + panic("unimplemented") + } +} diff --git a/templates/LoginPage.html b/templates/LoginPage.html new file mode 100644 index 0000000..486e7dd --- /dev/null +++ b/templates/LoginPage.html @@ -0,0 +1,82 @@ + + + + + + Login + + + + + +