Pass DB into functions. Implement import feature

Esse commit está contido em:
lawl 2021-03-20 20:47:05 +01:00
commit 5604c310d7
1 arquivos alterados com 28 adições e 11 exclusões

39
main.go
Ver arquivo

@ -1,6 +1,7 @@
package main package main
import ( import (
"bufio"
"database/sql" "database/sql"
"flag" "flag"
"fmt" "fmt"
@ -25,9 +26,7 @@ func createConnection() *sql.DB {
return db return db
} }
func initDatabase() { func initDatabase(conn *sql.DB) {
conn := createConnection()
queryStmt := "CREATE TABLE history(id INTEGER PRIMARY KEY, command varchar(512), timestamp datetime DEFAULT current_timestamp, user varchar(25), hostname varchar(32));\n" + queryStmt := "CREATE TABLE history(id INTEGER PRIMARY KEY, command varchar(512), timestamp datetime DEFAULT current_timestamp, user varchar(25), hostname varchar(32));\n" +
"CREATE VIEW count_by_date AS SELECT COUNT(id), STRFTIME('%Y-%m-%d', timestamp) FROM history GROUP BY strftime('%Y-%m-%d', timestamp)" "CREATE VIEW count_by_date AS SELECT COUNT(id), STRFTIME('%Y-%m-%d', timestamp) FROM history GROUP BY strftime('%Y-%m-%d', timestamp)"
@ -37,9 +36,25 @@ func initDatabase() {
} }
} }
func search(q string) { func importFromStdin(conn *sql.DB) {
conn := createConnection() scanner := bufio.NewScanner(os.Stdin)
_, err := conn.Exec("BEGIN;")
if err != nil {
log.Panic(err)
}
for scanner.Scan() {
add(conn, scanner.Text())
}
_, err = conn.Exec("END;")
if err != nil {
log.Panic(err)
}
}
func search(conn *sql.DB, q string) {
queryStmt := "SELECT command FROM history WHERE command LIKE ? ORDER BY timestamp ASC" queryStmt := "SELECT command FROM history WHERE command LIKE ? ORDER BY timestamp ASC"
rows, err := conn.Query(queryStmt, "%"+q+"%") rows, err := conn.Query(queryStmt, "%"+q+"%")
@ -57,9 +72,7 @@ func search(q string) {
} }
} }
func add(cmd string) { func add(conn *sql.DB, cmd string) {
conn := createConnection()
user := os.Getenv("USER") user := os.Getenv("USER")
hostname, err := os.Hostname() hostname, err := os.Hostname()
if err != nil { if err != nil {
@ -113,6 +126,8 @@ func main() {
cmd := args[0] cmd := args[0]
conn := createConnection()
if cmd == "add" { if cmd == "add" {
if argslen < 2 { if argslen < 2 {
fmt.Fprint(os.Stderr, "Error: You need to provide the command to be added") fmt.Fprint(os.Stderr, "Error: You need to provide the command to be added")
@ -121,18 +136,20 @@ func main() {
historycmd := args[1] historycmd := args[1]
var rgx = regexp.MustCompile("\\s+\\d+\\s+(.*)") var rgx = regexp.MustCompile("\\s+\\d+\\s+(.*)")
rs := rgx.FindStringSubmatch(historycmd) rs := rgx.FindStringSubmatch(historycmd)
add(rs[1]) add(conn, rs[1])
} else if cmd == "search" { } else if cmd == "search" {
if argslen < 2 { if argslen < 2 {
fmt.Fprint(os.Stderr, "Please provide the search query\n") fmt.Fprint(os.Stderr, "Please provide the search query\n")
} }
q := args[1] q := args[1]
search(q) search(conn, q)
} else if cmd == "init" { } else if cmd == "init" {
err := os.MkdirAll(filepath.Dir(databaseLocation()), 0755) err := os.MkdirAll(filepath.Dir(databaseLocation()), 0755)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
initDatabase() initDatabase(conn)
} else if cmd == "import" {
importFromStdin(conn)
} }
} }