add user scripting with lua
This commit is contained in:
		
							parent
							
								
									5d2b82876d
								
							
						
					
					
						commit
						3920b8913d
					
				
							
								
								
									
										15
									
								
								config.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								config.go
									
									
									
									
									
								
							@ -19,6 +19,12 @@ type OAuthProvider struct {
 | 
			
		||||
	ClientID     string `toml:"client_id"`
 | 
			
		||||
	ClientSecret string `toml:"client_secret"`
 | 
			
		||||
	RedirectURL  string `toml:"redirect_url"`
 | 
			
		||||
 | 
			
		||||
	// Only for custom OAuth provider
 | 
			
		||||
	AuthURL  string   `toml:"auth_url"`
 | 
			
		||||
	TokenURL string   `toml:"token_url"`
 | 
			
		||||
	Scopes   []string `toml:"scopes"`
 | 
			
		||||
	Script   string   `toml:"info_script"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
@ -50,6 +56,15 @@ func LoadConfig() (Config, oauth2.Config, error) {
 | 
			
		||||
	case "google":
 | 
			
		||||
		oa2.Endpoint = endpoints.Google
 | 
			
		||||
		oa2.Scopes = []string{"https://www.googleapis.com/auth/userinfo.email"}
 | 
			
		||||
	default:
 | 
			
		||||
		oa2.Endpoint = oauth2.Endpoint{
 | 
			
		||||
			AuthURL:  config.OAuthProvider.AuthURL,
 | 
			
		||||
			TokenURL: config.OAuthProvider.TokenURL,
 | 
			
		||||
		}
 | 
			
		||||
		oa2.Scopes = config.OAuthProvider.Scopes
 | 
			
		||||
		if config.OAuthProvider.Script == "" {
 | 
			
		||||
			panic("no script provided")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return config, oa2, err
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ guarded_paths = ["/"]
 | 
			
		||||
# A list of all allowed users. For GitHub it is a list of usernames.
 | 
			
		||||
# For Google it is a list of emails
 | 
			
		||||
allowed_users = ["lukasmwerner"]
 | 
			
		||||
redirect_url = "http://localhost:3000/oauth/callback"
 | 
			
		||||
 | 
			
		||||
[upstream]
 | 
			
		||||
addr = "http://localhost:8080"
 | 
			
		||||
@ -11,7 +12,11 @@ program = "rezepte"
 | 
			
		||||
args = []
 | 
			
		||||
 | 
			
		||||
[provider]
 | 
			
		||||
kind = "github" # can `google` or `github`
 | 
			
		||||
kind = "OAuth2" # can `google` or `github` or custom via an undefined name
 | 
			
		||||
auth_url = "https://github.com/login/oauth/authorize"
 | 
			
		||||
token_url = "https://github.com/login/oauth/access_token"
 | 
			
		||||
scopes = ["read:user"]
 | 
			
		||||
client_id = "<CHANGE_ME>"
 | 
			
		||||
client_secret = "<CHANGE_ME>"
 | 
			
		||||
redirect_url = "http://localhost:3000/oauth/callback"
 | 
			
		||||
# lua script that contains the function: `get_user_info(token)` and returns a string
 | 
			
		||||
info_script = "get_user_info.lua"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										18
									
								
								config_simple.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								config_simple.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,18 @@
 | 
			
		||||
listen_url = "http://localhost:3000"
 | 
			
		||||
guarded_paths = ["/"]
 | 
			
		||||
# A list of all allowed users. For GitHub it is a list of usernames.
 | 
			
		||||
# For Google it is a list of emails
 | 
			
		||||
allowed_users = ["lukasmwerner"]
 | 
			
		||||
 | 
			
		||||
[upstream]
 | 
			
		||||
addr = "http://localhost:8080"
 | 
			
		||||
# An optional program to run as the upstream
 | 
			
		||||
program = "rezepte"
 | 
			
		||||
args = []
 | 
			
		||||
 | 
			
		||||
[provider]
 | 
			
		||||
kind = "custom" # can `google` or `github` or `custom`
 | 
			
		||||
client_id = "<CHANGE_ME>"
 | 
			
		||||
client_secret = "<CHANGE_ME>"
 | 
			
		||||
 | 
			
		||||
redirect_url = "http://localhost:3000/oauth/callback"
 | 
			
		||||
							
								
								
									
										8
									
								
								get_user_info.lua
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								get_user_info.lua
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
function get_user_info(token)
 | 
			
		||||
	local json = require("json")
 | 
			
		||||
	local resp, status = http.get("https://api.github.com/user", {
 | 
			
		||||
		Authorization =  "Bearer " .. token
 | 
			
		||||
	})
 | 
			
		||||
	local data = json.decode(resp)
 | 
			
		||||
	return data.login
 | 
			
		||||
end
 | 
			
		||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@ -4,6 +4,8 @@ go 1.24.3
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/BurntSushi/toml v1.5.0
 | 
			
		||||
	github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e
 | 
			
		||||
	github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab
 | 
			
		||||
	github.com/google/uuid v1.6.0
 | 
			
		||||
	golang.org/x/oauth2 v0.30.0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@ -1,5 +1,9 @@
 | 
			
		||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
 | 
			
		||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
 | 
			
		||||
github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e h1:zT/Iq/ow1l/J45IMajZ487dGbtjO9CfATa1O0T0aA9U=
 | 
			
		||||
github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e/go.mod h1:M4CxjVc/1Nwka5atBv7G/sb7Ac2BDe3+FxbiT9iVNIQ=
 | 
			
		||||
github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab h1:lEd6vZgWJOjXAoIDUxSgg/U8/DbFEJnTfcBOQyAhej4=
 | 
			
		||||
github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab/go.mod h1:xIykgNzJggTWudqtySZwJa8Ab8NFgUSbSpPrTHQaHIc=
 | 
			
		||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
 | 
			
		||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								oauth.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								oauth.go
									
									
									
									
									
								
							@ -177,7 +177,7 @@ func (s *OAuthStore) CallbackHandler() http.Handler {
 | 
			
		||||
			http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken)
 | 
			
		||||
		userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken, s.config)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			http.Error(w, "Failed to get info", http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
@ -210,7 +210,7 @@ func (s *OAuthStore) CallbackHandler() http.Handler {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getUserInfo(providerKind, token string) (string, error) {
 | 
			
		||||
func getUserInfo(providerKind, token string, c *Config) (string, error) {
 | 
			
		||||
	switch providerKind {
 | 
			
		||||
	case "google":
 | 
			
		||||
		type UserInfo struct {
 | 
			
		||||
@ -249,6 +249,6 @@ func getUserInfo(providerKind, token string) (string, error) {
 | 
			
		||||
		}
 | 
			
		||||
		return userInfo.Login, nil
 | 
			
		||||
	default:
 | 
			
		||||
		panic("unimplemented")
 | 
			
		||||
		return RunScript(c.OAuthProvider.Script, token)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										94
									
								
								scripts.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								scripts.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,94 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/Shopify/go-lua"
 | 
			
		||||
	"github.com/Shopify/goluago/util"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var networkFunctions = []lua.RegistryFunction{
 | 
			
		||||
	{Name: "get", Function: func(l *lua.State) int {
 | 
			
		||||
		url := lua.CheckString(l, 1)
 | 
			
		||||
 | 
			
		||||
		req, err := http.NewRequest("GET", url, nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			lua.Errorf(l, "unable to build new request: %s", err.Error())
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !l.IsNil(2) {
 | 
			
		||||
			headers, err := util.PullStringTable(l, 2)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				lua.Errorf(l, "unable to acces headers table: %s", err.Error())
 | 
			
		||||
				return 0
 | 
			
		||||
			}
 | 
			
		||||
			for key, value := range headers {
 | 
			
		||||
				req.Header.Set(key, value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		resp, err := http.DefaultClient.Do(req)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			lua.Errorf(l, "error fetching: %s", err.Error())
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
		defer resp.Body.Close()
 | 
			
		||||
		b, err := io.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			lua.Errorf(l, "error reading body: %s", err.Error())
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
		l.PushString(string(b))
 | 
			
		||||
		l.PushInteger(resp.StatusCode)
 | 
			
		||||
 | 
			
		||||
		return 2
 | 
			
		||||
	}},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var jsonFunctions = []lua.RegistryFunction{
 | 
			
		||||
	{Name: "decode", Function: func(l *lua.State) int {
 | 
			
		||||
		payload := lua.CheckString(l, 1)
 | 
			
		||||
		var output any
 | 
			
		||||
		if err := json.Unmarshal([]byte(payload), &output); err != nil {
 | 
			
		||||
			lua.Errorf(l, "error parsing json: %s", err.Error())
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
		return util.DeepPush(l, output)
 | 
			
		||||
	}},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RunScript(fileName string, token string) (string, error) {
 | 
			
		||||
	l := lua.NewState()
 | 
			
		||||
	lua.OpenLibraries(l)
 | 
			
		||||
	lua.Require(l, "http", func(state *lua.State) int {
 | 
			
		||||
		lua.NewLibrary(l, networkFunctions)
 | 
			
		||||
		return 1
 | 
			
		||||
 | 
			
		||||
	}, true)
 | 
			
		||||
 | 
			
		||||
	lua.Require(l, "json", func(state *lua.State) int {
 | 
			
		||||
		lua.NewLibrary(l, jsonFunctions)
 | 
			
		||||
		return 1
 | 
			
		||||
	}, false)
 | 
			
		||||
	err := lua.DoFile(l, fileName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	l.Global("get_user_info")
 | 
			
		||||
	l.PushString(token)
 | 
			
		||||
	err = l.ProtectedCall(1, 1, 0)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	myStr, ok := l.ToString(-1)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return "", errors.New("unable to get function result as string")
 | 
			
		||||
	}
 | 
			
		||||
	return myStr, nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user