summaryrefslogtreecommitdiff
path: root/internal/storage/setup.go
diff options
context:
space:
mode:
authorBenjamin Chausse <benjamin@chausse.xyz>2025-02-03 01:12:45 -0500
committerBenjamin Chausse <benjamin@chausse.xyz>2025-02-03 01:12:45 -0500
commit5389e1a5d26fdbf2441fa5a1e101999e8449b9d1 (patch)
tree069cd37cb8e556c1ba3b47c3ea8576a1aa91ea2c /internal/storage/setup.go
Batman
Diffstat (limited to 'internal/storage/setup.go')
-rw-r--r--internal/storage/setup.go201
1 files changed, 201 insertions, 0 deletions
diff --git a/internal/storage/setup.go b/internal/storage/setup.go
new file mode 100644
index 0000000..f103b15
--- /dev/null
+++ b/internal/storage/setup.go
@@ -0,0 +1,201 @@
+package storage
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "strings"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+var (
+ ErrIntegrityCheckFailed = errors.New("integrity check failed")
+ ErrForeignKeysDisabled = errors.New("foreign_keys pragma is not enabled")
+ ErrJournalModeInvalid = errors.New("journal_mode is not wal")
+ ErrSchemaMismatch = errors.New("schema does not match expected definition")
+)
+
+// Setup opens the SQLite DB at path, verifies its integrity and schema,
+// and returns the valid DB handle. On any error, it backs up the old file
+// (if it exists) and calls genDB() to initialize a valid schema.
+func Setup(path string) (*sql.DB, error) {
+ _, statErr := os.Stat(path)
+ exists := statErr == nil
+
+ // If file doesn't exist, generate new DB.
+ if !exists {
+ return genDB(path)
+ }
+
+ db, err := sql.Open("sqlite3", path+opts())
+ if err != nil {
+ slog.Error("failed to open DB", "error", err)
+ backupFile(path)
+ return genDB(path)
+ }
+
+ // Integrity check.
+ var integrity string
+ if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil {
+ slog.Error("integrity check query failed", "error", err)
+ db.Close()
+ backupFile(path)
+ return genDB(path)
+ }
+ if integrity != "ok" {
+ slog.Error("integrity check failed", "error", ErrIntegrityCheckFailed)
+ db.Close()
+ backupFile(path)
+ return genDB(path)
+ }
+
+ // Validate schema and pragmas.
+ if err = validateSchema(db); err != nil {
+ slog.Error("schema validation failed", "error", err)
+ db.Close()
+ backupFile(path)
+ return genDB(path)
+ }
+
+ return db, nil
+}
+
+// validateSchema verifies that required pragmas and table definitions are set.
+func validateSchema(db *sql.DB) error {
+ // Check PRAGMA foreign_keys = on.
+ var fk int
+ if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil {
+ return err
+ }
+ if fk != 1 {
+ return ErrForeignKeysDisabled
+ }
+
+ // Check PRAGMA journal_mode = wal.
+ var jm string
+ if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil {
+ return err
+ }
+ if strings.ToLower(jm) != "wal" {
+ return ErrJournalModeInvalid
+ }
+
+ // Define required table definitions (as substrings in lower-case).
+ type tableCheck struct {
+ name string
+ substrings []string
+ }
+
+ checks := []tableCheck{
+ {
+ name: "User",
+ substrings: []string{
+ "create table user",
+ "userid", "integer", "primary key", "autoincrement",
+ "name", "text", "not null",
+ "email", "text", "not null", "unique",
+ },
+ },
+ {
+ name: "Task",
+ substrings: []string{
+ "create table task",
+ "taskid", "integer", "primary key", "autoincrement",
+ "title", "not null",
+ "description", "not null",
+ "due", "date", "not null",
+ "do", "date", "not null",
+ "owner", "integer", "not null",
+ "foreign key", "references user",
+ },
+ },
+ {
+ name: "Tag",
+ substrings: []string{
+ "create table tag",
+ "tagid", "integer", "primary key", "autoincrement",
+ "name", "text", "not null", "unique",
+ },
+ },
+ {
+ name: "TaskTag",
+ substrings: []string{
+ "create table tasktag",
+ "taskuuid", "integer", "not null",
+ "tagid", "integer", "not null",
+ "primary key",
+ "foreign key", "references task",
+ "foreign key", "references tag",
+ },
+ },
+ }
+
+ for _, chk := range checks {
+ sqlDef, err := fetchTableSQL(db, chk.name)
+ if err != nil {
+ return fmt.Errorf("failed to fetch definition for table %s: %w", chk.name, err)
+ }
+ lc := strings.ToLower(sqlDef)
+ for _, substr := range chk.substrings {
+ if !strings.Contains(lc, substr) {
+ return fmt.Errorf("%w: table %s missing %q", ErrSchemaMismatch, chk.name, substr)
+ }
+ }
+ }
+
+ return nil
+}
+
+func fetchTableSQL(db *sql.DB, table string) (string, error) {
+ var sqlDef sql.NullString
+ err := db.QueryRow("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&sqlDef)
+ if err != nil {
+ return "", err
+ }
+ if !sqlDef.Valid {
+ return "", fmt.Errorf("no SQL definition found for table %s", table)
+ }
+ return sqlDef.String, nil
+}
+
+// backupFile renames the existing file by appending a ".bak" suffix.
+func backupFile(path string) {
+ backupPath := path + ".bak"
+ // If backupPath exists, append a timestamp.
+ if _, err := os.Stat(backupPath); err == nil {
+ backupPath = fmt.Sprintf("%s.%d.bak", path, os.Getpid())
+ }
+ if err := os.Rename(path, backupPath); err != nil {
+ slog.Error("failed to backup file", "error", err, "original", path, "backup", backupPath)
+ } else {
+ slog.Info("backed up corrupt DB", "original", path, "backup", backupPath)
+ }
+}
+
+// genDB creates a new database at path with the valid schema.
+func genDB(path string) (*sql.DB, error) {
+ db, err := sql.Open("sqlite3", path)
+ if err != nil {
+ slog.Error("failed to create DB", "error", err)
+ return nil, err
+ }
+
+ // Set pragmas.
+ if _, err := db.Exec("PRAGMA foreign_keys = on; PRAGMA journal_mode = wal;"); err != nil {
+ slog.Error("failed to set pragmas", "error", err)
+ db.Close()
+ return nil, err
+ }
+
+ if _, err := db.Exec(schema()); err != nil {
+ slog.Error("failed to initialize schema", "error", err)
+ db.Close()
+ return nil, err
+ }
+
+ slog.Info("created new blank DB with valid schema", "path", path)
+ return db, nil
+}