package main import ( "crypto/rand" _ "embed" "encoding/base64" "encoding/json" "html/template" "net/http" "sync" "time" "github.com/google/uuid" "golang.org/x/oauth2" ) type Session struct { ID string `json:"id"` UserID string `json:"user_id"` ExpiresAt time.Time `json:"expires_at"` CreatedAt time.Time `json:"created_at"` } const SessionCookie = "session" type OAuthStore struct { sessions map[string]*Session mutex sync.RWMutex oa2 *oauth2.Config config *Config } func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore { return &OAuthStore{ sessions: make(map[string]*Session), oa2: oa2, config: config, } } func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) { uid, err := uuid.NewRandom() if err != nil { return nil, err } id := uid.String() session := &Session{ ID: id, UserID: userID, ExpiresAt: time.Now().Add(duration), CreatedAt: time.Now(), } s.mutex.Lock() s.sessions[id] = session s.mutex.Unlock() return session, nil } func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) { s.mutex.RLock() defer s.mutex.RUnlock() session, exists := s.sessions[sessionID] if !exists { return nil, false } if session.ExpiresAt.Before(time.Now()) { s.DeleteSession(sessionID) return nil, false } return session, true } func (s *OAuthStore) DeleteSession(sessionID string) { s.mutex.Lock() delete(s.sessions, sessionID) s.mutex.Unlock() } func sendToLoginPage(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect) } func generateRandomToken() string { b := make([]byte, 32) rand.Read(b) return base64.StdEncoding.EncodeToString(b) } //go:embed templates/LoginPage.html var loginPageContent string func (s *OAuthStore) LoginPage() http.Handler { loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { state := generateRandomToken() if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" { state = cookie.Value } else { http.SetCookie(w, &http.Cookie{ Name: "oauth_state", Value: state, HttpOnly: true, Secure: true, // Set to true in production SameSite: http.SameSiteLaxMode, MaxAge: 60 * 10, // 10 minutes }) } url := s.oa2.AuthCodeURL(state) provider := s.config.OAuthProvider.Kind switch provider { case "google": provider = "Google" case "github": provider = "GitHub" } loginPageTemplate.Execute(w, struct { Url string State string Provider string }{ Url: url, State: state, Provider: provider, }) }) } func (s *OAuthStore) Protected(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie(SessionCookie) if err != nil { sendToLoginPage(w, r) return } sess, exists := s.GetSession(cookie.Value) if !exists { sendToLoginPage(w, r) return } found := false for _, user := range s.config.AllowedUsers { if user == sess.UserID { found = true } } if !found { sendToLoginPage(w, r) return } next.ServeHTTP(w, r) }) } func (s *OAuthStore) CallbackHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { stateCookie, err := r.Cookie("oauth_state") if err != nil || stateCookie.Value != r.URL.Query().Get("state") { http.Error(w, "Invalid state", http.StatusBadRequest) return } tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code")) if err != nil { http.Error(w, "Failed to exchange token", http.StatusInternalServerError) return } userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken) if err != nil { http.Error(w, "Failed to get info", http.StatusInternalServerError) return } sess, err := s.CreateSession(userID, time.Hour*24) if err != nil { http.Error(w, "Failed to create session", http.StatusInternalServerError) return } http.SetCookie(w, &http.Cookie{ Name: SessionCookie, Value: sess.ID, HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, MaxAge: int(time.Hour.Seconds() * 24), Path: "/", }) // clear cookie http.SetCookie(w, &http.Cookie{ Name: "oauth_state", Value: "", MaxAge: -1, }) // TODO: remember what path the user was on and redirect them back there after doing the whole login process http.Redirect(w, r, "/", http.StatusTemporaryRedirect) }) } func getUserInfo(providerKind, token string) (string, error) { switch providerKind { case "google": type UserInfo struct { ID string `json:"sub"` Email string `json:"email"` Name string `json:"name"` } resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token) if err != nil { return "", err } defer resp.Body.Close() var userInfo UserInfo if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { return "", err } return userInfo.Email, nil case "github": type UserInfo struct { Login string `json:"login"` Name string `json:"name"` } req, _ := http.NewRequest("GET", "https://api.github.com/user", nil) req.Header.Add("Authorization", "Bearer "+token) resp, err := http.DefaultClient.Do(req) if err != nil { return "", err } defer resp.Body.Close() var userInfo UserInfo if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { return "", err } return userInfo.Login, nil default: panic("unimplemented") } }