oauth-guard/oauth.go
2025-07-03 17:07:35 -07:00

255 lines
5.6 KiB
Go

package main
import (
"crypto/rand"
_ "embed"
"encoding/base64"
"encoding/json"
"html/template"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/oauth2"
)
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
const SessionCookie = "session"
type OAuthStore struct {
sessions map[string]*Session
mutex sync.RWMutex
oa2 *oauth2.Config
config *Config
}
func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore {
return &OAuthStore{
sessions: make(map[string]*Session),
oa2: oa2,
config: config,
}
}
func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) {
uid, err := uuid.NewRandom()
if err != nil {
return nil, err
}
id := uid.String()
session := &Session{
ID: id,
UserID: userID,
ExpiresAt: time.Now().Add(duration),
CreatedAt: time.Now(),
}
s.mutex.Lock()
s.sessions[id] = session
s.mutex.Unlock()
return session, nil
}
func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) {
s.mutex.RLock()
defer s.mutex.RUnlock()
session, exists := s.sessions[sessionID]
if !exists {
return nil, false
}
if session.ExpiresAt.Before(time.Now()) {
s.DeleteSession(sessionID)
return nil, false
}
return session, true
}
func (s *OAuthStore) DeleteSession(sessionID string) {
s.mutex.Lock()
delete(s.sessions, sessionID)
s.mutex.Unlock()
}
func sendToLoginPage(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect)
}
func generateRandomToken() string {
b := make([]byte, 32)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
//go:embed templates/LoginPage.html
var loginPageContent string
func (s *OAuthStore) LoginPage() http.Handler {
loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
state := generateRandomToken()
if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" {
state = cookie.Value
} else {
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state,
HttpOnly: true,
Secure: true, // Set to true in production
SameSite: http.SameSiteLaxMode,
MaxAge: 60 * 10, // 10 minutes
})
}
url := s.oa2.AuthCodeURL(state)
provider := s.config.OAuthProvider.Kind
switch provider {
case "google":
provider = "Google"
case "github":
provider = "GitHub"
}
loginPageTemplate.Execute(w, struct {
Url string
State string
Provider string
}{
Url: url,
State: state,
Provider: provider,
})
})
}
func (s *OAuthStore) Protected(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(SessionCookie)
if err != nil {
sendToLoginPage(w, r)
return
}
sess, exists := s.GetSession(cookie.Value)
if !exists {
sendToLoginPage(w, r)
return
}
found := false
for _, user := range s.config.AllowedUsers {
if user == sess.UserID {
found = true
}
}
if !found {
sendToLoginPage(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (s *OAuthStore) CallbackHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
stateCookie, err := r.Cookie("oauth_state")
if err != nil || stateCookie.Value != r.URL.Query().Get("state") {
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
return
}
userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken, s.config)
if err != nil {
http.Error(w, "Failed to get info", http.StatusInternalServerError)
return
}
sess, err := s.CreateSession(userID, time.Hour*24)
if err != nil {
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
http.SetCookie(w, &http.Cookie{
Name: SessionCookie,
Value: sess.ID,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(time.Hour.Seconds() * 24),
Path: "/",
})
// clear cookie
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: "",
MaxAge: -1,
})
// TODO: remember what path the user was on and redirect them back there after doing the whole login process
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
})
}
func getUserInfo(providerKind, token string, c *Config) (string, error) {
switch providerKind {
case "google":
type UserInfo struct {
ID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
}
resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token)
if err != nil {
return "", err
}
defer resp.Body.Close()
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", err
}
return userInfo.Email, nil
case "github":
type UserInfo struct {
Login string `json:"login"`
Name string `json:"name"`
}
req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
req.Header.Add("Authorization", "Bearer "+token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
var userInfo UserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", err
}
return userInfo.Login, nil
default:
return RunScript(c.OAuthProvider.Script, token)
}
}