summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
authorBenjamin Chausse <benjamin@chausse.xyz>2025-02-22 09:59:10 -0500
committerBenjamin Chausse <benjamin@chausse.xyz>2025-02-22 09:59:10 -0500
commitf36f77472a82d6ebfac153aed6d17f154ae239a6 (patch)
treed749ecc2ebf86a39b15ac3026d3e100d0276442b /internal/db
parent2cb9e5fe823391c09a99424138192d0fbec727af (diff)
Good foundations
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/data_validation.go67
-rw-r--r--internal/db/schema_initialisation.go130
-rw-r--r--internal/db/setup.go234
-rw-r--r--internal/db/transaction_definitions.go108
4 files changed, 539 insertions, 0 deletions
diff --git a/internal/db/data_validation.go b/internal/db/data_validation.go
new file mode 100644
index 0000000..1883dc0
--- /dev/null
+++ b/internal/db/data_validation.go
@@ -0,0 +1,67 @@
+package db
+
+import (
+ "context"
+ "database/sql"
+ "log/slog"
+)
+
+const (
+ PublicUserSignupKey = "ACCEPT_PUBLIC_USERS"
+ EnforceHttpsKey = "ENFORCE_HTTPS"
+ MaxUsersKey = "MAX_USERS"
+)
+
+var settingValidations = [...]struct {
+ key string // key to look for in the settings
+ defaultVal string // Default value to init if not set
+}{
+ { // Only admins can create users when false
+ PublicUserSignupKey,
+ "FALSE",
+ },
+ { // If something like traefik manages https, this can be set to
+ // false. But there MUST be https in your stack otherwise
+ // credentials are sent in the clear
+ EnforceHttpsKey,
+ "TRUE",
+ },
+ { // Safeguard to avoid account creation spamming.
+ // An admin can still create users over the limit
+ MaxUsersKey,
+ "25",
+ },
+}
+
+func ValidateSettings(ctx context.Context, db *sql.DB) error {
+ valTx, err := db.PrepareContext(ctx, "SELECT value FROM Settings WHERE key=?")
+ if err != nil {
+ return err
+ }
+ defer valTx.Close()
+
+ newTx, err := db.PrepareContext(ctx, "INSERT INTO Settings (key, value) VALUES (?, ?)")
+ if err != nil {
+ return err
+ }
+ defer newTx.Close()
+
+ for _, s := range settingValidations {
+ var val string
+ err := valTx.QueryRowContext(ctx, s.key).Scan(&val)
+ if err != nil {
+ return err
+ }
+ if val == "" {
+ slog.WarnContext(ctx, "Missing configuration, setting the default",
+ "setting", s.key,
+ "value", s.defaultVal,
+ )
+ _, err := newTx.ExecContext(ctx, s.key, s.defaultVal)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
diff --git a/internal/db/schema_initialisation.go b/internal/db/schema_initialisation.go
new file mode 100644
index 0000000..3963786
--- /dev/null
+++ b/internal/db/schema_initialisation.go
@@ -0,0 +1,130 @@
+package db
+
+import (
+ "context"
+ "database/sql"
+ "log/slog"
+)
+
+// schemaDefinitions is the single source of truth for both creating and validating the DB schema.
+var schemaDefinitions = [...]struct {
+ Name string
+ Cmd string
+}{
+ {
+ "Users",
+ `CREATE TABLE Users (
+ userID TEXT PRIMARY KEY CHECK (length(userID) = 36),
+ name TEXT NOT NULL,
+ email TEXT NOT NULL UNIQUE,
+ createdAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ updatedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ );`,
+ },
+ {
+ "UserSecrets",
+ `CREATE TABLE UserSecrets (
+ userID TEXT PRIMARY KEY,
+ saltAndHash TEXT NOT NULL,
+ FOREIGN KEY (userID) REFERENCES Users(userID) ON DELETE CASCADE
+ );`,
+ },
+ {
+ "Tasks",
+ `CREATE TABLE Tasks (
+ taskID TEXT PRIMARY KEY CHECK (length(taskID) = 36),
+ title TEXT NOT NULL,
+ priority INTEGER NOT NULL DEFAULT 0,
+ description TEXT,
+ due TIMESTAMP,
+ do TIMESTAMP,
+ cron TEXT,
+ cronIsEnabled BOOLEAN NOT NULL DEFAULT FALSE,
+ owner TEXT NOT NULL,
+ FOREIGN KEY (owner) REFERENCES Users(userID) ON DELETE CASCADE
+ );`,
+ },
+ {
+ "Tags",
+ `CREATE TABLE Tags (
+ tagID INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL UNIQUE
+ );`,
+ },
+ {
+ "TaskTags",
+ `CREATE TABLE TaskTags (
+ taskID TEXT NOT NULL,
+ tagID INTEGER NOT NULL,
+ PRIMARY KEY (taskID, tagID),
+ FOREIGN KEY (taskID) REFERENCES Tasks(taskID) ON DELETE CASCADE,
+ FOREIGN KEY (tagID) REFERENCES Tags(tagID) ON DELETE CASCADE
+ );`,
+ },
+ {
+ "Settings",
+ `CREATE TABLE Settings (
+ key TEXT PRIMARY KEY,
+ value TEXT
+ );`,
+ },
+ {
+ "Roles",
+ `CREATE TABLE Roles (
+ role TEXT PRIMARY KEY CHECK (role GLOB '[A-Z_]*')
+ );`,
+ },
+ {
+ "UserRoles",
+ `CREATE TABLE UserRoles (
+ userID TEXT NOT NULL,
+ role TEXT NOT NULL,
+ PRIMARY KEY (userID, role),
+ FOREIGN KEY (userID) REFERENCES Users(userID) ON DELETE CASCADE,
+ FOREIGN KEY (role) REFERENCES Roles(role) ON DELETE CASCADE
+ );`,
+ },
+}
+
+// genDB creates a new database at path using the expected schema definitions.
+func genDB(ctx context.Context, path string) (*sql.DB, error) {
+ db, err := sql.Open("sqlite3", path)
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to create DB", "error", err)
+ return nil, err
+ }
+
+ // Set the required PRAGMAs.
+ if _, err := db.Exec("PRAGMA foreign_keys = on; PRAGMA journal_mode = wal;"); err != nil {
+ slog.ErrorContext(ctx, "failed to set pragmas", "error", err)
+ db.Close()
+ return nil, err
+ }
+
+ // Create tables inside a transaction.
+ tx, err := db.Begin()
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to begin transaction for schema initialization", "error", err)
+ db.Close()
+ return nil, err
+ }
+ for _, table := range schemaDefinitions {
+ if _, err := tx.Exec(table.Cmd); err != nil {
+ slog.ErrorContext(ctx, "failed to initialize schema", "table", table.Name, "error", err)
+ if errRollback := tx.Rollback(); errRollback != nil {
+ slog.ErrorContext(ctx, "failed to rollback schema initialization", "error", errRollback)
+ }
+
+ db.Close()
+ return nil, err
+ }
+ }
+ if err := tx.Commit(); err != nil {
+ slog.ErrorContext(ctx, "failed to commit schema initialization", "error", err)
+ db.Close()
+ return nil, err
+ }
+
+ slog.InfoContext(ctx, "created new blank DB wit h valid schema", "path", path)
+ return db, nil
+}
diff --git a/internal/db/setup.go b/internal/db/setup.go
new file mode 100644
index 0000000..038334a
--- /dev/null
+++ b/internal/db/setup.go
@@ -0,0 +1,234 @@
+package db
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "strings"
+ "time"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+var (
+ ErrForeignKeysDisabled = errors.New("foreign keys are disabled")
+ ErrIntegrityCheckFailed = errors.New("integrity check failed")
+ ErrJournalModeInvalid = errors.New("journal mode is not WAL")
+ ErrSchemaMismatch = errors.New("database schema does not match expected definition")
+ ErrTableMissing = errors.New("table is missing")
+ ErrTableStructure = errors.New("table structure does not match expected schema")
+)
+
+type Store struct {
+ DB *sql.DB
+ Common []*sql.Stmt
+}
+
+func new(db *sql.DB) (*Store, error) {
+ lst := make([]*sql.Stmt, len(commonTransactions))
+ for _, common := range commonTransactions {
+ stmt, err := db.Prepare(common.Cmd)
+ if err != nil {
+ return nil, err
+ }
+ lst[common.Name] = stmt
+ }
+ return &Store{
+ DB: db,
+ Common: lst,
+ }, nil
+}
+
+func (s *Store) Close() error {
+ errs := make([]error, len(s.Common)+1)
+ for i, s := range s.Common {
+ if s != nil {
+ errs[i] = s.Close()
+ }
+ }
+ errs[len(s.Common)] = s.DB.Close()
+ return errors.Join(errs...)
+}
+
+// opts returns connection options that enforce our desired pragmas.
+func opts() string {
+ return "?_foreign_keys=on&_journal_mode=WAL"
+}
+
+// Setup opens the SQLite DB at path, verifies its integrity and schema,
+// and returns the valid DB handle. If any check fails, it backs up the old file
+// and reinitializes the DB using the schema definitions.
+func Setup(ctx context.Context, path string) (*Store, error) {
+ slog.DebugContext(ctx, "Setting up database connection")
+
+ // If file does not exist, generate a new DB.
+ if _, err := os.Stat(path); err != nil {
+ db, err := genDB(ctx, path)
+ if err != nil {
+ return nil, err
+ }
+ return new(db)
+ }
+
+ db, err := sql.Open("sqlite3", path+opts())
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to open DB", "error", err)
+ backupFile(ctx, path)
+ db, err := genDB(ctx, path)
+ if err != nil {
+ return nil, err
+ }
+ return new(db)
+ }
+
+ // Run integrity check.
+ var integrity string
+ if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil || integrity != "ok" {
+ if err != nil {
+ slog.ErrorContext(ctx, "integrity check query failed", "error", err)
+ } else {
+ slog.ErrorContext(ctx, "integrity check failed", "integrity", integrity)
+ }
+ db.Close()
+ backupFile(ctx, path)
+ db, err := genDB(ctx, path)
+ if err != nil {
+ return nil, err
+ }
+ return new(db)
+ }
+
+ // Validate the PRAGMA settings and each table's schema.
+ if err = validateSchema(ctx, db); err != nil {
+ slog.ErrorContext(ctx, "schema validation failed", "error", err)
+ db.Close()
+ backupFile(ctx, path)
+ db, err := genDB(ctx, path)
+ if err != nil {
+ return nil, err
+ }
+ return new(db)
+ }
+
+ return new(db)
+}
+
+// validateSchema checks that the PRAGMAs and every table definition match the expected schema.
+func validateSchema(ctx context.Context, db *sql.DB) error {
+ if err := validatePragmas(db); err != nil {
+ return err
+ }
+ for _, table := range schemaDefinitions {
+ if err := validateTable(ctx, db, table.Name, table.Cmd); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// validatePragmas ensures that the required PRAGMAs are set.
+func validatePragmas(db *sql.DB) error {
+ var fk int
+ if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil {
+ return err
+ }
+ if fk != 1 {
+ return ErrForeignKeysDisabled
+ }
+
+ var jm string
+ if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil {
+ return err
+ }
+ if strings.ToLower(jm) != "wal" {
+ return ErrJournalModeInvalid
+ }
+ return nil
+}
+
+// validateTable fetches the stored SQL for the table and compares it
+// (after normalization) with the expected definition.
+func validateTable(ctx context.Context, db *sql.DB, tableName, expectedSQL string) error {
+ actualSQL, err := fetchTableSQL(db, tableName)
+ if err != nil {
+ slog.ErrorContext(ctx, "failed to fetch table definition", "table", tableName, "error", err)
+ return ErrSchemaMismatch
+ }
+ if actualSQL == "" {
+ slog.ErrorContext(ctx, "table is missing", "table", tableName)
+ return ErrTableMissing
+ }
+
+ normalizedExpected := normalizeSQL(expectedSQL)
+ normalizedActual := normalizeSQL(actualSQL)
+ if normalizedExpected != normalizedActual {
+ slog.ErrorContext(ctx, "table structure does not match expected schema",
+ "table", tableName,
+ "expected", normalizedExpected,
+ "actual", normalizedActual,
+ )
+ return ErrTableStructure
+ }
+ return nil
+}
+
+// normalizeSQL removes SQL comments, converts to lowercase,
+// collapses whitespace, and removes a trailing semicolon.
+func normalizeSQL(sqlStr string) string {
+ sqlStr = removeSQLComments(sqlStr)
+ sqlStr = strings.ToLower(sqlStr)
+ sqlStr = strings.ReplaceAll(sqlStr, "\n", " ")
+ sqlStr = strings.Join(strings.Fields(sqlStr), " ")
+ sqlStr = strings.TrimSuffix(sqlStr, ";")
+ return sqlStr
+}
+
+// removeSQLComments strips out any '--' style comments.
+func removeSQLComments(sqlStr string) string {
+ lines := strings.Split(sqlStr, "\n")
+ for i, line := range lines {
+ if idx := strings.Index(line, "--"); idx != -1 {
+ lines[i] = line[:idx]
+ }
+ }
+ return strings.Join(lines, " ")
+}
+
+// fetchTableSQL retrieves the SQL definition of a table from sqlite_master.
+func fetchTableSQL(db *sql.DB, table string) (string, error) {
+ var sqlDef sql.NullString
+ err := db.QueryRow(
+ "SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
+ table,
+ ).Scan(&sqlDef)
+ if err != nil {
+ return "", err
+ }
+ if !sqlDef.Valid {
+ return "", fmt.Errorf("no SQL definition found for table %s", table)
+ }
+ return sqlDef.String, nil
+}
+
+// backupFile renames the existing file by appending a ".bak" (or timestamped) suffix.
+func backupFile(ctx context.Context, path string) {
+ backupPath := path + ".bak"
+ if _, err := os.Stat(backupPath); err == nil {
+ backupPath = fmt.Sprintf("%s-%s.bak", path, time.Now().Format(time.RFC3339))
+ }
+ if err := os.Rename(path, backupPath); err != nil {
+ slog.ErrorContext(ctx, "failed to backup file",
+ "error", err,
+ "original", path,
+ "backup", backupPath,
+ )
+ } else {
+ slog.InfoContext(ctx, "backed up corrupt DB",
+ "original", path,
+ "backup", backupPath,
+ )
+ }
+}
diff --git a/internal/db/transaction_definitions.go b/internal/db/transaction_definitions.go
new file mode 100644
index 0000000..cdfa433
--- /dev/null
+++ b/internal/db/transaction_definitions.go
@@ -0,0 +1,108 @@
+package db
+
+type transactionName int
+
+const (
+ CreateUser transactionName = iota
+ CreateUserSecret
+ CreateTag
+ CreateRole
+ RemoveUser
+ RemoveRole
+ RemoveUnusedTags
+ AssignRoleToUser
+ AssignTagToTask
+ RemoveRoleFromUser
+ RemoveTagFromTask
+ UpdateSetting
+ GetSingleUser
+ GetAllUsers
+ GetSingleTask
+ GetAllTasks
+ GetSingleUserWithSecretAndRoles
+ GetAllTagsRelatedToTask
+)
+
+var commonTransactions = [...]struct {
+ Name transactionName
+ Cmd string
+}{
+ { // Create a user (including salted secret)
+ Name: CreateUser,
+ Cmd: "INSERT INTO Users (userID, name, email) VALUES (?, ?, ?)",
+ },
+ { // Create user secrets
+ Name: CreateUserSecret,
+ Cmd: "INSERT INTO UserSecrets (userID, saltAndHash) VALUES (?, ?)",
+ },
+ { // Create a tag
+ Name: CreateTag,
+ Cmd: "INSERT INTO Tags (name) VALUES (?)",
+ },
+ { // Create a role
+ Name: CreateRole,
+ Cmd: "INSERT INTO Roles (role) VALUES (?)",
+ },
+ { // Remove a user
+ Name: RemoveUser,
+ Cmd: "DELETE FROM Users WHERE userID = ?",
+ },
+ { // Remove a role
+ Name: RemoveRole,
+ Cmd: "DELETE FROM Roles WHERE role = ?",
+ },
+ { // Remove unused tags (assigned to no tasks)
+ Name: RemoveUnusedTags,
+ Cmd: "DELETE FROM Tags WHERE tagID NOT IN (SELECT tagID FROM TaskTags)",
+ },
+ { // Assign a new role to a user
+ Name: AssignRoleToUser,
+ Cmd: "INSERT INTO UserRoles (userID, role) VALUES (?, ?)",
+ },
+ { // Assign a new tag to a task
+ Name: AssignTagToTask,
+ Cmd: "INSERT INTO TaskTags (taskID, tagID) VALUES (?, ?)",
+ },
+ { // Remove a role from a user
+ Name: RemoveRoleFromUser,
+ Cmd: "DELETE FROM UserRoles WHERE userID = ? AND role = ?",
+ },
+ { // Remove a tag from a task
+ Name: RemoveTagFromTask,
+ Cmd: "DELETE FROM TaskTags WHERE taskID = ? AND tagID = ?",
+ },
+ { // Update a setting KeyPair
+ Name: UpdateSetting,
+ Cmd: "UPDATE Settings SET value = ? WHERE key = ?",
+ },
+ { // Get a single user
+ Name: GetSingleUser,
+ Cmd: "SELECT * FROM Users WHERE userID = ?",
+ },
+ { // Get all users
+ Name: GetAllUsers,
+ Cmd: "SELECT * FROM Users",
+ },
+ { // Get a single task
+ Name: GetSingleTask,
+ Cmd: "SELECT * FROM Tasks WHERE taskID = ?",
+ },
+ { // Get all tasks
+ Name: GetAllTasks,
+ Cmd: "SELECT * FROM Tasks",
+ },
+ { // Get a single user with secret info and roles
+ Name: GetSingleUserWithSecretAndRoles,
+ Cmd: `SELECT u.*, us.saltAndHash, ur.role
+ FROM Users u
+ JOIN UserSecrets us ON u.userID = us.userID
+ LEFT JOIN UserRoles ur ON u.userID = ur.userID
+ WHERE u.userID = ?`,
+ },
+ { // Get all tags related to a task
+ Name: GetAllTagsRelatedToTask,
+ Cmd: `SELECT t.* FROM Tags t
+ JOIN TaskTags tt ON t.tagID = tt.tagID
+ WHERE tt.taskID = ?`,
+ },
+}