diff options
Diffstat (limited to 'internal/storage/setup.go')
-rw-r--r-- | internal/storage/setup.go | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/internal/storage/setup.go b/internal/storage/setup.go new file mode 100644 index 0000000..f103b15 --- /dev/null +++ b/internal/storage/setup.go @@ -0,0 +1,201 @@ +package storage + +import ( + "database/sql" + "errors" + "fmt" + "log/slog" + "os" + "strings" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + ErrIntegrityCheckFailed = errors.New("integrity check failed") + ErrForeignKeysDisabled = errors.New("foreign_keys pragma is not enabled") + ErrJournalModeInvalid = errors.New("journal_mode is not wal") + ErrSchemaMismatch = errors.New("schema does not match expected definition") +) + +// Setup opens the SQLite DB at path, verifies its integrity and schema, +// and returns the valid DB handle. On any error, it backs up the old file +// (if it exists) and calls genDB() to initialize a valid schema. +func Setup(path string) (*sql.DB, error) { + _, statErr := os.Stat(path) + exists := statErr == nil + + // If file doesn't exist, generate new DB. + if !exists { + return genDB(path) + } + + db, err := sql.Open("sqlite3", path+opts()) + if err != nil { + slog.Error("failed to open DB", "error", err) + backupFile(path) + return genDB(path) + } + + // Integrity check. + var integrity string + if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil { + slog.Error("integrity check query failed", "error", err) + db.Close() + backupFile(path) + return genDB(path) + } + if integrity != "ok" { + slog.Error("integrity check failed", "error", ErrIntegrityCheckFailed) + db.Close() + backupFile(path) + return genDB(path) + } + + // Validate schema and pragmas. + if err = validateSchema(db); err != nil { + slog.Error("schema validation failed", "error", err) + db.Close() + backupFile(path) + return genDB(path) + } + + return db, nil +} + +// validateSchema verifies that required pragmas and table definitions are set. +func validateSchema(db *sql.DB) error { + // Check PRAGMA foreign_keys = on. + var fk int + if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil { + return err + } + if fk != 1 { + return ErrForeignKeysDisabled + } + + // Check PRAGMA journal_mode = wal. + var jm string + if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil { + return err + } + if strings.ToLower(jm) != "wal" { + return ErrJournalModeInvalid + } + + // Define required table definitions (as substrings in lower-case). + type tableCheck struct { + name string + substrings []string + } + + checks := []tableCheck{ + { + name: "User", + substrings: []string{ + "create table user", + "userid", "integer", "primary key", "autoincrement", + "name", "text", "not null", + "email", "text", "not null", "unique", + }, + }, + { + name: "Task", + substrings: []string{ + "create table task", + "taskid", "integer", "primary key", "autoincrement", + "title", "not null", + "description", "not null", + "due", "date", "not null", + "do", "date", "not null", + "owner", "integer", "not null", + "foreign key", "references user", + }, + }, + { + name: "Tag", + substrings: []string{ + "create table tag", + "tagid", "integer", "primary key", "autoincrement", + "name", "text", "not null", "unique", + }, + }, + { + name: "TaskTag", + substrings: []string{ + "create table tasktag", + "taskuuid", "integer", "not null", + "tagid", "integer", "not null", + "primary key", + "foreign key", "references task", + "foreign key", "references tag", + }, + }, + } + + for _, chk := range checks { + sqlDef, err := fetchTableSQL(db, chk.name) + if err != nil { + return fmt.Errorf("failed to fetch definition for table %s: %w", chk.name, err) + } + lc := strings.ToLower(sqlDef) + for _, substr := range chk.substrings { + if !strings.Contains(lc, substr) { + return fmt.Errorf("%w: table %s missing %q", ErrSchemaMismatch, chk.name, substr) + } + } + } + + return nil +} + +func fetchTableSQL(db *sql.DB, table string) (string, error) { + var sqlDef sql.NullString + err := db.QueryRow("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&sqlDef) + if err != nil { + return "", err + } + if !sqlDef.Valid { + return "", fmt.Errorf("no SQL definition found for table %s", table) + } + return sqlDef.String, nil +} + +// backupFile renames the existing file by appending a ".bak" suffix. +func backupFile(path string) { + backupPath := path + ".bak" + // If backupPath exists, append a timestamp. + if _, err := os.Stat(backupPath); err == nil { + backupPath = fmt.Sprintf("%s.%d.bak", path, os.Getpid()) + } + if err := os.Rename(path, backupPath); err != nil { + slog.Error("failed to backup file", "error", err, "original", path, "backup", backupPath) + } else { + slog.Info("backed up corrupt DB", "original", path, "backup", backupPath) + } +} + +// genDB creates a new database at path with the valid schema. +func genDB(path string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + slog.Error("failed to create DB", "error", err) + return nil, err + } + + // Set pragmas. + if _, err := db.Exec("PRAGMA foreign_keys = on; PRAGMA journal_mode = wal;"); err != nil { + slog.Error("failed to set pragmas", "error", err) + db.Close() + return nil, err + } + + if _, err := db.Exec(schema()); err != nil { + slog.Error("failed to initialize schema", "error", err) + db.Close() + return nil, err + } + + slog.Info("created new blank DB with valid schema", "path", path) + return db, nil +} |