add user scripting with lua

This commit is contained in:
Lukas Werner 2025-07-03 17:07:35 -07:00
parent 5d2b82876d
commit 3920b8913d
No known key found for this signature in database
8 changed files with 151 additions and 5 deletions

View File

@ -19,6 +19,12 @@ type OAuthProvider struct {
ClientID string `toml:"client_id"`
ClientSecret string `toml:"client_secret"`
RedirectURL string `toml:"redirect_url"`
// Only for custom OAuth provider
AuthURL string `toml:"auth_url"`
TokenURL string `toml:"token_url"`
Scopes []string `toml:"scopes"`
Script string `toml:"info_script"`
}
type Config struct {
@ -50,6 +56,15 @@ func LoadConfig() (Config, oauth2.Config, error) {
case "google":
oa2.Endpoint = endpoints.Google
oa2.Scopes = []string{"https://www.googleapis.com/auth/userinfo.email"}
default:
oa2.Endpoint = oauth2.Endpoint{
AuthURL: config.OAuthProvider.AuthURL,
TokenURL: config.OAuthProvider.TokenURL,
}
oa2.Scopes = config.OAuthProvider.Scopes
if config.OAuthProvider.Script == "" {
panic("no script provided")
}
}
return config, oa2, err

View File

@ -3,6 +3,7 @@ guarded_paths = ["/"]
# A list of all allowed users. For GitHub it is a list of usernames.
# For Google it is a list of emails
allowed_users = ["lukasmwerner"]
redirect_url = "http://localhost:3000/oauth/callback"
[upstream]
addr = "http://localhost:8080"
@ -11,7 +12,11 @@ program = "rezepte"
args = []
[provider]
kind = "github" # can `google` or `github`
kind = "OAuth2" # can `google` or `github` or custom via an undefined name
auth_url = "https://github.com/login/oauth/authorize"
token_url = "https://github.com/login/oauth/access_token"
scopes = ["read:user"]
client_id = "<CHANGE_ME>"
client_secret = "<CHANGE_ME>"
redirect_url = "http://localhost:3000/oauth/callback"
# lua script that contains the function: `get_user_info(token)` and returns a string
info_script = "get_user_info.lua"

18
config_simple.toml Normal file
View File

@ -0,0 +1,18 @@
listen_url = "http://localhost:3000"
guarded_paths = ["/"]
# A list of all allowed users. For GitHub it is a list of usernames.
# For Google it is a list of emails
allowed_users = ["lukasmwerner"]
[upstream]
addr = "http://localhost:8080"
# An optional program to run as the upstream
program = "rezepte"
args = []
[provider]
kind = "custom" # can `google` or `github` or `custom`
client_id = "<CHANGE_ME>"
client_secret = "<CHANGE_ME>"
redirect_url = "http://localhost:3000/oauth/callback"

8
get_user_info.lua Normal file
View File

@ -0,0 +1,8 @@
function get_user_info(token)
local json = require("json")
local resp, status = http.get("https://api.github.com/user", {
Authorization = "Bearer " .. token
})
local data = json.decode(resp)
return data.login
end

2
go.mod
View File

@ -4,6 +4,8 @@ go 1.24.3
require (
github.com/BurntSushi/toml v1.5.0
github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e
github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab
github.com/google/uuid v1.6.0
golang.org/x/oauth2 v0.30.0
)

4
go.sum
View File

@ -1,5 +1,9 @@
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e h1:zT/Iq/ow1l/J45IMajZ487dGbtjO9CfATa1O0T0aA9U=
github.com/Shopify/go-lua v0.0.0-20250605195627-15bbeb73041e/go.mod h1:M4CxjVc/1Nwka5atBv7G/sb7Ac2BDe3+FxbiT9iVNIQ=
github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab h1:lEd6vZgWJOjXAoIDUxSgg/U8/DbFEJnTfcBOQyAhej4=
github.com/Shopify/goluago v0.0.0-20240527182001-ec4ec6c26eab/go.mod h1:xIykgNzJggTWudqtySZwJa8Ab8NFgUSbSpPrTHQaHIc=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=

View File

@ -177,7 +177,7 @@ func (s *OAuthStore) CallbackHandler() http.Handler {
http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
return
}
userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken)
userID, err := getUserInfo(s.config.OAuthProvider.Kind, tok.AccessToken, s.config)
if err != nil {
http.Error(w, "Failed to get info", http.StatusInternalServerError)
return
@ -210,7 +210,7 @@ func (s *OAuthStore) CallbackHandler() http.Handler {
})
}
func getUserInfo(providerKind, token string) (string, error) {
func getUserInfo(providerKind, token string, c *Config) (string, error) {
switch providerKind {
case "google":
type UserInfo struct {
@ -249,6 +249,6 @@ func getUserInfo(providerKind, token string) (string, error) {
}
return userInfo.Login, nil
default:
panic("unimplemented")
return RunScript(c.OAuthProvider.Script, token)
}
}

94
scripts.go Normal file
View File

@ -0,0 +1,94 @@
package main
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/Shopify/go-lua"
"github.com/Shopify/goluago/util"
)
var networkFunctions = []lua.RegistryFunction{
{Name: "get", Function: func(l *lua.State) int {
url := lua.CheckString(l, 1)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
lua.Errorf(l, "unable to build new request: %s", err.Error())
return 0
}
if !l.IsNil(2) {
headers, err := util.PullStringTable(l, 2)
if err != nil {
lua.Errorf(l, "unable to acces headers table: %s", err.Error())
return 0
}
for key, value := range headers {
req.Header.Set(key, value)
}
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
lua.Errorf(l, "error fetching: %s", err.Error())
return 0
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
lua.Errorf(l, "error reading body: %s", err.Error())
return 0
}
l.PushString(string(b))
l.PushInteger(resp.StatusCode)
return 2
}},
}
var jsonFunctions = []lua.RegistryFunction{
{Name: "decode", Function: func(l *lua.State) int {
payload := lua.CheckString(l, 1)
var output any
if err := json.Unmarshal([]byte(payload), &output); err != nil {
lua.Errorf(l, "error parsing json: %s", err.Error())
return 0
}
return util.DeepPush(l, output)
}},
}
func RunScript(fileName string, token string) (string, error) {
l := lua.NewState()
lua.OpenLibraries(l)
lua.Require(l, "http", func(state *lua.State) int {
lua.NewLibrary(l, networkFunctions)
return 1
}, true)
lua.Require(l, "json", func(state *lua.State) int {
lua.NewLibrary(l, jsonFunctions)
return 1
}, false)
err := lua.DoFile(l, fileName)
if err != nil {
return "", err
}
l.Global("get_user_info")
l.PushString(token)
err = l.ProtectedCall(1, 1, 0)
if err != nil {
return "", err
}
myStr, ok := l.ToString(-1)
if !ok {
return "", errors.New("unable to get function result as string")
}
return myStr, nil
}