it works!
This commit is contained in:
		
						commit
						7bbaf7ceb2
					
				
							
								
								
									
										56
									
								
								config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								config.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,56 @@
 | 
				
			|||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/BurntSushi/toml"
 | 
				
			||||||
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
 | 
						"golang.org/x/oauth2/endpoints"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Upstream struct {
 | 
				
			||||||
 | 
						Addr    string   `toml:"addr"`
 | 
				
			||||||
 | 
						Program string   `toml:"program"`
 | 
				
			||||||
 | 
						Args    []string `toml:"args"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OAuthProvider struct {
 | 
				
			||||||
 | 
						Kind         string `toml:"kind"`
 | 
				
			||||||
 | 
						ClientID     string `toml:"client_id"`
 | 
				
			||||||
 | 
						ClientSecret string `toml:"client_secret"`
 | 
				
			||||||
 | 
						RedirectURL  string `toml:"redirect_url"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Config struct {
 | 
				
			||||||
 | 
						ListenURL     string        `toml:"listen_url"`
 | 
				
			||||||
 | 
						GuardedPaths  []string      `toml:"guarded_paths"`
 | 
				
			||||||
 | 
						AllowedUsers  []string      `toml:"allowed_users"`
 | 
				
			||||||
 | 
						Upstream      Upstream      `toml:"upstream"`
 | 
				
			||||||
 | 
						OAuthProvider OAuthProvider `toml:"provider"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func LoadConfig() (Config, oauth2.Config, error) {
 | 
				
			||||||
 | 
						config := Config{}
 | 
				
			||||||
 | 
						oa2 := oauth2.Config{}
 | 
				
			||||||
 | 
						_, err := toml.DecodeFile("config.toml", &config)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return config, oa2, fmt.Errorf("unable to parse 'config.toml' tompl decoding error: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oa2.ClientID = config.OAuthProvider.ClientID
 | 
				
			||||||
 | 
						oa2.ClientSecret = config.OAuthProvider.ClientSecret
 | 
				
			||||||
 | 
						oa2.Endpoint = oauth2.Endpoint{}
 | 
				
			||||||
 | 
						oa2.RedirectURL = config.OAuthProvider.RedirectURL
 | 
				
			||||||
 | 
						oa2.Scopes = []string{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch config.OAuthProvider.Kind {
 | 
				
			||||||
 | 
						case "github":
 | 
				
			||||||
 | 
							oa2.Endpoint = endpoints.GitHub
 | 
				
			||||||
 | 
							oa2.Scopes = []string{"read:user"}
 | 
				
			||||||
 | 
						case "google":
 | 
				
			||||||
 | 
							oa2.Endpoint = endpoints.Google
 | 
				
			||||||
 | 
							oa2.Scopes = []string{"https://www.googleapis.com/auth/userinfo.email"}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return config, oa2, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										17
									
								
								config.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								config.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					listen_url = "http://localhost:3000"
 | 
				
			||||||
 | 
					guarded_paths = ["/"]
 | 
				
			||||||
 | 
					# A list of all allowed users. For GitHub it is a list of usernames.
 | 
				
			||||||
 | 
					# For Google it is a list of emails
 | 
				
			||||||
 | 
					allowed_users = ["lukasmwerner"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[upstream]
 | 
				
			||||||
 | 
					addr = "http://localhost:8080"
 | 
				
			||||||
 | 
					# An optional program to run as the upstream
 | 
				
			||||||
 | 
					program = "rezepte"
 | 
				
			||||||
 | 
					args = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[provider]
 | 
				
			||||||
 | 
					kind = "github" # can `google` or `github`
 | 
				
			||||||
 | 
					client_id = "<CHANGE_ME>"
 | 
				
			||||||
 | 
					client_secret = "<CHANGE_ME>"
 | 
				
			||||||
 | 
					redirect_url = "http://localhost:3000/oauth/callback"
 | 
				
			||||||
							
								
								
									
										9
									
								
								go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								go.mod
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					module git.hafen.run/lukas/oauth-guard
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					go 1.24.3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					require (
 | 
				
			||||||
 | 
						github.com/BurntSushi/toml v1.5.0
 | 
				
			||||||
 | 
						github.com/google/uuid v1.6.0
 | 
				
			||||||
 | 
						golang.org/x/oauth2 v0.30.0
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										6
									
								
								go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								go.sum
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
 | 
				
			||||||
 | 
					github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
 | 
				
			||||||
 | 
					github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
 | 
				
			||||||
 | 
					github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
				
			||||||
 | 
					golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
 | 
				
			||||||
 | 
					golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
 | 
				
			||||||
							
								
								
									
										91
									
								
								main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								main.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
				
			|||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httputil"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/exec"
 | 
				
			||||||
 | 
						"os/signal"
 | 
				
			||||||
 | 
						"syscall"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func main() {
 | 
				
			||||||
 | 
						config, oa2, err := LoadConfig()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Printf("error loading config: %s\n", err.Error())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						upstream, _ := url.Parse(config.Upstream.Addr)
 | 
				
			||||||
 | 
						rp := httputil.NewSingleHostReverseProxy(upstream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var cmd *exec.Cmd
 | 
				
			||||||
 | 
						if config.Upstream.Program != "" {
 | 
				
			||||||
 | 
							ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
 | 
							defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cmd = exec.CommandContext(ctx, config.Upstream.Program, config.Upstream.Args...)
 | 
				
			||||||
 | 
							cmd.Stdin = os.Stdin
 | 
				
			||||||
 | 
							cmd.Stdout = os.Stdout
 | 
				
			||||||
 | 
							cmd.Stderr = os.Stderr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							defer func() {
 | 
				
			||||||
 | 
								if cmd != nil && cmd.Process != nil {
 | 
				
			||||||
 | 
									log.Println("Terminating child process...")
 | 
				
			||||||
 | 
									cmd.Process.Signal(syscall.SIGTERM)
 | 
				
			||||||
 | 
									cmd.Process.Wait()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								if err := cmd.Run(); err != nil {
 | 
				
			||||||
 | 
									log.Printf("Child process exited with error: %v\n", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c := make(chan os.Signal, 1)
 | 
				
			||||||
 | 
						signal.Notify(c, os.Interrupt, syscall.SIGTERM)
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							<-c
 | 
				
			||||||
 | 
							log.Println("Received shutdown signal, cleaning up...")
 | 
				
			||||||
 | 
							if cmd != nil && cmd.Process != nil {
 | 
				
			||||||
 | 
								cmd.Process.Signal(syscall.SIGTERM)
 | 
				
			||||||
 | 
								cmd.Process.Wait()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							os.Exit(0)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oauthStore := NewSessionStore(&config, &oa2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						http.Handle("/oauth/callback", oauthStore.CallbackHandler())
 | 
				
			||||||
 | 
						http.Handle("/oauth/login", oauthStore.LoginPage())
 | 
				
			||||||
 | 
						protectedRoot := false
 | 
				
			||||||
 | 
						for _, pattern := range config.GuardedPaths {
 | 
				
			||||||
 | 
							if pattern == "/" {
 | 
				
			||||||
 | 
								protectedRoot = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							http.Handle(pattern, oauthStore.Protected(rp))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !protectedRoot {
 | 
				
			||||||
 | 
							http.Handle("/", rp)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						http.Handle("/favicon.ico", rp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						listenAddr, _ := url.Parse(config.ListenURL)
 | 
				
			||||||
 | 
						listenPort := listenAddr.Port()
 | 
				
			||||||
 | 
						if listenPort == "" {
 | 
				
			||||||
 | 
							if listenAddr.Scheme == "https" {
 | 
				
			||||||
 | 
								listenPort = "443"
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								listenPort = "80"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Printf("Starting server on port %s\n", listenPort)
 | 
				
			||||||
 | 
						if err := http.ListenAndServe(":"+listenPort, nil); err != nil {
 | 
				
			||||||
 | 
							log.Printf("Server error: %v\n", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										256
									
								
								oauth.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										256
									
								
								oauth.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,256 @@
 | 
				
			|||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						_ "embed"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"html/template"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/google/uuid"
 | 
				
			||||||
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Session struct {
 | 
				
			||||||
 | 
						ID        string    `json:"id"`
 | 
				
			||||||
 | 
						UserID    string    `json:"user_id"`
 | 
				
			||||||
 | 
						ExpiresAt time.Time `json:"expires_at"`
 | 
				
			||||||
 | 
						CreatedAt time.Time `json:"created_at"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const SessionCookie = "session"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OAuthStore struct {
 | 
				
			||||||
 | 
						sessions map[string]*Session
 | 
				
			||||||
 | 
						mutex    sync.RWMutex
 | 
				
			||||||
 | 
						oa2      *oauth2.Config
 | 
				
			||||||
 | 
						config   *Config
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewSessionStore(config *Config, oa2 *oauth2.Config) *OAuthStore {
 | 
				
			||||||
 | 
						return &OAuthStore{
 | 
				
			||||||
 | 
							sessions: make(map[string]*Session),
 | 
				
			||||||
 | 
							oa2:      oa2,
 | 
				
			||||||
 | 
							config:   config,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *OAuthStore) CreateSession(userID string, duration time.Duration) (*Session, error) {
 | 
				
			||||||
 | 
						uid, err := uuid.NewRandom()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						id := uid.String()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session := &Session{
 | 
				
			||||||
 | 
							ID:        id,
 | 
				
			||||||
 | 
							UserID:    userID,
 | 
				
			||||||
 | 
							ExpiresAt: time.Now().Add(duration),
 | 
				
			||||||
 | 
							CreatedAt: time.Now(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s.mutex.Lock()
 | 
				
			||||||
 | 
						s.sessions[id] = session
 | 
				
			||||||
 | 
						s.mutex.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return session, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *OAuthStore) GetSession(sessionID string) (*Session, bool) {
 | 
				
			||||||
 | 
						s.mutex.RLock()
 | 
				
			||||||
 | 
						defer s.mutex.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session, exists := s.sessions[sessionID]
 | 
				
			||||||
 | 
						if !exists {
 | 
				
			||||||
 | 
							return nil, false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if session.ExpiresAt.Before(time.Now()) {
 | 
				
			||||||
 | 
							s.DeleteSession(sessionID)
 | 
				
			||||||
 | 
							return nil, false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return session, true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *OAuthStore) DeleteSession(sessionID string) {
 | 
				
			||||||
 | 
						s.mutex.Lock()
 | 
				
			||||||
 | 
						delete(s.sessions, sessionID)
 | 
				
			||||||
 | 
						s.mutex.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func sendToLoginPage(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
						http.Redirect(w, r, "/oauth/login", http.StatusTemporaryRedirect)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func generateRandomToken() string {
 | 
				
			||||||
 | 
						b := make([]byte, 32)
 | 
				
			||||||
 | 
						rand.Read(b)
 | 
				
			||||||
 | 
						return base64.StdEncoding.EncodeToString(b)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//go:embed templates/LoginPage.html
 | 
				
			||||||
 | 
					var loginPageContent string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *OAuthStore) LoginPage() http.Handler {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						loginPageTemplate := template.Must(template.New("loginPageContent").Parse(loginPageContent))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							state := generateRandomToken()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if cookie, err := r.Cookie("oauth_state"); err == nil && cookie.Value != "" {
 | 
				
			||||||
 | 
								state = cookie.Value
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								http.SetCookie(w, &http.Cookie{
 | 
				
			||||||
 | 
									Name:     "oauth_state",
 | 
				
			||||||
 | 
									Value:    state,
 | 
				
			||||||
 | 
									HttpOnly: true,
 | 
				
			||||||
 | 
									Secure:   true, // Set to true in production
 | 
				
			||||||
 | 
									SameSite: http.SameSiteLaxMode,
 | 
				
			||||||
 | 
									MaxAge:   60 * 10, // 10 minutes
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							url := s.oa2.AuthCodeURL(state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							provider := s.config.OAuthProvider.Kind
 | 
				
			||||||
 | 
							switch provider {
 | 
				
			||||||
 | 
							case "google":
 | 
				
			||||||
 | 
								provider = "Google"
 | 
				
			||||||
 | 
							case "github":
 | 
				
			||||||
 | 
								provider = "GitHub"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							loginPageTemplate.Execute(w, struct {
 | 
				
			||||||
 | 
								Url      string
 | 
				
			||||||
 | 
								State    string
 | 
				
			||||||
 | 
								Provider string
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								Url:      url,
 | 
				
			||||||
 | 
								State:    state,
 | 
				
			||||||
 | 
								Provider: provider,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *OAuthStore) Protected(next http.Handler) http.Handler {
 | 
				
			||||||
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cookie, err := r.Cookie(SessionCookie)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								sendToLoginPage(w, r)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							sess, exists := s.GetSession(cookie.Value)
 | 
				
			||||||
 | 
							if !exists {
 | 
				
			||||||
 | 
								sendToLoginPage(w, r)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							found := false
 | 
				
			||||||
 | 
							for _, user := range s.config.AllowedUsers {
 | 
				
			||||||
 | 
								if user == sess.UserID {
 | 
				
			||||||
 | 
									found = true
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !found {
 | 
				
			||||||
 | 
								sendToLoginPage(w, r)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *OAuthStore) CallbackHandler() http.Handler {
 | 
				
			||||||
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							stateCookie, err := r.Cookie("oauth_state")
 | 
				
			||||||
 | 
							if err != nil || stateCookie.Value != r.URL.Query().Get("state") {
 | 
				
			||||||
 | 
								http.Error(w, "Invalid state", http.StatusBadRequest)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							tok, err := s.oa2.Exchange(r.Context(), r.URL.Query().Get("code"))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								http.Error(w, "Failed to get info", http.StatusInternalServerError)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							println(userID)
 | 
				
			||||||
 | 
							sess, err := s.CreateSession(userID, time.Hour*24)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								http.Error(w, "Failed to create session", http.StatusInternalServerError)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							http.SetCookie(w, &http.Cookie{
 | 
				
			||||||
 | 
								Name:     SessionCookie,
 | 
				
			||||||
 | 
								Value:    sess.ID,
 | 
				
			||||||
 | 
								HttpOnly: true,
 | 
				
			||||||
 | 
								Secure:   true,
 | 
				
			||||||
 | 
								SameSite: http.SameSiteLaxMode,
 | 
				
			||||||
 | 
								MaxAge:   int(time.Hour.Seconds() * 24),
 | 
				
			||||||
 | 
								Path:     "/",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// clear cookie
 | 
				
			||||||
 | 
							http.SetCookie(w, &http.Cookie{
 | 
				
			||||||
 | 
								Name:   "oauth_state",
 | 
				
			||||||
 | 
								Value:  "",
 | 
				
			||||||
 | 
								MaxAge: -1,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// TODO: remember what path the user was on and redirect them back there after doing the whole login process
 | 
				
			||||||
 | 
							http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getUserInfo(providerKind, token string) (string, error) {
 | 
				
			||||||
 | 
						switch providerKind {
 | 
				
			||||||
 | 
						case "google":
 | 
				
			||||||
 | 
							type UserInfo struct {
 | 
				
			||||||
 | 
								ID    string `json:"sub"`
 | 
				
			||||||
 | 
								Email string `json:"email"`
 | 
				
			||||||
 | 
								Name  string `json:"name"`
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							var userInfo UserInfo
 | 
				
			||||||
 | 
							if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return userInfo.Email, nil
 | 
				
			||||||
 | 
						case "github":
 | 
				
			||||||
 | 
							type UserInfo struct {
 | 
				
			||||||
 | 
								Login string `json:"login"`
 | 
				
			||||||
 | 
								Name  string `json:"name"`
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
 | 
				
			||||||
 | 
							req.Header.Add("Authorization", "Bearer "+token)
 | 
				
			||||||
 | 
							resp, err := http.DefaultClient.Do(req)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							var userInfo UserInfo
 | 
				
			||||||
 | 
							if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
 | 
				
			||||||
 | 
								return "", err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return userInfo.Login, nil
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							panic("unimplemented")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										82
									
								
								templates/LoginPage.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								templates/LoginPage.html
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,82 @@
 | 
				
			|||||||
 | 
					<!DOCTYPE html>
 | 
				
			||||||
 | 
					<html lang="en">
 | 
				
			||||||
 | 
					  <head>
 | 
				
			||||||
 | 
					    <meta charset="UTF-8">
 | 
				
			||||||
 | 
					    <meta name="viewport" content="width=device-width, initial-scale=1">
 | 
				
			||||||
 | 
					    <title>Login</title>
 | 
				
			||||||
 | 
					    <style>
 | 
				
			||||||
 | 
					      * {
 | 
				
			||||||
 | 
					        margin: 0;
 | 
				
			||||||
 | 
					        padding: 0;
 | 
				
			||||||
 | 
					        box-sizing: border-box;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      body {
 | 
				
			||||||
 | 
					        font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
 | 
				
			||||||
 | 
					        min-height: 100vh;
 | 
				
			||||||
 | 
					        display: flex;
 | 
				
			||||||
 | 
					        align-items: center;
 | 
				
			||||||
 | 
					        justify-content: center;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      .login-container {
 | 
				
			||||||
 | 
					        background: white;
 | 
				
			||||||
 | 
					        padding: 3rem;
 | 
				
			||||||
 | 
					        border-radius: 16px;
 | 
				
			||||||
 | 
					        max-width: 400px;
 | 
				
			||||||
 | 
					        width: 90%;
 | 
				
			||||||
 | 
					        display: flex;
 | 
				
			||||||
 | 
					        justify-content: center;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      .login-button {
 | 
				
			||||||
 | 
					        background: #4c4c4c;
 | 
				
			||||||
 | 
					        color: white;
 | 
				
			||||||
 | 
					        border: none;
 | 
				
			||||||
 | 
					        padding: 16px 32px;
 | 
				
			||||||
 | 
					        border-radius: 12px;
 | 
				
			||||||
 | 
					        font-size: 16px;
 | 
				
			||||||
 | 
					        font-weight: 600;
 | 
				
			||||||
 | 
					        cursor: pointer;
 | 
				
			||||||
 | 
					        transition: all 0.3s ease;
 | 
				
			||||||
 | 
					        display: inline-flex;
 | 
				
			||||||
 | 
					        align-items: center;
 | 
				
			||||||
 | 
					        gap: 12px;
 | 
				
			||||||
 | 
					        text-decoration: none;
 | 
				
			||||||
 | 
					        min-width: 200px;
 | 
				
			||||||
 | 
					        justify-content: center;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      .login-button:hover {
 | 
				
			||||||
 | 
					        transform: translateY(-2px);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      .login-button:active {
 | 
				
			||||||
 | 
					        transform: translateY(0);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      .lock-icon {
 | 
				
			||||||
 | 
					        width: 20px;
 | 
				
			||||||
 | 
					        height: 20px;
 | 
				
			||||||
 | 
					        fill: currentColor;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					      h1 {
 | 
				
			||||||
 | 
					        color: #333;
 | 
				
			||||||
 | 
					        margin-bottom: 2rem;
 | 
				
			||||||
 | 
					        font-weight: 300;
 | 
				
			||||||
 | 
					        font-size: 2rem;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    </style>
 | 
				
			||||||
 | 
					  </head>
 | 
				
			||||||
 | 
					  <body>
 | 
				
			||||||
 | 
					    <div class="login-container">
 | 
				
			||||||
 | 
					      <a href="{{.Url}}" class="login-button">
 | 
				
			||||||
 | 
					        <svg class="lock-icon" viewBox="0 0 24 24">
 | 
				
			||||||
 | 
					          <path d="M18,8h-1V6c0-2.76-2.24-5-5-5S7,3.24,7,6v2H6c-1.1,0-2,0.9-2,2v10c0,1.1,0.9,2,2,2h12c1.1,0,2-0.9,2-2V10C20,8.9,19.1,8,18,8z M12,17c-1.1,0-2-0.9-2-2s0.9-2,2-2s2,0.9,2,2S13.1,17,12,17z M15.1,8H8.9V6c0-1.71,1.39-3.1,3.1-3.1s3.1,1.39,3.1,3.1V8z"/>
 | 
				
			||||||
 | 
					        </svg>
 | 
				
			||||||
 | 
					        Login with {{.Provider}}
 | 
				
			||||||
 | 
					      </a>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					  </body>
 | 
				
			||||||
 | 
					</html>
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user