diff options
author | Benjamin Chausse <benjamin@chausse.xyz> | 2025-02-22 09:59:10 -0500 |
---|---|---|
committer | Benjamin Chausse <benjamin@chausse.xyz> | 2025-02-22 09:59:10 -0500 |
commit | f36f77472a82d6ebfac153aed6d17f154ae239a6 (patch) | |
tree | d749ecc2ebf86a39b15ac3026d3e100d0276442b /internal/db/setup.go | |
parent | 2cb9e5fe823391c09a99424138192d0fbec727af (diff) |
Good foundations
Diffstat (limited to 'internal/db/setup.go')
-rw-r--r-- | internal/db/setup.go | 234 |
1 files changed, 234 insertions, 0 deletions
diff --git a/internal/db/setup.go b/internal/db/setup.go new file mode 100644 index 0000000..038334a --- /dev/null +++ b/internal/db/setup.go @@ -0,0 +1,234 @@ +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log/slog" + "os" + "strings" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + ErrForeignKeysDisabled = errors.New("foreign keys are disabled") + ErrIntegrityCheckFailed = errors.New("integrity check failed") + ErrJournalModeInvalid = errors.New("journal mode is not WAL") + ErrSchemaMismatch = errors.New("database schema does not match expected definition") + ErrTableMissing = errors.New("table is missing") + ErrTableStructure = errors.New("table structure does not match expected schema") +) + +type Store struct { + DB *sql.DB + Common []*sql.Stmt +} + +func new(db *sql.DB) (*Store, error) { + lst := make([]*sql.Stmt, len(commonTransactions)) + for _, common := range commonTransactions { + stmt, err := db.Prepare(common.Cmd) + if err != nil { + return nil, err + } + lst[common.Name] = stmt + } + return &Store{ + DB: db, + Common: lst, + }, nil +} + +func (s *Store) Close() error { + errs := make([]error, len(s.Common)+1) + for i, s := range s.Common { + if s != nil { + errs[i] = s.Close() + } + } + errs[len(s.Common)] = s.DB.Close() + return errors.Join(errs...) +} + +// opts returns connection options that enforce our desired pragmas. +func opts() string { + return "?_foreign_keys=on&_journal_mode=WAL" +} + +// Setup opens the SQLite DB at path, verifies its integrity and schema, +// and returns the valid DB handle. If any check fails, it backs up the old file +// and reinitializes the DB using the schema definitions. +func Setup(ctx context.Context, path string) (*Store, error) { + slog.DebugContext(ctx, "Setting up database connection") + + // If file does not exist, generate a new DB. + if _, err := os.Stat(path); err != nil { + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + db, err := sql.Open("sqlite3", path+opts()) + if err != nil { + slog.ErrorContext(ctx, "failed to open DB", "error", err) + backupFile(ctx, path) + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + // Run integrity check. + var integrity string + if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil || integrity != "ok" { + if err != nil { + slog.ErrorContext(ctx, "integrity check query failed", "error", err) + } else { + slog.ErrorContext(ctx, "integrity check failed", "integrity", integrity) + } + db.Close() + backupFile(ctx, path) + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + // Validate the PRAGMA settings and each table's schema. + if err = validateSchema(ctx, db); err != nil { + slog.ErrorContext(ctx, "schema validation failed", "error", err) + db.Close() + backupFile(ctx, path) + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + return new(db) +} + +// validateSchema checks that the PRAGMAs and every table definition match the expected schema. +func validateSchema(ctx context.Context, db *sql.DB) error { + if err := validatePragmas(db); err != nil { + return err + } + for _, table := range schemaDefinitions { + if err := validateTable(ctx, db, table.Name, table.Cmd); err != nil { + return err + } + } + return nil +} + +// validatePragmas ensures that the required PRAGMAs are set. +func validatePragmas(db *sql.DB) error { + var fk int + if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil { + return err + } + if fk != 1 { + return ErrForeignKeysDisabled + } + + var jm string + if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil { + return err + } + if strings.ToLower(jm) != "wal" { + return ErrJournalModeInvalid + } + return nil +} + +// validateTable fetches the stored SQL for the table and compares it +// (after normalization) with the expected definition. +func validateTable(ctx context.Context, db *sql.DB, tableName, expectedSQL string) error { + actualSQL, err := fetchTableSQL(db, tableName) + if err != nil { + slog.ErrorContext(ctx, "failed to fetch table definition", "table", tableName, "error", err) + return ErrSchemaMismatch + } + if actualSQL == "" { + slog.ErrorContext(ctx, "table is missing", "table", tableName) + return ErrTableMissing + } + + normalizedExpected := normalizeSQL(expectedSQL) + normalizedActual := normalizeSQL(actualSQL) + if normalizedExpected != normalizedActual { + slog.ErrorContext(ctx, "table structure does not match expected schema", + "table", tableName, + "expected", normalizedExpected, + "actual", normalizedActual, + ) + return ErrTableStructure + } + return nil +} + +// normalizeSQL removes SQL comments, converts to lowercase, +// collapses whitespace, and removes a trailing semicolon. +func normalizeSQL(sqlStr string) string { + sqlStr = removeSQLComments(sqlStr) + sqlStr = strings.ToLower(sqlStr) + sqlStr = strings.ReplaceAll(sqlStr, "\n", " ") + sqlStr = strings.Join(strings.Fields(sqlStr), " ") + sqlStr = strings.TrimSuffix(sqlStr, ";") + return sqlStr +} + +// removeSQLComments strips out any '--' style comments. +func removeSQLComments(sqlStr string) string { + lines := strings.Split(sqlStr, "\n") + for i, line := range lines { + if idx := strings.Index(line, "--"); idx != -1 { + lines[i] = line[:idx] + } + } + return strings.Join(lines, " ") +} + +// fetchTableSQL retrieves the SQL definition of a table from sqlite_master. +func fetchTableSQL(db *sql.DB, table string) (string, error) { + var sqlDef sql.NullString + err := db.QueryRow( + "SELECT sql FROM sqlite_master WHERE type='table' AND name=?", + table, + ).Scan(&sqlDef) + if err != nil { + return "", err + } + if !sqlDef.Valid { + return "", fmt.Errorf("no SQL definition found for table %s", table) + } + return sqlDef.String, nil +} + +// backupFile renames the existing file by appending a ".bak" (or timestamped) suffix. +func backupFile(ctx context.Context, path string) { + backupPath := path + ".bak" + if _, err := os.Stat(backupPath); err == nil { + backupPath = fmt.Sprintf("%s-%s.bak", path, time.Now().Format(time.RFC3339)) + } + if err := os.Rename(path, backupPath); err != nil { + slog.ErrorContext(ctx, "failed to backup file", + "error", err, + "original", path, + "backup", backupPath, + ) + } else { + slog.InfoContext(ctx, "backed up corrupt DB", + "original", path, + "backup", backupPath, + ) + } +} |