diff --git a/config.go b/config.go index b2c862f..44c767d 100644 --- a/config.go +++ b/config.go @@ -19,6 +19,12 @@ type OAuthProvider struct { ClientID string `toml:"client_id"` ClientSecret string `toml:"client_secret"` RedirectURL string `toml:"redirect_url"` + + // Only for custom OAuth provider + AuthURL string `toml:"auth_url"` + TokenURL string `toml:"token_url"` + Scopes []string `toml:"scopes"` + Script string `toml:"info_script"` } type Config struct { @@ -50,6 +56,15 @@ func LoadConfig() (Config, oauth2.Config, error) { case "google": oa2.Endpoint = endpoints.Google oa2.Scopes = []string{"https://www.googleapis.com/auth/userinfo.email"} + default: + oa2.Endpoint = oauth2.Endpoint{ + AuthURL: config.OAuthProvider.AuthURL, + TokenURL: config.OAuthProvider.TokenURL, + } + oa2.Scopes = config.OAuthProvider.Scopes + if config.OAuthProvider.Script == "" { + panic("no script provided") + } } return config, oa2, err diff --git a/config.toml b/config.toml index f3e9377..daa1ad8 100644 --- a/config.toml +++ b/config.toml @@ -3,6 +3,7 @@ guarded_paths = ["/"] # A list of all allowed users. For GitHub it is a list of usernames. # For Google it is a list of emails allowed_users = ["lukasmwerner"] +redirect_url = "http://localhost:3000/oauth/callback" [upstream] addr = "http://localhost:8080" @@ -11,7 +12,11 @@ program = "rezepte" args = [] [provider] -kind = "github" # can `google` or `github` +kind = "OAuth2" # can `google` or `github` or custom via an undefined name +auth_url = "https://github.com/login/oauth/authorize" +token_url = "https://github.com/login/oauth/access_token" +scopes = ["read:user"] client_id = "" client_secret = "" -redirect_url = "http://localhost:3000/oauth/callback" +# lua script that contains the function: `get_user_info(token)` and returns a string +info_script = "get_user_info.lua" diff --git a/config_simple.toml b/config_simple.toml new file mode 100644 index 0000000..5f22a75 --- /dev/null +++ b/config_simple.toml @@ -0,0 +1,18 @@ +listen_url = "http://localhost:3000" +guarded_paths = ["/"] +# A list of all allowed users. For GitHub it is a list of usernames. +# For Google it is a list of emails +allowed_users = ["lukasmwerner"] + +[upstream] +addr = "http://localhost:8080" +# An optional program to run as the upstream +program = "rezepte" +args = [] + +[provider] +kind = "custom" # can `google` or `github` or `custom` +client_id = "" +client_secret = "" + +redirect_url = "http://localhost:3000/oauth/callback" diff --git a/get_user_info.lua b/get_user_info.lua new file mode 100644 index 0000000..a6df917 --- /dev/null +++ b/get_user_info.lua @@ -0,0 +1,8 @@ +function get_user_info(token) + local json = require("json") + local resp, status = http.get("https://api.github.com/user", { + Authorization = "Bearer " .. token + }) + local data = json.decode(resp) + return data.login +end diff --git a/go.mod b/go.mod index 340342f..023aa28 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.24.3 require ( github.com/BurntSushi/toml v1.5.0 + github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e + github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab github.com/google/uuid v1.6.0 golang.org/x/oauth2 v0.30.0 ) diff --git a/go.sum b/go.sum index b2f0a1f..27ad82d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e h1:zT/Iq/ow1l/J45IMajZ487dGbtjO9CfATa1O0T0aA9U= +github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e/go.mod h1:M4CxjVc/1Nwka5atBv7G/sb7Ac2BDe3+FxbiT9iVNIQ= +github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab h1:lEd6vZgWJOjXAoIDUxSgg/U8/DbFEJnTfcBOQyAhej4= +github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab/go.mod h1:xIykgNzJggTWudqtySZwJa8Ab8NFgUSbSpPrTHQaHIc= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= diff --git a/oauth.go b/oauth.go index 55fda53..efd6595 100644 --- a/oauth.go +++ b/oauth.go @@ -177,7 +177,7 @@ func (s *OAuthStore) CallbackHandler() http.Handler { http.Error(w, "Failed to exchange token", http.StatusInternalServerError) return } - userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken) + userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken, s.config) if err != nil { http.Error(w, "Failed to get info", http.StatusInternalServerError) return @@ -210,7 +210,7 @@ func (s *OAuthStore) CallbackHandler() http.Handler { }) } -func getUserInfo(providerKind, token string) (string, error) { +func getUserInfo(providerKind, token string, c *Config) (string, error) { switch providerKind { case "google": type UserInfo struct { @@ -249,6 +249,6 @@ func getUserInfo(providerKind, token string) (string, error) { } return userInfo.Login, nil default: - panic("unimplemented") + return RunScript(c.OAuthProvider.Script, token) } } diff --git a/scripts.go b/scripts.go new file mode 100644 index 0000000..3426be8 --- /dev/null +++ b/scripts.go @@ -0,0 +1,94 @@ +package main + +import ( + "encoding/json" + "errors" + "io" + "net/http" + + "github.com/Shopify/go-lua" + "github.com/Shopify/goluago/util" +) + +var networkFunctions = []lua.RegistryFunction{ + {Name: "get", Function: func(l *lua.State) int { + url := lua.CheckString(l, 1) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + lua.Errorf(l, "unable to build new request: %s", err.Error()) + return 0 + } + + if !l.IsNil(2) { + headers, err := util.PullStringTable(l, 2) + if err != nil { + lua.Errorf(l, "unable to acces headers table: %s", err.Error()) + return 0 + } + for key, value := range headers { + req.Header.Set(key, value) + } + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + lua.Errorf(l, "error fetching: %s", err.Error()) + return 0 + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + lua.Errorf(l, "error reading body: %s", err.Error()) + return 0 + } + l.PushString(string(b)) + l.PushInteger(resp.StatusCode) + + return 2 + }}, +} + +var jsonFunctions = []lua.RegistryFunction{ + {Name: "decode", Function: func(l *lua.State) int { + payload := lua.CheckString(l, 1) + var output any + if err := json.Unmarshal([]byte(payload), &output); err != nil { + lua.Errorf(l, "error parsing json: %s", err.Error()) + return 0 + } + return util.DeepPush(l, output) + }}, +} + +func RunScript(fileName string, token string) (string, error) { + l := lua.NewState() + lua.OpenLibraries(l) + lua.Require(l, "http", func(state *lua.State) int { + lua.NewLibrary(l, networkFunctions) + return 1 + + }, true) + + lua.Require(l, "json", func(state *lua.State) int { + lua.NewLibrary(l, jsonFunctions) + return 1 + }, false) + err := lua.DoFile(l, fileName) + if err != nil { + return "", err + } + l.Global("get_user_info") + l.PushString(token) + err = l.ProtectedCall(1, 1, 0) + if err != nil { + return "", err + } + + myStr, ok := l.ToString(-1) + if !ok { + return "", errors.New("unable to get function result as string") + } + return myStr, nil + +}