diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/data_validation.go | 67 | ||||
-rw-r--r-- | internal/db/schema_initialisation.go | 130 | ||||
-rw-r--r-- | internal/db/setup.go | 234 | ||||
-rw-r--r-- | internal/db/transaction_definitions.go | 108 |
4 files changed, 539 insertions, 0 deletions
diff --git a/internal/db/data_validation.go b/internal/db/data_validation.go new file mode 100644 index 0000000..1883dc0 --- /dev/null +++ b/internal/db/data_validation.go @@ -0,0 +1,67 @@ +package db + +import ( + "context" + "database/sql" + "log/slog" +) + +const ( + PublicUserSignupKey = "ACCEPT_PUBLIC_USERS" + EnforceHttpsKey = "ENFORCE_HTTPS" + MaxUsersKey = "MAX_USERS" +) + +var settingValidations = [...]struct { + key string // key to look for in the settings + defaultVal string // Default value to init if not set +}{ + { // Only admins can create users when false + PublicUserSignupKey, + "FALSE", + }, + { // If something like traefik manages https, this can be set to + // false. But there MUST be https in your stack otherwise + // credentials are sent in the clear + EnforceHttpsKey, + "TRUE", + }, + { // Safeguard to avoid account creation spamming. + // An admin can still create users over the limit + MaxUsersKey, + "25", + }, +} + +func ValidateSettings(ctx context.Context, db *sql.DB) error { + valTx, err := db.PrepareContext(ctx, "SELECT value FROM Settings WHERE key=?") + if err != nil { + return err + } + defer valTx.Close() + + newTx, err := db.PrepareContext(ctx, "INSERT INTO Settings (key, value) VALUES (?, ?)") + if err != nil { + return err + } + defer newTx.Close() + + for _, s := range settingValidations { + var val string + err := valTx.QueryRowContext(ctx, s.key).Scan(&val) + if err != nil { + return err + } + if val == "" { + slog.WarnContext(ctx, "Missing configuration, setting the default", + "setting", s.key, + "value", s.defaultVal, + ) + _, err := newTx.ExecContext(ctx, s.key, s.defaultVal) + if err != nil { + return err + } + } + } + return nil +} diff --git a/internal/db/schema_initialisation.go b/internal/db/schema_initialisation.go new file mode 100644 index 0000000..3963786 --- /dev/null +++ b/internal/db/schema_initialisation.go @@ -0,0 +1,130 @@ +package db + +import ( + "context" + "database/sql" + "log/slog" +) + +// schemaDefinitions is the single source of truth for both creating and validating the DB schema. +var schemaDefinitions = [...]struct { + Name string + Cmd string +}{ + { + "Users", + `CREATE TABLE Users ( + userID TEXT PRIMARY KEY CHECK (length(userID) = 36), + name TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, + createdAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updatedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP + );`, + }, + { + "UserSecrets", + `CREATE TABLE UserSecrets ( + userID TEXT PRIMARY KEY, + saltAndHash TEXT NOT NULL, + FOREIGN KEY (userID) REFERENCES Users(userID) ON DELETE CASCADE + );`, + }, + { + "Tasks", + `CREATE TABLE Tasks ( + taskID TEXT PRIMARY KEY CHECK (length(taskID) = 36), + title TEXT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + description TEXT, + due TIMESTAMP, + do TIMESTAMP, + cron TEXT, + cronIsEnabled BOOLEAN NOT NULL DEFAULT FALSE, + owner TEXT NOT NULL, + FOREIGN KEY (owner) REFERENCES Users(userID) ON DELETE CASCADE + );`, + }, + { + "Tags", + `CREATE TABLE Tags ( + tagID INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE + );`, + }, + { + "TaskTags", + `CREATE TABLE TaskTags ( + taskID TEXT NOT NULL, + tagID INTEGER NOT NULL, + PRIMARY KEY (taskID, tagID), + FOREIGN KEY (taskID) REFERENCES Tasks(taskID) ON DELETE CASCADE, + FOREIGN KEY (tagID) REFERENCES Tags(tagID) ON DELETE CASCADE + );`, + }, + { + "Settings", + `CREATE TABLE Settings ( + key TEXT PRIMARY KEY, + value TEXT + );`, + }, + { + "Roles", + `CREATE TABLE Roles ( + role TEXT PRIMARY KEY CHECK (role GLOB '[A-Z_]*') + );`, + }, + { + "UserRoles", + `CREATE TABLE UserRoles ( + userID TEXT NOT NULL, + role TEXT NOT NULL, + PRIMARY KEY (userID, role), + FOREIGN KEY (userID) REFERENCES Users(userID) ON DELETE CASCADE, + FOREIGN KEY (role) REFERENCES Roles(role) ON DELETE CASCADE + );`, + }, +} + +// genDB creates a new database at path using the expected schema definitions. +func genDB(ctx context.Context, path string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + slog.ErrorContext(ctx, "failed to create DB", "error", err) + return nil, err + } + + // Set the required PRAGMAs. + if _, err := db.Exec("PRAGMA foreign_keys = on; PRAGMA journal_mode = wal;"); err != nil { + slog.ErrorContext(ctx, "failed to set pragmas", "error", err) + db.Close() + return nil, err + } + + // Create tables inside a transaction. + tx, err := db.Begin() + if err != nil { + slog.ErrorContext(ctx, "failed to begin transaction for schema initialization", "error", err) + db.Close() + return nil, err + } + for _, table := range schemaDefinitions { + if _, err := tx.Exec(table.Cmd); err != nil { + slog.ErrorContext(ctx, "failed to initialize schema", "table", table.Name, "error", err) + if errRollback := tx.Rollback(); errRollback != nil { + slog.ErrorContext(ctx, "failed to rollback schema initialization", "error", errRollback) + } + + db.Close() + return nil, err + } + } + if err := tx.Commit(); err != nil { + slog.ErrorContext(ctx, "failed to commit schema initialization", "error", err) + db.Close() + return nil, err + } + + slog.InfoContext(ctx, "created new blank DB wit h valid schema", "path", path) + return db, nil +} diff --git a/internal/db/setup.go b/internal/db/setup.go new file mode 100644 index 0000000..038334a --- /dev/null +++ b/internal/db/setup.go @@ -0,0 +1,234 @@ +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log/slog" + "os" + "strings" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + ErrForeignKeysDisabled = errors.New("foreign keys are disabled") + ErrIntegrityCheckFailed = errors.New("integrity check failed") + ErrJournalModeInvalid = errors.New("journal mode is not WAL") + ErrSchemaMismatch = errors.New("database schema does not match expected definition") + ErrTableMissing = errors.New("table is missing") + ErrTableStructure = errors.New("table structure does not match expected schema") +) + +type Store struct { + DB *sql.DB + Common []*sql.Stmt +} + +func new(db *sql.DB) (*Store, error) { + lst := make([]*sql.Stmt, len(commonTransactions)) + for _, common := range commonTransactions { + stmt, err := db.Prepare(common.Cmd) + if err != nil { + return nil, err + } + lst[common.Name] = stmt + } + return &Store{ + DB: db, + Common: lst, + }, nil +} + +func (s *Store) Close() error { + errs := make([]error, len(s.Common)+1) + for i, s := range s.Common { + if s != nil { + errs[i] = s.Close() + } + } + errs[len(s.Common)] = s.DB.Close() + return errors.Join(errs...) +} + +// opts returns connection options that enforce our desired pragmas. +func opts() string { + return "?_foreign_keys=on&_journal_mode=WAL" +} + +// Setup opens the SQLite DB at path, verifies its integrity and schema, +// and returns the valid DB handle. If any check fails, it backs up the old file +// and reinitializes the DB using the schema definitions. +func Setup(ctx context.Context, path string) (*Store, error) { + slog.DebugContext(ctx, "Setting up database connection") + + // If file does not exist, generate a new DB. + if _, err := os.Stat(path); err != nil { + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + db, err := sql.Open("sqlite3", path+opts()) + if err != nil { + slog.ErrorContext(ctx, "failed to open DB", "error", err) + backupFile(ctx, path) + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + // Run integrity check. + var integrity string + if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil || integrity != "ok" { + if err != nil { + slog.ErrorContext(ctx, "integrity check query failed", "error", err) + } else { + slog.ErrorContext(ctx, "integrity check failed", "integrity", integrity) + } + db.Close() + backupFile(ctx, path) + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + // Validate the PRAGMA settings and each table's schema. + if err = validateSchema(ctx, db); err != nil { + slog.ErrorContext(ctx, "schema validation failed", "error", err) + db.Close() + backupFile(ctx, path) + db, err := genDB(ctx, path) + if err != nil { + return nil, err + } + return new(db) + } + + return new(db) +} + +// validateSchema checks that the PRAGMAs and every table definition match the expected schema. +func validateSchema(ctx context.Context, db *sql.DB) error { + if err := validatePragmas(db); err != nil { + return err + } + for _, table := range schemaDefinitions { + if err := validateTable(ctx, db, table.Name, table.Cmd); err != nil { + return err + } + } + return nil +} + +// validatePragmas ensures that the required PRAGMAs are set. +func validatePragmas(db *sql.DB) error { + var fk int + if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil { + return err + } + if fk != 1 { + return ErrForeignKeysDisabled + } + + var jm string + if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil { + return err + } + if strings.ToLower(jm) != "wal" { + return ErrJournalModeInvalid + } + return nil +} + +// validateTable fetches the stored SQL for the table and compares it +// (after normalization) with the expected definition. +func validateTable(ctx context.Context, db *sql.DB, tableName, expectedSQL string) error { + actualSQL, err := fetchTableSQL(db, tableName) + if err != nil { + slog.ErrorContext(ctx, "failed to fetch table definition", "table", tableName, "error", err) + return ErrSchemaMismatch + } + if actualSQL == "" { + slog.ErrorContext(ctx, "table is missing", "table", tableName) + return ErrTableMissing + } + + normalizedExpected := normalizeSQL(expectedSQL) + normalizedActual := normalizeSQL(actualSQL) + if normalizedExpected != normalizedActual { + slog.ErrorContext(ctx, "table structure does not match expected schema", + "table", tableName, + "expected", normalizedExpected, + "actual", normalizedActual, + ) + return ErrTableStructure + } + return nil +} + +// normalizeSQL removes SQL comments, converts to lowercase, +// collapses whitespace, and removes a trailing semicolon. +func normalizeSQL(sqlStr string) string { + sqlStr = removeSQLComments(sqlStr) + sqlStr = strings.ToLower(sqlStr) + sqlStr = strings.ReplaceAll(sqlStr, "\n", " ") + sqlStr = strings.Join(strings.Fields(sqlStr), " ") + sqlStr = strings.TrimSuffix(sqlStr, ";") + return sqlStr +} + +// removeSQLComments strips out any '--' style comments. +func removeSQLComments(sqlStr string) string { + lines := strings.Split(sqlStr, "\n") + for i, line := range lines { + if idx := strings.Index(line, "--"); idx != -1 { + lines[i] = line[:idx] + } + } + return strings.Join(lines, " ") +} + +// fetchTableSQL retrieves the SQL definition of a table from sqlite_master. +func fetchTableSQL(db *sql.DB, table string) (string, error) { + var sqlDef sql.NullString + err := db.QueryRow( + "SELECT sql FROM sqlite_master WHERE type='table' AND name=?", + table, + ).Scan(&sqlDef) + if err != nil { + return "", err + } + if !sqlDef.Valid { + return "", fmt.Errorf("no SQL definition found for table %s", table) + } + return sqlDef.String, nil +} + +// backupFile renames the existing file by appending a ".bak" (or timestamped) suffix. +func backupFile(ctx context.Context, path string) { + backupPath := path + ".bak" + if _, err := os.Stat(backupPath); err == nil { + backupPath = fmt.Sprintf("%s-%s.bak", path, time.Now().Format(time.RFC3339)) + } + if err := os.Rename(path, backupPath); err != nil { + slog.ErrorContext(ctx, "failed to backup file", + "error", err, + "original", path, + "backup", backupPath, + ) + } else { + slog.InfoContext(ctx, "backed up corrupt DB", + "original", path, + "backup", backupPath, + ) + } +} diff --git a/internal/db/transaction_definitions.go b/internal/db/transaction_definitions.go new file mode 100644 index 0000000..cdfa433 --- /dev/null +++ b/internal/db/transaction_definitions.go @@ -0,0 +1,108 @@ +package db + +type transactionName int + +const ( + CreateUser transactionName = iota + CreateUserSecret + CreateTag + CreateRole + RemoveUser + RemoveRole + RemoveUnusedTags + AssignRoleToUser + AssignTagToTask + RemoveRoleFromUser + RemoveTagFromTask + UpdateSetting + GetSingleUser + GetAllUsers + GetSingleTask + GetAllTasks + GetSingleUserWithSecretAndRoles + GetAllTagsRelatedToTask +) + +var commonTransactions = [...]struct { + Name transactionName + Cmd string +}{ + { // Create a user (including salted secret) + Name: CreateUser, + Cmd: "INSERT INTO Users (userID, name, email) VALUES (?, ?, ?)", + }, + { // Create user secrets + Name: CreateUserSecret, + Cmd: "INSERT INTO UserSecrets (userID, saltAndHash) VALUES (?, ?)", + }, + { // Create a tag + Name: CreateTag, + Cmd: "INSERT INTO Tags (name) VALUES (?)", + }, + { // Create a role + Name: CreateRole, + Cmd: "INSERT INTO Roles (role) VALUES (?)", + }, + { // Remove a user + Name: RemoveUser, + Cmd: "DELETE FROM Users WHERE userID = ?", + }, + { // Remove a role + Name: RemoveRole, + Cmd: "DELETE FROM Roles WHERE role = ?", + }, + { // Remove unused tags (assigned to no tasks) + Name: RemoveUnusedTags, + Cmd: "DELETE FROM Tags WHERE tagID NOT IN (SELECT tagID FROM TaskTags)", + }, + { // Assign a new role to a user + Name: AssignRoleToUser, + Cmd: "INSERT INTO UserRoles (userID, role) VALUES (?, ?)", + }, + { // Assign a new tag to a task + Name: AssignTagToTask, + Cmd: "INSERT INTO TaskTags (taskID, tagID) VALUES (?, ?)", + }, + { // Remove a role from a user + Name: RemoveRoleFromUser, + Cmd: "DELETE FROM UserRoles WHERE userID = ? AND role = ?", + }, + { // Remove a tag from a task + Name: RemoveTagFromTask, + Cmd: "DELETE FROM TaskTags WHERE taskID = ? AND tagID = ?", + }, + { // Update a setting KeyPair + Name: UpdateSetting, + Cmd: "UPDATE Settings SET value = ? WHERE key = ?", + }, + { // Get a single user + Name: GetSingleUser, + Cmd: "SELECT * FROM Users WHERE userID = ?", + }, + { // Get all users + Name: GetAllUsers, + Cmd: "SELECT * FROM Users", + }, + { // Get a single task + Name: GetSingleTask, + Cmd: "SELECT * FROM Tasks WHERE taskID = ?", + }, + { // Get all tasks + Name: GetAllTasks, + Cmd: "SELECT * FROM Tasks", + }, + { // Get a single user with secret info and roles + Name: GetSingleUserWithSecretAndRoles, + Cmd: `SELECT u.*, us.saltAndHash, ur.role + FROM Users u + JOIN UserSecrets us ON u.userID = us.userID + LEFT JOIN UserRoles ur ON u.userID = ur.userID + WHERE u.userID = ?`, + }, + { // Get all tags related to a task + Name: GetAllTagsRelatedToTask, + Cmd: `SELECT t.* FROM Tags t + JOIN TaskTags tt ON t.tagID = tt.tagID + WHERE tt.taskID = ?`, + }, +} |