258 lines
5.4 KiB
Go
258 lines
5.4 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/xml"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
lua "github.com/yuin/gopher-lua"
|
|
)
|
|
|
|
func InitLua() *lua.LState {
|
|
logger.Info("Initializing lua")
|
|
L := lua.NewState()
|
|
// load_log_library(L)
|
|
bind_lua_functions(L)
|
|
LoadScripts(L)
|
|
return L
|
|
}
|
|
|
|
func LoadLogFunctions(L *lua.LState) {
|
|
table := L.NewTable()
|
|
L.SetFuncs(table, map[string]lua.LGFunction{
|
|
"info": luaLogInfo,
|
|
"debug": luaLogDebug,
|
|
"error": luaLogError,
|
|
"fatal": luaLogFatal,
|
|
})
|
|
L.SetGlobal("log", table)
|
|
}
|
|
|
|
func loadScript(L *lua.LState, script string) error {
|
|
if err := L.DoFile(script); err != nil {
|
|
logger.Fatal(err)
|
|
return err
|
|
}
|
|
logger.Debug("Loaded script: ", script)
|
|
return nil
|
|
}
|
|
|
|
func LoadScripts(L *lua.LState) error {
|
|
// iterate over the scripts directory and load each script
|
|
files, err := os.ReadDir("scripts")
|
|
if err != nil {
|
|
logger.Fatal(err)
|
|
return err
|
|
}
|
|
|
|
loadAllScripts := appConfig.Scripts[0] == "all" || appConfig.Scripts[0] == ""
|
|
if loadAllScripts {
|
|
for _, file := range files {
|
|
var extension = filepath.Ext(file.Name())
|
|
if file.IsDir() || extension != ".lua" {
|
|
continue
|
|
}
|
|
if err := loadScript(L, "scripts/"+file.Name()); err != nil {
|
|
logger.Fatal(err)
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
for _, script := range appConfig.Scripts {
|
|
if err := loadScript(L, "scripts/"+script+".lua"); err != nil {
|
|
logger.Fatal(err)
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func execute_route(L *lua.LState, route string, script string) (string, error) {
|
|
logger.Debug("Executing route: ", route)
|
|
|
|
// Get table by script name
|
|
table := L.GetGlobal(script)
|
|
// Get the function from the table
|
|
fn := L.GetField(table, "route")
|
|
// TODO: The message should say something else
|
|
if fn.Type() != lua.LTFunction {
|
|
logger.Error("route function was not found in script: ", script)
|
|
return "", nil
|
|
}
|
|
|
|
if err := L.CallByParam(lua.P{
|
|
Fn: fn,
|
|
NRet: 1,
|
|
Protect: true,
|
|
}); err != nil {
|
|
logger.Error(err)
|
|
return "", err
|
|
}
|
|
ret := L.Get(-1)
|
|
responses[route] = strings.Clone(ret.String())
|
|
|
|
return ret.String(), nil
|
|
}
|
|
|
|
func add_html_parser(L *lua.LState) int {
|
|
// Create a table for the html_parser add_functions
|
|
html_table := L.NewTable()
|
|
L.SetFuncs(html_table, map[string]lua.LGFunction{
|
|
"parse": newHtmlParser,
|
|
"select": select_html_node,
|
|
})
|
|
L.Push(html_table)
|
|
return 1
|
|
}
|
|
|
|
func bind_lua_functions(L *lua.LState) error {
|
|
// Add global functions to lua
|
|
LoadLogFunctions(L)
|
|
|
|
L.SetGlobal("add_route", L.NewFunction(lua_add_route))
|
|
L.SetGlobal("get", L.NewFunction(get_request))
|
|
L.SetGlobal("parse_xml_feed", L.NewFunction(parse_xml_feed))
|
|
L.SetGlobal("create_rss_feed", L.NewFunction(create_rss_feed))
|
|
|
|
// Register Golang structs with lua
|
|
registerDatabaseType(L)
|
|
registerRSSItemType(L)
|
|
registerHtmlParserType(L)
|
|
registerRssImageType(L)
|
|
|
|
return nil
|
|
}
|
|
|
|
func lua_add_route(L *lua.LState) int {
|
|
// TODO: Check that the parameters are correct
|
|
// Handle errors
|
|
script := L.ToString(1)
|
|
route := L.ToString(2)
|
|
table := L.NewTable()
|
|
|
|
L.SetGlobal(script, table)
|
|
|
|
routes[route] = script
|
|
responses[route] = ""
|
|
logger.Debug("Adding route: ", route, " -> ", script)
|
|
return 0
|
|
}
|
|
|
|
func get_request(L *lua.LState) int {
|
|
url := L.ToString(1)
|
|
logger.Debug("GET: ", url)
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
logger.Error(err)
|
|
return 0
|
|
}
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
logger.Fatal(err)
|
|
return 0
|
|
}
|
|
L.Push(lua.LString(string(body)))
|
|
return 1
|
|
}
|
|
|
|
func parse_xml_feed(L *lua.LState) int {
|
|
xml_str := L.ToString(1)
|
|
|
|
var rss RssRoot
|
|
err := xml.Unmarshal([]byte(xml_str), &rss)
|
|
if err != nil {
|
|
logger.Fatal(err)
|
|
return 0
|
|
}
|
|
feed_items := L.NewTable()
|
|
items := rss.Channel.Items
|
|
|
|
for i := range items {
|
|
ud := L.NewUserData()
|
|
ud.Value = &items[i]
|
|
L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName))
|
|
feed_items.Append(ud)
|
|
}
|
|
L.Push(feed_items)
|
|
return 1
|
|
}
|
|
|
|
// TODO: Move this into the rss table
|
|
func create_rss_feed(L *lua.LState) int {
|
|
//TODO: Check that the parameters are correct
|
|
feed := RssRoot{}
|
|
|
|
//TODO: Bind RssImage to lua and pass it in
|
|
// Also needs to be added to the database and then re-served through the server
|
|
title := L.CheckString(1)
|
|
link := L.CheckString(2)
|
|
image := L.CheckUserData(3)
|
|
// TODO: Channel needs to also be passed in
|
|
feed.Channel = RssChannel{
|
|
Title: title,
|
|
Image: *image.Value.(*RssImage),
|
|
Link: link,
|
|
}
|
|
|
|
// Append the items to the feed
|
|
// Notably the <channel> tag as <item>s
|
|
entries := L.CheckTable(4)
|
|
if entries.Len() == 0 {
|
|
logger.Error("No entries to create rss feed found in table")
|
|
return 0
|
|
}
|
|
|
|
for i := 1; i <= entries.Len(); i++ {
|
|
entry := entries.RawGetInt(i)
|
|
ud := entry.(*lua.LUserData)
|
|
item := ud.Value.(*RssItem)
|
|
feed.Channel.Items = append(feed.Channel.Items, *item)
|
|
}
|
|
|
|
// sort by date
|
|
sort.Sort(ByPubDate(feed.Channel.Items))
|
|
|
|
output, err := xml.MarshalIndent(feed, "", " ")
|
|
if err != nil {
|
|
logger.Fatal(err)
|
|
return 0
|
|
}
|
|
|
|
L.Push(lua.LString(string(output)))
|
|
return 1
|
|
}
|
|
|
|
func luaLogInfo(L *lua.LState) int {
|
|
loc := L.Where(1)
|
|
msg := L.ToString(1)
|
|
logger.Info(loc, msg)
|
|
return 0
|
|
}
|
|
|
|
func luaLogDebug(L *lua.LState) int {
|
|
// Get the where and message
|
|
loc := L.Where(1)
|
|
msg := L.ToString(1)
|
|
logger.Debug(loc, msg)
|
|
return 0
|
|
}
|
|
|
|
func luaLogError(L *lua.LState) int {
|
|
loc := L.Where(1)
|
|
msg := L.ToString(1)
|
|
logger.Error(loc, msg)
|
|
return 0
|
|
}
|
|
|
|
func luaLogFatal(L *lua.LState) int {
|
|
loc := L.Where(1)
|
|
msg := L.ToString(1)
|
|
logger.Fatal(loc, msg)
|
|
return 0
|
|
}
|