rsslair/db.go

330 lines
7.9 KiB
Go

package main
import (
"database/sql"
"fmt"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
lua "github.com/yuin/gopher-lua"
)
const luaDatabaseTypeName = "sqlite3"
var db *sql.DB
var databaseMethods = map[string]lua.LGFunction{
"check": check,
"insert": insert_rss,
"getRss": GetRssFromDb,
}
func registerDatabaseType(L *lua.LState) {
logger.Debug("Registering database type")
mt := L.NewTypeMetatable("sqlite3")
L.SetGlobal("sqlite3", mt)
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), databaseMethods))
// Create an instance of the database
ud := InitDb(L)
L.SetGlobal("db", ud)
}
// Check if the first lua argument is a *LUserData
func checkDatabase(L *lua.LState) *sql.DB {
ud := L.CheckUserData(1)
if v, ok := ud.Value.(*sql.DB); ok {
return v
}
L.ArgError(1, "sqlite3 expected")
return nil
}
// Constructor for new database userdata
func InitDb(L *lua.LState) *lua.LUserData {
ud := L.NewUserData()
// Create a database connection
database, err := sql.Open("sqlite3", "./rss.db")
if err != nil {
logger.Error("Error opening database: ", err)
} // Create the table with the rss schema if this is the first time
// TODO: Check if the table exists first
sqlStmt := `
CREATE TABLE IF NOT EXISTS rss (id INTEGER PRIMARY KEY, title TEXT, link TEXT, description TEXT, pubDate TEXT, guid TEXT UNIQUE, category TEXT, read INTEGER DEFAULT 0);
`
_, err = database.Exec(sqlStmt)
if err != nil {
logger.Error("Error creating table: ", err)
}
ud.Value = database // Store the database connection in the userdata
db = database
_CreateRSSFeedDbTable(db)
L.SetMetatable(ud, L.GetTypeMetatable(luaDatabaseTypeName)) // Set the metatable for the userdata
return ud // Return the number of values pushed onto the stack
}
func _CreateRSSFeedDbTable(db *sql.DB) {
// Create the table with the rss schema if this is the first time
sqlStmt := `
CREATE TABLE IF NOT EXISTS feed (id INTEGER PRIMARY KEY, title TEXT, link TEXT UNIQUE, description TEXT, lastSyncTime TEXT);
`
_, err := db.Exec(sqlStmt)
if err != nil {
logger.Error("Error creating table: ", err)
}
}
// TODO: Check for duplicate entries before inserting
func insert_rss(L *lua.LState) int {
if L.GetTop() != 2 {
L.ArgError(2, "Expected RSS table")
}
item := L.CheckUserData(2)
if item == nil {
return 0
}
ud := item.Value.(*RssItem)
// Insert the rss into the database
sqlStmt := `
INSERT INTO rss (title, link, description, pubDate, guid, category)
VALUES (?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(sqlStmt, ud.Title, ud.Link, ud.Description, ud.PubDate, ud.Guid, ud.Category)
if err != nil && err.Error() != "UNIQUE constraint failed: rss.guid" {
logger.Error("Error inserting into table: ", err)
}
return 0
}
func GetRssFeed(link string) {
// Get the rss feed from the database
sqlStmt := `
SELECT title, link, description, lastSyncTime FROM feed WHERE link = ?
`
rows, err := db.Query(sqlStmt, link)
if err != nil {
logger.Error("Error querying table: ", err)
}
defer rows.Close()
for rows.Next() {
}
}
func check(L *lua.LState) int {
// Pass in a table of RSSItem
if L.GetTop() != 2 {
L.ArgError(2, "Expected RSS table")
}
items := L.CheckTable(2)
if items == nil {
return 0
}
// Create a map of the rss items
rssMap := make(map[string]*RssItem)
items.ForEach(func(_, value lua.LValue) {
item := value.(*lua.LUserData)
if v, ok := item.Value.(*RssItem); ok {
rssMap[v.Guid] = v
}
})
// SELECT * FROM your_table WHERE your_column_name IN (value1, value2, ..., valueN);
// Create a string with all the guids
guids := make([]string, 0, len(rssMap))
for guid := range rssMap {
id := fmt.Sprintf("%q", guid)
guids = append(guids, id)
}
values := strings.Join(guids, ",")
sqlStmt := `
SELECT guid FROM rss WHERE guid IN (` + values + `)
`
// Return to lua the values that do not exist in the database
rows, err := db.Query(sqlStmt)
if err != nil {
logger.Error("Error querying table: ", err)
}
// Create array to store the previous guids
prevGuids := make([]string, 0, len(rssMap))
defer rows.Close()
for rows.Next() {
var guid string
err = rows.Scan(&guid)
if err != nil {
logger.Error("Error scanning table: ", err)
}
prevGuids = append(prevGuids, guid)
delete(rssMap, guid)
}
table := L.NewTable()
prevTable := L.NewTable()
for _, v := range rssMap {
ud := L.NewUserData()
ud.Value = v
L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName))
table.Append(ud)
}
for _, v := range prevGuids {
prevTable.Append(lua.LString(v))
}
L.Push(prevTable)
L.Push(table)
return 2
}
func GetRssFromDb(L *lua.LState) int {
// Pass in a table of RSSItem
if L.GetTop() != 2 {
L.ArgError(2, "Expected RSS table")
}
items := L.CheckTable(2)
if items == nil {
return 0
}
// Create an array of the rss items
arr := make([]string, 0, items.Len())
items.ForEach(func(_, value lua.LValue) {
item := value.String()
id := fmt.Sprintf("%q", item)
arr = append(arr, id)
})
values := strings.Join(arr, ",")
sqlStmt := `
SELECT title, link, description, pubDate, guid, category FROM rss WHERE guid IN (` + values + `)
`
rows, err := db.Query(sqlStmt)
if err != nil {
logger.Error("Error querying table: ", err)
}
table := L.NewTable()
defer rows.Close()
for rows.Next() {
item := RssItem{}
err = rows.Scan(&item.Title, &item.Link, &item.Description, &item.PubDate, &item.Guid, &item.Category)
// TODO: Handle this error better
if err != nil {
logger.Error("Error scanning table: ", err)
}
ud := L.NewUserData()
ud.Value = &item
L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName))
table.Append(ud)
}
L.Push(table)
return 1
}
func get_rss_entry(L *lua.LState) int {
if L.GetTop() != 2 {
L.ArgError(2, "Expected RSS table")
}
// This function should only be getting data from the database
return 0
}
func get_rss_feed(url string) RssFeed {
logger.Trace("===Executing get_rss_feed for url: ", url, "===")
// Get the rss feed from the database
sqlStmt := `
SELECT title, link, description, lastSyncTime FROM feed WHERE link = ?
`
rows, err := db.Query(sqlStmt, url)
if err != nil {
logger.Error("Error querying table: ", err)
}
var rssFeed RssFeed
defer rows.Close()
for rows.Next() {
err = rows.Scan(&rssFeed.Title, &rssFeed.Link, &rssFeed.Content, &rssFeed.LastSyncTime)
if err != nil {
logger.Error("Error scanning table: ", err)
}
}
return rssFeed
}
func insert_rss_feed(rssFeed RssFeed) {
sqlStmt := `
INSERT INTO feed (title, link, description, lastSyncTime)
VALUES (?, ?, ?, ?)
`
_, err := db.Exec(sqlStmt, rssFeed.Title, rssFeed.Link, rssFeed.Content, rssFeed.LastSyncTime)
if err != nil && err.Error() != "UNIQUE constraint failed: feed.link" {
logger.Error("Error inserting into table: ", err)
}
}
func update_rss_feed(rssFeed RssFeed) {
sqlStmt := `
UPDATE feed SET description = ?, lastSyncTime = ? WHERE link = ?
`
_, err := db.Exec(sqlStmt, rssFeed.Content, rssFeed.LastSyncTime, rssFeed.Link)
if err != nil {
logger.Error("Error updating table: ", err)
}
}
func DoesRssFeedExist(link string) bool {
sqlStmt := `
SELECT EXISTS(SELECT 1 FROM feed WHERE link = ? LIMIT 1)
`
rows, err := db.Query(sqlStmt, link)
if err != nil {
logger.Error("Error querying table: ", err)
}
var exists bool
defer rows.Close()
for rows.Next() {
err = rows.Scan(&exists)
if err != nil {
logger.Error("Error scanning table: ", err)
}
}
return exists
}
func GetLastSyncTime(link string) (time.Time, error) {
sqlStmt := `
SELECT lastSyncTime FROM feed WHERE link = ?
`
rows, err := db.Query(sqlStmt, link)
if err != nil {
logger.Error("Error querying table: ", err)
}
var lastSyncTime string
defer rows.Close()
for rows.Next() {
err = rows.Scan(&lastSyncTime)
if err != nil {
logger.Error("Error scanning table: ", err)
}
}
return time.Parse(time.RFC3339, lastSyncTime)
}
func CloseDb() {
logger.Debug("Closing database")
err := db.Close()
if err != nil {
logger.Error("Error closing database: ", err)
}
}