package main import ( "encoding/xml" "io" "net/http" "os" "path/filepath" "sort" "strings" lua "github.com/yuin/gopher-lua" ) func InitLua() *lua.LState { logger.Info("Initializing lua") L := lua.NewState() // load_log_library(L) bind_lua_functions(L) LoadScripts(L) return L } func LoadLogFunctions(L *lua.LState) { table := L.NewTable() L.SetFuncs(table, map[string]lua.LGFunction{ "info": luaLogInfo, "debug": luaLogDebug, "error": luaLogError, "fatal": luaLogFatal, }) L.SetGlobal("log", table) } func loadScript(L *lua.LState, script string) error { if err := L.DoFile(script); err != nil { logger.Fatal(err) return err } logger.Debug("Loaded script: ", script) return nil } func LoadScripts(L *lua.LState) error { // iterate over the scripts directory and load each script files, err := os.ReadDir("scripts") if err != nil { logger.Fatal(err) return err } loadAllScripts := appConfig.Scripts[0] == "all" || appConfig.Scripts[0] == "" if loadAllScripts { for _, file := range files { var extension = filepath.Ext(file.Name()) if file.IsDir() || extension != ".lua" { continue } if err := loadScript(L, "scripts/"+file.Name()); err != nil { logger.Fatal(err) return err } } } else { for _, script := range appConfig.Scripts { if err := loadScript(L, "scripts/"+script+".lua"); err != nil { logger.Fatal(err) return err } } } return nil } func execute_route(L *lua.LState, route string, script string) (string, error) { logger.Debug("Executing route: ", route) // Get table by script name table := L.GetGlobal(script) // Get the function from the table fn := L.GetField(table, "route") // TODO: The message should say something else if fn.Type() != lua.LTFunction { logger.Error("route function was not found in script: ", script) return "", nil } if err := L.CallByParam(lua.P{ Fn: fn, NRet: 1, Protect: true, }); err != nil { logger.Error(err) return "", err } ret := L.Get(-1) responses[route] = strings.Clone(ret.String()) return ret.String(), nil } func add_html_parser(L *lua.LState) int { // Create a table for the html_parser add_functions html_table := L.NewTable() L.SetFuncs(html_table, map[string]lua.LGFunction{ "parse": newHtmlParser, "select": select_html_node, }) L.Push(html_table) return 1 } func bind_lua_functions(L *lua.LState) error { // Add global functions to lua LoadLogFunctions(L) L.SetGlobal("add_route", L.NewFunction(lua_add_route)) L.SetGlobal("get", L.NewFunction(get_request)) L.SetGlobal("parse_xml_feed", L.NewFunction(parse_xml_feed)) L.SetGlobal("create_rss_feed", L.NewFunction(create_rss_feed)) // Register Golang structs with lua registerDatabaseType(L) registerRSSItemType(L) registerHtmlParserType(L) registerRssImageType(L) return nil } func lua_add_route(L *lua.LState) int { // TODO: Check that the parameters are correct // Handle errors script := L.ToString(1) route := L.ToString(2) table := L.NewTable() L.SetGlobal(script, table) routes[route] = script responses[route] = "" logger.Debug("Adding route: ", route, " -> ", script) return 0 } func get_request(L *lua.LState) int { url := L.ToString(1) logger.Debug("GET: ", url) resp, err := http.Get(url) if err != nil { logger.Error(err) return 0 } body, err := io.ReadAll(resp.Body) if err != nil { logger.Fatal(err) return 0 } L.Push(lua.LString(string(body))) return 1 } func parse_xml_feed(L *lua.LState) int { xml_str := L.ToString(1) var rss RssRoot err := xml.Unmarshal([]byte(xml_str), &rss) if err != nil { logger.Fatal(err) return 0 } feed_items := L.NewTable() items := rss.Channel.Items for i := range items { ud := L.NewUserData() ud.Value = &items[i] L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName)) feed_items.Append(ud) } L.Push(feed_items) return 1 } // TODO: Move this into the rss table func create_rss_feed(L *lua.LState) int { //TODO: Check that the parameters are correct feed := RssRoot{} //TODO: Bind RssImage to lua and pass it in // Also needs to be added to the database and then re-served through the server title := L.CheckString(1) link := L.CheckString(2) image := L.CheckUserData(3) // TODO: Channel needs to also be passed in feed.Channel = RssChannel{ Title: title, Image: *image.Value.(*RssImage), Link: link, } // Append the items to the feed // Notably the tag as s entries := L.CheckTable(4) if entries.Len() == 0 { logger.Error("No entries to create rss feed found in table") return 0 } for i := 1; i <= entries.Len(); i++ { entry := entries.RawGetInt(i) ud := entry.(*lua.LUserData) item := ud.Value.(*RssItem) feed.Channel.Items = append(feed.Channel.Items, *item) } // sort by date sort.Sort(ByPubDate(feed.Channel.Items)) output, err := xml.MarshalIndent(feed, "", " ") if err != nil { logger.Fatal(err) return 0 } L.Push(lua.LString(string(output))) return 1 } func luaLogInfo(L *lua.LState) int { loc := L.Where(1) msg := L.ToString(1) logger.Info(loc, msg) return 0 } func luaLogDebug(L *lua.LState) int { // Get the where and message loc := L.Where(1) msg := L.ToString(1) logger.Debug(loc, msg) return 0 } func luaLogError(L *lua.LState) int { loc := L.Where(1) msg := L.ToString(1) logger.Error(loc, msg) return 0 } func luaLogFatal(L *lua.LState) int { loc := L.Where(1) msg := L.ToString(1) logger.Fatal(loc, msg) return 0 }