330 lines
7.9 KiB
Go
330 lines
7.9 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
lua "github.com/yuin/gopher-lua"
|
|
)
|
|
|
|
const luaDatabaseTypeName = "sqlite3"
|
|
|
|
var db *sql.DB
|
|
|
|
var databaseMethods = map[string]lua.LGFunction{
|
|
"check": check,
|
|
"insert": insert_rss,
|
|
"getRss": GetRssFromDb,
|
|
}
|
|
|
|
func registerDatabaseType(L *lua.LState) {
|
|
logger.Debug("Registering database type")
|
|
mt := L.NewTypeMetatable("sqlite3")
|
|
L.SetGlobal("sqlite3", mt)
|
|
L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), databaseMethods))
|
|
|
|
// Create an instance of the database
|
|
ud := InitDb(L)
|
|
L.SetGlobal("db", ud)
|
|
}
|
|
|
|
// Check if the first lua argument is a *LUserData
|
|
func checkDatabase(L *lua.LState) *sql.DB {
|
|
ud := L.CheckUserData(1)
|
|
if v, ok := ud.Value.(*sql.DB); ok {
|
|
return v
|
|
}
|
|
L.ArgError(1, "sqlite3 expected")
|
|
return nil
|
|
}
|
|
|
|
// Constructor for new database userdata
|
|
func InitDb(L *lua.LState) *lua.LUserData {
|
|
ud := L.NewUserData()
|
|
|
|
// Create a database connection
|
|
database, err := sql.Open("sqlite3", "./rss.db")
|
|
if err != nil {
|
|
logger.Error("Error opening database: ", err)
|
|
} // Create the table with the rss schema if this is the first time
|
|
// TODO: Check if the table exists first
|
|
sqlStmt := `
|
|
CREATE TABLE IF NOT EXISTS rss (id INTEGER PRIMARY KEY, title TEXT, link TEXT, description TEXT, pubDate TEXT, guid TEXT UNIQUE, category TEXT, read INTEGER DEFAULT 0);
|
|
`
|
|
_, err = database.Exec(sqlStmt)
|
|
if err != nil {
|
|
logger.Error("Error creating table: ", err)
|
|
}
|
|
ud.Value = database // Store the database connection in the userdata
|
|
db = database
|
|
|
|
_CreateRSSFeedDbTable(db)
|
|
|
|
L.SetMetatable(ud, L.GetTypeMetatable(luaDatabaseTypeName)) // Set the metatable for the userdata
|
|
return ud // Return the number of values pushed onto the stack
|
|
}
|
|
|
|
func _CreateRSSFeedDbTable(db *sql.DB) {
|
|
// Create the table with the rss schema if this is the first time
|
|
sqlStmt := `
|
|
CREATE TABLE IF NOT EXISTS feed (id INTEGER PRIMARY KEY, title TEXT, link TEXT UNIQUE, description TEXT, lastSyncTime TEXT);
|
|
`
|
|
_, err := db.Exec(sqlStmt)
|
|
if err != nil {
|
|
logger.Error("Error creating table: ", err)
|
|
}
|
|
}
|
|
|
|
|
|
// TODO: Check for duplicate entries before inserting
|
|
func insert_rss(L *lua.LState) int {
|
|
if L.GetTop() != 2 {
|
|
L.ArgError(2, "Expected RSS table")
|
|
}
|
|
item := L.CheckUserData(2)
|
|
if item == nil {
|
|
return 0
|
|
}
|
|
ud := item.Value.(*RssItem)
|
|
|
|
// Insert the rss into the database
|
|
sqlStmt := `
|
|
INSERT INTO rss (title, link, description, pubDate, guid, category)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
`
|
|
_, err := db.Exec(sqlStmt, ud.Title, ud.Link, ud.Description, ud.PubDate, ud.Guid, ud.Category)
|
|
if err != nil && err.Error() != "UNIQUE constraint failed: rss.guid" {
|
|
logger.Error("Error inserting into table: ", err)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func GetRssFeed(link string) {
|
|
// Get the rss feed from the database
|
|
sqlStmt := `
|
|
SELECT title, link, description, lastSyncTime FROM feed WHERE link = ?
|
|
`
|
|
rows, err := db.Query(sqlStmt, link)
|
|
if err != nil {
|
|
logger.Error("Error querying table: ", err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
}
|
|
}
|
|
|
|
func check(L *lua.LState) int {
|
|
// Pass in a table of RSSItem
|
|
if L.GetTop() != 2 {
|
|
L.ArgError(2, "Expected RSS table")
|
|
}
|
|
items := L.CheckTable(2)
|
|
if items == nil {
|
|
return 0
|
|
}
|
|
|
|
// Create a map of the rss items
|
|
rssMap := make(map[string]*RssItem)
|
|
items.ForEach(func(_, value lua.LValue) {
|
|
item := value.(*lua.LUserData)
|
|
if v, ok := item.Value.(*RssItem); ok {
|
|
rssMap[v.Guid] = v
|
|
}
|
|
})
|
|
|
|
// SELECT * FROM your_table WHERE your_column_name IN (value1, value2, ..., valueN);
|
|
// Create a string with all the guids
|
|
guids := make([]string, 0, len(rssMap))
|
|
for guid := range rssMap {
|
|
id := fmt.Sprintf("%q", guid)
|
|
guids = append(guids, id)
|
|
}
|
|
values := strings.Join(guids, ",")
|
|
sqlStmt := `
|
|
SELECT guid FROM rss WHERE guid IN (` + values + `)
|
|
`
|
|
|
|
// Return to lua the values that do not exist in the database
|
|
rows, err := db.Query(sqlStmt)
|
|
if err != nil {
|
|
logger.Error("Error querying table: ", err)
|
|
}
|
|
|
|
// Create array to store the previous guids
|
|
prevGuids := make([]string, 0, len(rssMap))
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var guid string
|
|
err = rows.Scan(&guid)
|
|
if err != nil {
|
|
logger.Error("Error scanning table: ", err)
|
|
}
|
|
prevGuids = append(prevGuids, guid)
|
|
delete(rssMap, guid)
|
|
}
|
|
|
|
table := L.NewTable()
|
|
prevTable := L.NewTable()
|
|
for _, v := range rssMap {
|
|
ud := L.NewUserData()
|
|
ud.Value = v
|
|
L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName))
|
|
table.Append(ud)
|
|
}
|
|
for _, v := range prevGuids {
|
|
prevTable.Append(lua.LString(v))
|
|
}
|
|
|
|
L.Push(prevTable)
|
|
L.Push(table)
|
|
return 2
|
|
}
|
|
|
|
func GetRssFromDb(L *lua.LState) int {
|
|
// Pass in a table of RSSItem
|
|
if L.GetTop() != 2 {
|
|
L.ArgError(2, "Expected RSS table")
|
|
}
|
|
items := L.CheckTable(2)
|
|
if items == nil {
|
|
return 0
|
|
}
|
|
|
|
// Create an array of the rss items
|
|
arr := make([]string, 0, items.Len())
|
|
items.ForEach(func(_, value lua.LValue) {
|
|
item := value.String()
|
|
id := fmt.Sprintf("%q", item)
|
|
arr = append(arr, id)
|
|
})
|
|
|
|
values := strings.Join(arr, ",")
|
|
sqlStmt := `
|
|
SELECT title, link, description, pubDate, guid, category FROM rss WHERE guid IN (` + values + `)
|
|
`
|
|
rows, err := db.Query(sqlStmt)
|
|
if err != nil {
|
|
logger.Error("Error querying table: ", err)
|
|
}
|
|
|
|
table := L.NewTable()
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
item := RssItem{}
|
|
err = rows.Scan(&item.Title, &item.Link, &item.Description, &item.PubDate, &item.Guid, &item.Category)
|
|
// TODO: Handle this error better
|
|
if err != nil {
|
|
logger.Error("Error scanning table: ", err)
|
|
}
|
|
ud := L.NewUserData()
|
|
ud.Value = &item
|
|
L.SetMetatable(ud, L.GetTypeMetatable(luaRSSItemTypeName))
|
|
table.Append(ud)
|
|
}
|
|
L.Push(table)
|
|
return 1
|
|
}
|
|
|
|
func get_rss_entry(L *lua.LState) int {
|
|
if L.GetTop() != 2 {
|
|
L.ArgError(2, "Expected RSS table")
|
|
}
|
|
// This function should only be getting data from the database
|
|
return 0
|
|
}
|
|
|
|
func get_rss_feed(url string) RssFeed {
|
|
logger.Trace("===Executing get_rss_feed for url: ", url, "===")
|
|
// Get the rss feed from the database
|
|
sqlStmt := `
|
|
SELECT title, link, description, lastSyncTime FROM feed WHERE link = ?
|
|
`
|
|
rows, err := db.Query(sqlStmt, url)
|
|
if err != nil {
|
|
logger.Error("Error querying table: ", err)
|
|
}
|
|
|
|
var rssFeed RssFeed
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&rssFeed.Title, &rssFeed.Link, &rssFeed.Content, &rssFeed.LastSyncTime)
|
|
if err != nil {
|
|
logger.Error("Error scanning table: ", err)
|
|
}
|
|
}
|
|
return rssFeed
|
|
}
|
|
|
|
func insert_rss_feed(rssFeed RssFeed) {
|
|
sqlStmt := `
|
|
INSERT INTO feed (title, link, description, lastSyncTime)
|
|
VALUES (?, ?, ?, ?)
|
|
`
|
|
_, err := db.Exec(sqlStmt, rssFeed.Title, rssFeed.Link, rssFeed.Content, rssFeed.LastSyncTime)
|
|
if err != nil && err.Error() != "UNIQUE constraint failed: feed.link" {
|
|
logger.Error("Error inserting into table: ", err)
|
|
}
|
|
}
|
|
|
|
func update_rss_feed(rssFeed RssFeed) {
|
|
sqlStmt := `
|
|
UPDATE feed SET description = ?, lastSyncTime = ? WHERE link = ?
|
|
`
|
|
_, err := db.Exec(sqlStmt, rssFeed.Content, rssFeed.LastSyncTime, rssFeed.Link)
|
|
if err != nil {
|
|
logger.Error("Error updating table: ", err)
|
|
}
|
|
}
|
|
|
|
func DoesRssFeedExist(link string) bool {
|
|
sqlStmt := `
|
|
SELECT EXISTS(SELECT 1 FROM feed WHERE link = ? LIMIT 1)
|
|
`
|
|
rows, err := db.Query(sqlStmt, link)
|
|
if err != nil {
|
|
logger.Error("Error querying table: ", err)
|
|
}
|
|
|
|
var exists bool
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&exists)
|
|
if err != nil {
|
|
logger.Error("Error scanning table: ", err)
|
|
}
|
|
}
|
|
return exists
|
|
}
|
|
|
|
func GetLastSyncTime(link string) (time.Time, error) {
|
|
sqlStmt := `
|
|
SELECT lastSyncTime FROM feed WHERE link = ?
|
|
`
|
|
rows, err := db.Query(sqlStmt, link)
|
|
if err != nil {
|
|
logger.Error("Error querying table: ", err)
|
|
}
|
|
|
|
var lastSyncTime string
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&lastSyncTime)
|
|
if err != nil {
|
|
logger.Error("Error scanning table: ", err)
|
|
}
|
|
}
|
|
return time.Parse(time.RFC3339, lastSyncTime)
|
|
}
|
|
|
|
func CloseDb() {
|
|
logger.Debug("Closing database")
|
|
err := db.Close()
|
|
if err != nil {
|
|
logger.Error("Error closing database: ", err)
|
|
}
|
|
}
|