255 lines
5.6 KiB
Go
255 lines
5.6 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
_ "embed"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"html/template"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type Session struct {
|
|
ID string `json:"id"`
|
|
UserID string `json:"user_id"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
const SessionCookie = "session"
|
|
|
|
type OAuthStore struct {
|
|
sessions map[string]*Session
|
|
mutex sync.RWMutex
|
|
oa2 *oauth2.Config
|
|
config *Config
|
|
}
|
|
|
|
func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore {
|
|
return &OAuthStore{
|
|
sessions: make(map[string]*Session),
|
|
oa2: oa2,
|
|
config: config,
|
|
}
|
|
}
|
|
|
|
func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) {
|
|
uid, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id := uid.String()
|
|
|
|
session := &Session{
|
|
ID: id,
|
|
UserID: userID,
|
|
ExpiresAt: time.Now().Add(duration),
|
|
CreatedAt: time.Now(),
|
|
}
|
|
|
|
s.mutex.Lock()
|
|
s.sessions[id] = session
|
|
s.mutex.Unlock()
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) {
|
|
s.mutex.RLock()
|
|
defer s.mutex.RUnlock()
|
|
|
|
session, exists := s.sessions[sessionID]
|
|
if !exists {
|
|
return nil, false
|
|
}
|
|
if session.ExpiresAt.Before(time.Now()) {
|
|
s.DeleteSession(sessionID)
|
|
return nil, false
|
|
}
|
|
return session, true
|
|
|
|
}
|
|
|
|
func (s *OAuthStore) DeleteSession(sessionID string) {
|
|
s.mutex.Lock()
|
|
delete(s.sessions, sessionID)
|
|
s.mutex.Unlock()
|
|
}
|
|
|
|
func sendToLoginPage(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect)
|
|
}
|
|
func generateRandomToken() string {
|
|
b := make([]byte, 32)
|
|
rand.Read(b)
|
|
return base64.StdEncoding.EncodeToString(b)
|
|
}
|
|
|
|
//go:embed templates/LoginPage.html
|
|
var loginPageContent string
|
|
|
|
func (s *OAuthStore) LoginPage() http.Handler {
|
|
|
|
loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent))
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
state := generateRandomToken()
|
|
|
|
if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" {
|
|
state = cookie.Value
|
|
} else {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Value: state,
|
|
HttpOnly: true,
|
|
Secure: true, // Set to true in production
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: 60 * 10, // 10 minutes
|
|
})
|
|
}
|
|
|
|
url := s.oa2.AuthCodeURL(state)
|
|
|
|
provider := s.config.OAuthProvider.Kind
|
|
switch provider {
|
|
case "google":
|
|
provider = "Google"
|
|
case "github":
|
|
provider = "GitHub"
|
|
}
|
|
|
|
loginPageTemplate.Execute(w, struct {
|
|
Url string
|
|
State string
|
|
Provider string
|
|
}{
|
|
Url: url,
|
|
State: state,
|
|
Provider: provider,
|
|
})
|
|
})
|
|
}
|
|
|
|
func (s *OAuthStore) Protected(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
cookie, err := r.Cookie(SessionCookie)
|
|
if err != nil {
|
|
sendToLoginPage(w, r)
|
|
return
|
|
}
|
|
sess, exists := s.GetSession(cookie.Value)
|
|
if !exists {
|
|
sendToLoginPage(w, r)
|
|
return
|
|
}
|
|
|
|
found := false
|
|
for _, user := range s.config.AllowedUsers {
|
|
if user == sess.UserID {
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
sendToLoginPage(w, r)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (s *OAuthStore) CallbackHandler() http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
stateCookie, err := r.Cookie("oauth_state")
|
|
if err != nil || stateCookie.Value != r.URL.Query().Get("state") {
|
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code"))
|
|
if err != nil {
|
|
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken)
|
|
if err != nil {
|
|
http.Error(w, "Failed to get info", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
sess, err := s.CreateSession(userID, time.Hour*24)
|
|
if err != nil {
|
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: SessionCookie,
|
|
Value: sess.ID,
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: int(time.Hour.Seconds() * 24),
|
|
Path: "/",
|
|
})
|
|
|
|
// clear cookie
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Value: "",
|
|
MaxAge: -1,
|
|
})
|
|
|
|
// TODO: remember what path the user was on and redirect them back there after doing the whole login process
|
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
})
|
|
}
|
|
|
|
func getUserInfo(providerKind, token string) (string, error) {
|
|
switch providerKind {
|
|
case "google":
|
|
type UserInfo struct {
|
|
ID string `json:"sub"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
}
|
|
resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var userInfo UserInfo
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return userInfo.Email, nil
|
|
case "github":
|
|
type UserInfo struct {
|
|
Login string `json:"login"`
|
|
Name string `json:"name"`
|
|
}
|
|
req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
|
|
req.Header.Add("Authorization", "Bearer "+token)
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var userInfo UserInfo
|
|
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
|
return "", err
|
|
}
|
|
return userInfo.Login, nil
|
|
default:
|
|
panic("unimplemented")
|
|
}
|
|
}
|