rsslair/lua.go

258 lines
5.4 KiB
Go

package main
import (
"encoding/xml"
"io"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
lua "github.com/yuin/gopher-lua"
)
func InitLua() *lua.LState {
logger.Info("Initializing lua")
L := lua.NewState()
// load_log_library(L)
bind_lua_functions(L)
LoadScripts(L)
return L
}
func LoadLogFunctions(L *lua.LState) {
table := L.NewTable()
L.SetFuncs(table, map[string]lua.LGFunction{
"info": luaLogInfo,
"debug": luaLogDebug,
"error": luaLogError,
"fatal": luaLogFatal,
})
L.SetGlobal("log", table)
}
func loadScript(L *lua.LState, script string) error {
if err := L.DoFile(script); err != nil {
logger.Fatal(err)
return err
}
logger.Debug("Loaded script: ", script)
return nil
}
func LoadScripts(L *lua.LState) error {
// iterate over the scripts directory and load each script
files, err := os.ReadDir("scripts")
if err != nil {
logger.Fatal(err)
return err
}
loadAllScripts := appConfig.Scripts[0] == "all" || appConfig.Scripts[0] == ""
if loadAllScripts {
for _, file := range files {
var extension = filepath.Ext(file.Name())
if file.IsDir() || extension != ".lua" {
continue
}
if err := loadScript(L, "scripts/"+file.Name()); err != nil {
logger.Fatal(err)
return err
}
}
} else {
for _, script := range appConfig.Scripts {
if err := loadScript(L, "scripts/"+script+".lua"); err != nil {
logger.Fatal(err)
return err
}
}
}
return nil
}
func execute_route(L *lua.LState, route string, script string) (string, error) {
logger.Debug("Executing route: ", route)
// Get table by script name
table := L.GetGlobal(script)
// Get the function from the table
fn := L.GetField(table, "route")
// TODO: The message should say something else
if fn.Type() != lua.LTFunction {
logger.Error("route function was not found in script: ", script)
return "", nil
}
if err := L.CallByParam(lua.P{
Fn: fn,
NRet: 1,
Protect: true,
}); err != nil {
logger.Error(err)
return "", err
}
ret := L.Get(-1)
responses[route] = strings.Clone(ret.String())
return ret.String(), nil
}
func add_html_parser(L *lua.LState) int {
// Create a table for the html_parser add_functions
html_table := L.NewTable()
L.SetFuncs(html_table, map[string]lua.LGFunction{
"parse": newHtmlParser,
"select": select_html_node,
})
L.Push(html_table)
return 1
}
func bind_lua_functions(L *lua.LState) error {
// Add global functions to lua
LoadLogFunctions(L)
L.SetGlobal("add_route", L.NewFunction(lua_add_route))
L.SetGlobal("get", L.NewFunction(get_request))
L.SetGlobal("parse_xml_feed", L.NewFunction(parse_xml_feed))
L.SetGlobal("create_rss_feed", L.NewFunction(create_rss_feed))
// Register Golang structs with lua
registerDatabaseType(L)
registerRSSItemType(L)
registerHtmlParserType(L)
registerRssImageType(L)
return nil
}
func lua_add_route(L *lua.LState) int {
// TODO: Check that the parameters are correct
// Handle errors
script := L.ToString(1)
route := L.ToString(2)
table := L.NewTable()
L.SetGlobal(script, table)
routes[route] = script
responses[route] = ""
logger.Debug("Adding route: ", route, " -> ", script)
return 0
}
func get_request(L *lua.LState) int {
url := L.ToString(1)
logger.Debug("GET: ", url)
resp, err := http.Get(url)
if err != nil {
logger.Error(err)
return 0
}
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.Fatal(err)
return 0
}
L.Push(lua.LString(string(body)))
return 1
}
func parse_xml_feed(L *lua.LState) int {
xml_str := L.ToString(1)
var rss RssRoot
err := xml.Unmarshal([]byte(xml_str), &rss)
if err != nil {
logger.Fatal(err)
return 0
}
feed_items := L.NewTable()
items := rss.Channel.Items
for i := range items {
ud := L.NewUserData()
ud.Value = &items[i]
L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName))
feed_items.Append(ud)
}
L.Push(feed_items)
return 1
}
// TODO: Move this into the rss table
func create_rss_feed(L *lua.LState) int {
//TODO: Check that the parameters are correct
feed := RssRoot{}
//TODO: Bind RssImage to lua and pass it in
// Also needs to be added to the database and then re-served through the server
title := L.CheckString(1)
link := L.CheckString(2)
image := L.CheckUserData(3)
// TODO: Channel needs to also be passed in
feed.Channel = RssChannel{
Title: title,
Image: *image.Value.(*RssImage),
Link: link,
}
// Append the items to the feed
// Notably the <channel> tag as <item>s
entries := L.CheckTable(4)
if entries.Len() == 0 {
logger.Error("No entries to create rss feed found in table")
return 0
}
for i := 1; i <= entries.Len(); i++ {
entry := entries.RawGetInt(i)
ud := entry.(*lua.LUserData)
item := ud.Value.(*RssItem)
feed.Channel.Items = append(feed.Channel.Items, *item)
}
// sort by date
sort.Sort(ByPubDate(feed.Channel.Items))
output, err := xml.MarshalIndent(feed, "", " ")
if err != nil {
logger.Fatal(err)
return 0
}
L.Push(lua.LString(string(output)))
return 1
}
func luaLogInfo(L *lua.LState) int {
loc := L.Where(1)
msg := L.ToString(1)
logger.Info(loc, msg)
return 0
}
func luaLogDebug(L *lua.LState) int {
// Get the where and message
loc := L.Where(1)
msg := L.ToString(1)
logger.Debug(loc, msg)
return 0
}
func luaLogError(L *lua.LState) int {
loc := L.Where(1)
msg := L.ToString(1)
logger.Error(loc, msg)
return 0
}
func luaLogFatal(L *lua.LState) int {
loc := L.Where(1)
msg := L.ToString(1)
logger.Fatal(loc, msg)
return 0
}