package main import ( "database/sql" "fmt" "strings" "time" _ "github.com/mattn/go-sqlite3" lua "github.com/yuin/gopher-lua" ) const luaDatabaseTypeName = "sqlite3" var db *sql.DB var databaseMethods = map[string]lua.LGFunction{ "check": check, "insert": insert_rss, "getRss": GetRssFromDb, } func registerDatabaseType(L *lua.LState) { logger.Debug("Registering database type") mt := L.NewTypeMetatable("sqlite3") L.SetGlobal("sqlite3", mt) L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), databaseMethods)) // Create an instance of the database ud := InitDb(L) L.SetGlobal("db", ud) } // Check if the first lua argument is a *LUserData func checkDatabase(L *lua.LState) *sql.DB { ud := L.CheckUserData(1) if v, ok := ud.Value.(*sql.DB); ok { return v } L.ArgError(1, "sqlite3 expected") return nil } // Constructor for new database userdata func InitDb(L *lua.LState) *lua.LUserData { ud := L.NewUserData() // Create a database connection database, err := sql.Open("sqlite3", "./rss.db") if err != nil { logger.Error("Error opening database: ", err) } // Create the table with the rss schema if this is the first time // TODO: Check if the table exists first sqlStmt := ` CREATE TABLE IF NOT EXISTS rss (id INTEGER PRIMARY KEY, title TEXT, link TEXT, description TEXT, pubDate TEXT, guid TEXT UNIQUE, category TEXT, read INTEGER DEFAULT 0); ` _, err = database.Exec(sqlStmt) if err != nil { logger.Error("Error creating table: ", err) } ud.Value = database // Store the database connection in the userdata db = database _CreateRSSFeedDbTable(db) L.SetMetatable(ud, L.GetTypeMetatable(luaDatabaseTypeName)) // Set the metatable for the userdata return ud // Return the number of values pushed onto the stack } func _CreateRSSFeedDbTable(db *sql.DB) { // Create the table with the rss schema if this is the first time sqlStmt := ` CREATE TABLE IF NOT EXISTS feed (id INTEGER PRIMARY KEY, title TEXT, link TEXT UNIQUE, description TEXT, lastSyncTime TEXT); ` _, err := db.Exec(sqlStmt) if err != nil { logger.Error("Error creating table: ", err) } } // TODO: Check for duplicate entries before inserting func insert_rss(L *lua.LState) int { if L.GetTop() != 2 { L.ArgError(2, "Expected RSS table") } item := L.CheckUserData(2) if item == nil { return 0 } ud := item.Value.(*RssItem) // Insert the rss into the database sqlStmt := ` INSERT INTO rss (title, link, description, pubDate, guid, category) VALUES (?, ?, ?, ?, ?, ?) ` _, err := db.Exec(sqlStmt, ud.Title, ud.Link, ud.Description, ud.PubDate, ud.Guid, ud.Category) if err != nil && err.Error() != "UNIQUE constraint failed: rss.guid" { logger.Error("Error inserting into table: ", err) } return 0 } func GetRssFeed(link string) { // Get the rss feed from the database sqlStmt := ` SELECT title, link, description, lastSyncTime FROM feed WHERE link = ? ` rows, err := db.Query(sqlStmt, link) if err != nil { logger.Error("Error querying table: ", err) } defer rows.Close() for rows.Next() { } } func check(L *lua.LState) int { // Pass in a table of RSSItem if L.GetTop() != 2 { L.ArgError(2, "Expected RSS table") } items := L.CheckTable(2) if items == nil { return 0 } // Create a map of the rss items rssMap := make(map[string]*RssItem) items.ForEach(func(_, value lua.LValue) { item := value.(*lua.LUserData) if v, ok := item.Value.(*RssItem); ok { rssMap[v.Guid] = v } }) // SELECT * FROM your_table WHERE your_column_name IN (value1, value2, ..., valueN); // Create a string with all the guids guids := make([]string, 0, len(rssMap)) for guid := range rssMap { id := fmt.Sprintf("%q", guid) guids = append(guids, id) } values := strings.Join(guids, ",") sqlStmt := ` SELECT guid FROM rss WHERE guid IN (` + values + `) ` // Return to lua the values that do not exist in the database rows, err := db.Query(sqlStmt) if err != nil { logger.Error("Error querying table: ", err) } // Create array to store the previous guids prevGuids := make([]string, 0, len(rssMap)) defer rows.Close() for rows.Next() { var guid string err = rows.Scan(&guid) if err != nil { logger.Error("Error scanning table: ", err) } prevGuids = append(prevGuids, guid) delete(rssMap, guid) } table := L.NewTable() prevTable := L.NewTable() for _, v := range rssMap { ud := L.NewUserData() ud.Value = v L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName)) table.Append(ud) } for _, v := range prevGuids { prevTable.Append(lua.LString(v)) } L.Push(prevTable) L.Push(table) return 2 } func GetRssFromDb(L *lua.LState) int { // Pass in a table of RSSItem if L.GetTop() != 2 { L.ArgError(2, "Expected RSS table") } items := L.CheckTable(2) if items == nil { return 0 } // Create an array of the rss items arr := make([]string, 0, items.Len()) items.ForEach(func(_, value lua.LValue) { item := value.String() id := fmt.Sprintf("%q", item) arr = append(arr, id) }) values := strings.Join(arr, ",") sqlStmt := ` SELECT title, link, description, pubDate, guid, category FROM rss WHERE guid IN (` + values + `) ` rows, err := db.Query(sqlStmt) if err != nil { logger.Error("Error querying table: ", err) } table := L.NewTable() defer rows.Close() for rows.Next() { item := RssItem{} err = rows.Scan(&item.Title, &item.Link, &item.Description, &item.PubDate, &item.Guid, &item.Category) // TODO: Handle this error better if err != nil { logger.Error("Error scanning table: ", err) } ud := L.NewUserData() ud.Value = &item L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName)) table.Append(ud) } L.Push(table) return 1 } func get_rss_entry(L *lua.LState) int { if L.GetTop() != 2 { L.ArgError(2, "Expected RSS table") } // This function should only be getting data from the database return 0 } func get_rss_feed(url string) RssFeed { logger.Trace("===Executing get_rss_feed for url: ", url, "===") // Get the rss feed from the database sqlStmt := ` SELECT title, link, description, lastSyncTime FROM feed WHERE link = ? ` rows, err := db.Query(sqlStmt, url) if err != nil { logger.Error("Error querying table: ", err) } var rssFeed RssFeed defer rows.Close() for rows.Next() { err = rows.Scan(&rssFeed.Title, &rssFeed.Link, &rssFeed.Content, &rssFeed.LastSyncTime) if err != nil { logger.Error("Error scanning table: ", err) } } return rssFeed } func insert_rss_feed(rssFeed RssFeed) { sqlStmt := ` INSERT INTO feed (title, link, description, lastSyncTime) VALUES (?, ?, ?, ?) ` _, err := db.Exec(sqlStmt, rssFeed.Title, rssFeed.Link, rssFeed.Content, rssFeed.LastSyncTime) if err != nil && err.Error() != "UNIQUE constraint failed: feed.link" { logger.Error("Error inserting into table: ", err) } } func update_rss_feed(rssFeed RssFeed) { sqlStmt := ` UPDATE feed SET description = ?, lastSyncTime = ? WHERE link = ? ` _, err := db.Exec(sqlStmt, rssFeed.Content, rssFeed.LastSyncTime, rssFeed.Link) if err != nil { logger.Error("Error updating table: ", err) } } func DoesRssFeedExist(link string) bool { sqlStmt := ` SELECT EXISTS(SELECT 1 FROM feed WHERE link = ? LIMIT 1) ` rows, err := db.Query(sqlStmt, link) if err != nil { logger.Error("Error querying table: ", err) } var exists bool defer rows.Close() for rows.Next() { err = rows.Scan(&exists) if err != nil { logger.Error("Error scanning table: ", err) } } return exists } func GetLastSyncTime(link string) (time.Time, error) { sqlStmt := ` SELECT lastSyncTime FROM feed WHERE link = ? ` rows, err := db.Query(sqlStmt, link) if err != nil { logger.Error("Error querying table: ", err) } var lastSyncTime string defer rows.Close() for rows.Next() { err = rows.Scan(&lastSyncTime) if err != nil { logger.Error("Error scanning table: ", err) } } return time.Parse(time.RFC3339, lastSyncTime) } func CloseDb() { logger.Debug("Closing database") err := db.Close() if err != nil { logger.Error("Error closing database: ", err) } }