summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorBenjamin Chausse <benjamin@chausse.xyz>2025-02-03 01:12:45 -0500
committerBenjamin Chausse <benjamin@chausse.xyz>2025-02-03 01:12:45 -0500
commit5389e1a5d26fdbf2441fa5a1e101999e8449b9d1 (patch)
tree069cd37cb8e556c1ba3b47c3ea8576a1aa91ea2c /internal
Batman
Diffstat (limited to 'internal')
-rw-r--r--internal/app/action.go47
-rw-r--r--internal/app/command.go23
-rw-r--r--internal/app/flags.go127
-rw-r--r--internal/logging/context.go39
-rw-r--r--internal/logging/logging.go92
-rw-r--r--internal/logging/trace.go61
-rw-r--r--internal/manualgen/manualgen.go21
-rw-r--r--internal/server/model/.gitignore2
-rw-r--r--internal/server/server.go16
-rw-r--r--internal/server/setup.go34
-rw-r--r--internal/server/task.go33
-rw-r--r--internal/server/user.go34
-rw-r--r--internal/storage/db.go18
-rw-r--r--internal/storage/schema.go38
-rw-r--r--internal/storage/setup.go201
-rw-r--r--internal/tagging/tagging.go23
-rw-r--r--internal/util/context.go8
17 files changed, 817 insertions, 0 deletions
diff --git a/internal/app/action.go b/internal/app/action.go
new file mode 100644
index 0000000..b579253
--- /dev/null
+++ b/internal/app/action.go
@@ -0,0 +1,47 @@
+package app
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+
+ "github.com/ChausseBenjamin/rafta/internal/logging"
+ "github.com/ChausseBenjamin/rafta/internal/server"
+ "github.com/ChausseBenjamin/rafta/internal/storage"
+ "github.com/urfave/cli/v3"
+)
+
+func action(ctx context.Context, cmd *cli.Command) error {
+ err := logging.Setup(
+ cmd.String(FlagLogLevel),
+ cmd.String(FlagLogFormat),
+ cmd.String(FlagLogOutput),
+ )
+ if err != nil {
+ slog.Warn("Error(s) occured during logger initialization", logging.ErrKey, err)
+ }
+
+ slog.Info("Starting rafta server")
+
+ // TODO: Setup the db
+ store, err := storage.Setup(cmd.String(FlagDBPath))
+ if err != nil {
+ slog.Error("Unable to setup database", logging.ErrKey, err)
+ }
+
+ srv, lis, err := server.Setup(cmd.Int(FlagListenPort), store)
+ if err != nil {
+ slog.Error("Unable to setup server", logging.ErrKey, err)
+
+ return err
+ }
+
+ slog.Info(fmt.Sprintf("Listening on port %d", cmd.Int(FlagListenPort)))
+ if err := srv.Serve(lis); err != nil {
+ slog.Error("Server runtime error", logging.ErrKey, err)
+
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/app/command.go b/internal/app/command.go
new file mode 100644
index 0000000..5c32ba2
--- /dev/null
+++ b/internal/app/command.go
@@ -0,0 +1,23 @@
+package app
+
+import (
+ "github.com/urfave/cli/v3"
+)
+
+const (
+ AppName = "rafta"
+ AppUsage = "Really, Another Freaking Todo App?!"
+)
+
+var version = "COMPILED"
+
+func Command() *cli.Command {
+ return &cli.Command{
+ Name: AppName,
+ Usage: AppUsage,
+ Authors: []any{"Benjamin Chausse <benjamin@chausse.xyz>"},
+ Version: version,
+ Flags: flags(),
+ Action: action,
+ }
+}
diff --git a/internal/app/flags.go b/internal/app/flags.go
new file mode 100644
index 0000000..e96b8ef
--- /dev/null
+++ b/internal/app/flags.go
@@ -0,0 +1,127 @@
+package app
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "os"
+ "strings"
+
+ "github.com/ChausseBenjamin/rafta/internal/logging"
+ "github.com/ChausseBenjamin/rafta/internal/server"
+ "github.com/urfave/cli/v3"
+)
+
+const (
+ FlagListenPort = "port"
+ FlagLogLevel = "log-level"
+ FlagLogFormat = "log-format"
+ FlagLogOutput = "log-output"
+ FlagDBPath = "database"
+)
+
+func flags() []cli.Flag {
+ return []cli.Flag{
+ // Logging {{{
+ &cli.StringFlag{
+ Name: FlagLogFormat,
+ Aliases: []string{"f"},
+ Value: "plain",
+ Usage: "plain, json",
+ Sources: cli.EnvVars("LOG_FORMAT"),
+ Action: validateLogFormat,
+ },
+ &cli.StringFlag{
+ Name: FlagLogOutput,
+ Aliases: []string{"o"},
+ Value: "stdout",
+ Usage: "stdout, stderr, file",
+ Sources: cli.EnvVars("LOG_OUTPUT"),
+ Action: validateLogOutput,
+ },
+ &cli.StringFlag{
+ Name: FlagLogLevel,
+ Aliases: []string{"l"},
+ Value: "info",
+ Usage: "debug, info, warn, error",
+ Sources: cli.EnvVars("LOG_LEVEL"),
+ Action: validateLogLevel,
+ }, // }}}
+ // gRPC server {{{
+ &cli.IntFlag{
+ Name: FlagListenPort,
+ Aliases: []string{"p"},
+ Value: 1234,
+ Sources: cli.EnvVars("LISTEN_PORT"),
+ Action: validateListenPort,
+ }, // }}}
+ // Database {{{
+ &cli.StringFlag{
+ Name: FlagDBPath,
+ Aliases: []string{"d"},
+ Value: "store.db",
+ Usage: "database file",
+ Sources: cli.EnvVars("DATABASE_PATH"),
+ Action: validateDBPath,
+ }, // }}}
+ }
+}
+
+func validateLogOutput(ctx context.Context, cmd *cli.Command, s string) error {
+ switch {
+ case s == "stdout" || s == "stderr":
+ return nil
+ default:
+ // assume file
+ f, err := os.OpenFile(s, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
+ if err != nil {
+ slog.ErrorContext(
+ ctx,
+ fmt.Sprintf("Error creating/accessing provided log file %s", s),
+ )
+ return err
+ }
+ defer f.Close()
+ return nil
+ }
+}
+
+func validateLogLevel(ctx context.Context, cmd *cli.Command, s string) error {
+ for _, lvl := range []string{"deb", "inf", "warn", "err"} {
+ if strings.Contains(strings.ToLower(s), lvl) {
+ return nil
+ }
+ }
+ slog.ErrorContext(
+ ctx,
+ fmt.Sprintf("Unknown log level provided: %s", s),
+ )
+ return logging.ErrInvalidLevel
+}
+
+func validateLogFormat(ctx context.Context, cmd *cli.Command, s string) error {
+ s = strings.ToLower(s)
+ if s == "json" || s == "plain" {
+ return nil
+ }
+ return nil
+}
+
+func validateListenPort(ctx context.Context, cmd *cli.Command, p int64) error {
+ if p < 1024 || p > 65535 {
+ slog.ErrorContext(
+ ctx,
+ fmt.Sprintf("Out-of-bound port provided: %d", p),
+ )
+ return server.ErrOutOfBoundsPort
+ }
+ return nil
+}
+
+func validateDBPath(ctx context.Context, cmd *cli.Command, s string) error {
+ // TODO: Ensure the db file is writable.
+ // TODO: Ensure the db file is a valid sqlite3 db.
+ // TODO: Call db.Reset() if either of the above fail.
+ // TODO: Log the error/crash if the db file is not writable.
+ return nil
+}
diff --git a/internal/logging/context.go b/internal/logging/context.go
new file mode 100644
index 0000000..29bad62
--- /dev/null
+++ b/internal/logging/context.go
@@ -0,0 +1,39 @@
+package logging
+
+import (
+ "context"
+ "log/slog"
+)
+
+type ctxTracker struct {
+ ctxKey interface{}
+ logKey string
+ next slog.Handler
+}
+
+func (h ctxTracker) Handle(ctx context.Context, r slog.Record) error {
+ if v := ctx.Value(h.ctxKey); v != nil {
+ r.AddAttrs(slog.Any(h.logKey, v))
+ }
+ return h.next.Handle(ctx, r)
+}
+
+func (h ctxTracker) Enabled(ctx context.Context, lvl slog.Level) bool {
+ return h.next.Enabled(ctx, lvl)
+}
+
+func (h ctxTracker) WithAttrs(attrs []slog.Attr) slog.Handler {
+ return h.next.WithAttrs(attrs)
+}
+
+func (h ctxTracker) WithGroup(name string) slog.Handler {
+ return h.next.WithGroup(name)
+}
+
+func withTrackedContext(current slog.Handler, ctxKey interface{}, logKey string) *ctxTracker {
+ return &ctxTracker{
+ ctxKey: ctxKey,
+ logKey: logKey,
+ next: current,
+ }
+}
diff --git a/internal/logging/logging.go b/internal/logging/logging.go
new file mode 100644
index 0000000..91a9734
--- /dev/null
+++ b/internal/logging/logging.go
@@ -0,0 +1,92 @@
+package logging
+
+import (
+ "errors"
+ "io"
+ "log/slog"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/ChausseBenjamin/rafta/internal/util"
+ "github.com/charmbracelet/log"
+)
+
+const (
+ ErrKey = "error_message"
+)
+
+var (
+ ErrInvalidLevel = errors.New("invalid log level")
+ ErrInvalidFormat = errors.New("invalid log format")
+)
+
+func Setup(lvlStr, fmtStr, outStr string) error {
+ output, outputErr := setOutput(outStr)
+ format, formatErr := setFormat(fmtStr)
+ level, levelErr := setLevel(lvlStr)
+
+ prefixStr := ""
+ if format != log.JSONFormatter {
+ prefixStr = "Rafta 🚢"
+ }
+
+ var h slog.Handler = log.NewWithOptions(
+ output,
+ log.Options{
+ TimeFormat: time.DateTime,
+ Prefix: prefixStr,
+ Level: level,
+ ReportCaller: true,
+ Formatter: format,
+ },
+ )
+
+ h = withTrackedContext(h, util.ReqIDKey, "request_id")
+ h = withStackTrace(h)
+ slog.SetDefault(slog.New(h))
+ return errors.Join(outputErr, formatErr, levelErr)
+}
+
+func setLevel(target string) (log.Level, error) {
+ for _, l := range []struct {
+ prefix string
+ level log.Level
+ }{
+ {"deb", log.DebugLevel},
+ {"inf", log.InfoLevel},
+ {"warn", log.WarnLevel},
+ {"err", log.ErrorLevel},
+ } {
+ if strings.HasPrefix(strings.ToLower(target), l.prefix) {
+ return l.level, nil
+ }
+ }
+ return log.InfoLevel, ErrInvalidLevel
+}
+
+func setFormat(f string) (log.Formatter, error) {
+ switch f {
+ case "plain", "text":
+ return log.TextFormatter, nil
+ case "json", "structured":
+ return log.JSONFormatter, nil
+ }
+ return log.TextFormatter, ErrInvalidFormat
+}
+
+func setOutput(path string) (io.Writer, error) {
+ switch path {
+ case "stdout":
+ return os.Stdout, nil
+ case "stderr":
+ return os.Stderr, nil
+ default:
+ f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
+ if err != nil {
+ return os.Stdout, err
+ } else {
+ return f, nil
+ }
+ }
+}
diff --git a/internal/logging/trace.go b/internal/logging/trace.go
new file mode 100644
index 0000000..8900727
--- /dev/null
+++ b/internal/logging/trace.go
@@ -0,0 +1,61 @@
+package logging
+
+import (
+ "context"
+ "log/slog"
+ "runtime"
+ "strconv"
+ "strings"
+)
+
+const (
+ prgCount = 20
+ defSkip = 6
+)
+
+type stackTracer struct {
+ h slog.Handler
+ nSkip int
+}
+
+func (h stackTracer) Enabled(ctx context.Context, lvl slog.Level) bool {
+ return h.h.Enabled(ctx, lvl)
+}
+
+func (h stackTracer) WithAttrs(attrs []slog.Attr) slog.Handler {
+ return h.h.WithAttrs(attrs)
+}
+
+func (h stackTracer) WithGroup(name string) slog.Handler {
+ return h.h.WithGroup(name)
+}
+
+func (h stackTracer) Handle(ctx context.Context, r slog.Record) error {
+ if r.Level < slog.LevelError {
+ return h.h.Handle(ctx, r)
+ }
+
+ trace := h.GetTrace()
+ r.AddAttrs(slog.String("trace", trace))
+
+ return h.h.Handle(ctx, r)
+}
+
+func (h stackTracer) GetTrace() string {
+ var b strings.Builder
+ pc := make([]uintptr, prgCount)
+ n := runtime.Callers(h.nSkip, pc)
+ frames := runtime.CallersFrames(pc[:n])
+
+ for frame, more := frames.Next(); more; frame, more = frames.Next() {
+ b.WriteString(frame.Function + "\n " + frame.File + ":" + strconv.Itoa(frame.Line) + "\n")
+ }
+ return b.String()
+}
+
+func withStackTrace(h slog.Handler) slog.Handler {
+ return stackTracer{
+ h: h,
+ nSkip: defSkip,
+ }
+}
diff --git a/internal/manualgen/manualgen.go b/internal/manualgen/manualgen.go
new file mode 100644
index 0000000..88a72e2
--- /dev/null
+++ b/internal/manualgen/manualgen.go
@@ -0,0 +1,21 @@
+package main
+
+import (
+ "log/slog"
+ "os"
+
+ "github.com/ChausseBenjamin/rafta/internal/app"
+ "github.com/ChausseBenjamin/rafta/internal/logging"
+ docs "github.com/urfave/cli-docs/v3"
+)
+
+func main() {
+ a := app.Command()
+
+ man, err := docs.ToManWithSection(a, 1)
+ if err != nil {
+ slog.Error("failed to generate markdown", logging.ErrKey, err)
+ os.Exit(1)
+ }
+ os.Stdout.Write([]byte(man))
+}
diff --git a/internal/server/model/.gitignore b/internal/server/model/.gitignore
new file mode 100644
index 0000000..d6b7ef3
--- /dev/null
+++ b/internal/server/model/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
diff --git a/internal/server/server.go b/internal/server/server.go
new file mode 100644
index 0000000..70e5861
--- /dev/null
+++ b/internal/server/server.go
@@ -0,0 +1,16 @@
+package server
+
+import (
+ "database/sql"
+ "errors"
+
+ m "github.com/ChausseBenjamin/rafta/internal/server/model"
+)
+
+var ErrOutOfBoundsPort = errors.New("port out of bounds")
+
+// Implements ComsServer interface
+type Service struct {
+ store *sql.DB
+ m.UnimplementedRaftaServer
+}
diff --git a/internal/server/setup.go b/internal/server/setup.go
new file mode 100644
index 0000000..551cd95
--- /dev/null
+++ b/internal/server/setup.go
@@ -0,0 +1,34 @@
+package server
+
+import (
+ "database/sql"
+ "fmt"
+ "log/slog"
+ "net"
+
+ "github.com/ChausseBenjamin/rafta/internal/logging"
+ m "github.com/ChausseBenjamin/rafta/internal/server/model"
+ "github.com/ChausseBenjamin/rafta/internal/tagging"
+ "google.golang.org/grpc"
+)
+
+func Setup(port int64, storage *sql.DB) (*grpc.Server, net.Listener, error) {
+ lis, err := net.Listen(
+ "tcp",
+ fmt.Sprintf(":%d", port),
+ )
+ if err != nil {
+ slog.Error("Unable to create listener", logging.ErrKey, err)
+ return nil, nil, err
+ }
+
+ grpcServer := grpc.NewServer(
+ grpc.ChainUnaryInterceptor(
+ tagging.UnaryInterceptor,
+ ),
+ )
+ raftaService := &Service{store: storage}
+ m.RegisterRaftaServer(grpcServer, raftaService)
+
+ return grpcServer, lis, nil
+}
diff --git a/internal/server/task.go b/internal/server/task.go
new file mode 100644
index 0000000..b8cf0b8
--- /dev/null
+++ b/internal/server/task.go
@@ -0,0 +1,33 @@
+package server
+
+import (
+ "context"
+ "log/slog"
+
+ m "github.com/ChausseBenjamin/rafta/internal/server/model"
+ "google.golang.org/protobuf/types/known/emptypb"
+)
+
+func (s Service) GetUserTasks(ctx context.Context, id *m.UserID) (*m.TaskList, error) {
+ slog.ErrorContext(ctx, "GetUserTasks not implemented yet")
+ return nil, nil
+}
+
+func (s Service) GetTask(ctx context.Context, id *m.TaskID) (*m.Task, error) {
+ return nil, nil
+}
+
+func (s Service) DeleteTask(ctx context.Context, id *m.TaskID) (*emptypb.Empty, error) {
+ slog.ErrorContext(ctx, "DeleteTask not implemented yet")
+ return nil, nil
+}
+
+func (s Service) UpdateTask(ctx context.Context, t *m.Task) (*m.Task, error) {
+ slog.ErrorContext(ctx, "UpdateTask not implemented yet")
+ return t, nil
+}
+
+func (s Service) CreateTask(ctx context.Context, data *m.TaskData) (*m.Task, error) {
+ slog.ErrorContext(ctx, "CreateTask not implemented yet")
+ return nil, nil
+}
diff --git a/internal/server/user.go b/internal/server/user.go
new file mode 100644
index 0000000..c4a97c4
--- /dev/null
+++ b/internal/server/user.go
@@ -0,0 +1,34 @@
+package server
+
+import (
+ "context"
+ "log/slog"
+
+ m "github.com/ChausseBenjamin/rafta/internal/server/model"
+ "google.golang.org/protobuf/types/known/emptypb"
+)
+
+func (s Service) GetAllUsers(ctx context.Context, empty *emptypb.Empty) (*m.UserList, error) {
+ slog.ErrorContext(ctx, "GetAllUsers not implemented yet")
+ return nil, nil
+}
+
+func (s Service) GetUser(ctx context.Context, id *m.UserID) (*m.User, error) {
+ slog.ErrorContext(ctx, "GetUser not implemented yet")
+ return nil, nil
+}
+
+func (s Service) DeleteUser(ctx context.Context, id *m.UserID) (*emptypb.Empty, error) {
+ slog.ErrorContext(ctx, "DeleteUser not implemented yet")
+ return nil, nil
+}
+
+func (s Service) UpdateUser(ctx context.Context, u *m.User) (*m.User, error) {
+ slog.ErrorContext(ctx, "UpdateUser not implemented yet")
+ return nil, nil
+}
+
+func (s Service) CreateUser(ctx context.Context, data *m.UserData) (*m.User, error) {
+ slog.ErrorContext(ctx, "CreateUser not implemented yet")
+ return nil, nil
+}
diff --git a/internal/storage/db.go b/internal/storage/db.go
new file mode 100644
index 0000000..06fe31c
--- /dev/null
+++ b/internal/storage/db.go
@@ -0,0 +1,18 @@
+package storage
+
+import (
+ "context"
+ "database/sql"
+ "log/slog"
+
+ "github.com/ChausseBenjamin/rafta/internal/util"
+)
+
+func GetDB(ctx context.Context) *sql.DB {
+ db, ok := ctx.Value(util.DBKey).(*sql.DB)
+ if !ok {
+ slog.Error("Unable to retrieve database from context")
+ return nil
+ }
+ return db
+}
diff --git a/internal/storage/schema.go b/internal/storage/schema.go
new file mode 100644
index 0000000..042c09b
--- /dev/null
+++ b/internal/storage/schema.go
@@ -0,0 +1,38 @@
+package storage
+
+func opts() string {
+ return "?_foreign_keys=on&_journal_mode=WAL"
+}
+
+func schema() string {
+ return `
+CREATE TABLE User (
+ UserID INTEGER PRIMARY KEY AUTOINCREMENT,
+ Name TEXT NOT NULL,
+ Email TEXT NOT NULL UNIQUE
+);
+
+CREATE TABLE Task (
+ TaskID INTEGER PRIMARY KEY AUTOINCREMENT,
+ Title TEXT NOT NULL,
+ Description TEXT NOT NULL,
+ Due DATE NOT NULL,
+ Do DATE NOT NULL,
+ Owner INTEGER NOT NULL,
+ FOREIGN KEY (Owner) REFERENCES User(UserID) ON DELETE CASCADE
+);
+
+CREATE TABLE Tag (
+ TagID INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
+ Name TEXT NOT NULL UNIQUE
+);
+
+CREATE TABLE TaskTag (
+ TaskUUID INTEGER NOT NULL,
+ TagID INTEGER NOT NULL,
+ PRIMARY KEY (TaskUUID, TagID),
+ FOREIGN KEY (TaskUUID) REFERENCES Task(UUID) ON DELETE CASCADE,
+ FOREIGN KEY (TagID) REFERENCES Tag(TagID) ON DELETE CASCADE
+);
+`
+}
diff --git a/internal/storage/setup.go b/internal/storage/setup.go
new file mode 100644
index 0000000..f103b15
--- /dev/null
+++ b/internal/storage/setup.go
@@ -0,0 +1,201 @@
+package storage
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "strings"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+var (
+ ErrIntegrityCheckFailed = errors.New("integrity check failed")
+ ErrForeignKeysDisabled = errors.New("foreign_keys pragma is not enabled")
+ ErrJournalModeInvalid = errors.New("journal_mode is not wal")
+ ErrSchemaMismatch = errors.New("schema does not match expected definition")
+)
+
+// Setup opens the SQLite DB at path, verifies its integrity and schema,
+// and returns the valid DB handle. On any error, it backs up the old file
+// (if it exists) and calls genDB() to initialize a valid schema.
+func Setup(path string) (*sql.DB, error) {
+ _, statErr := os.Stat(path)
+ exists := statErr == nil
+
+ // If file doesn't exist, generate new DB.
+ if !exists {
+ return genDB(path)
+ }
+
+ db, err := sql.Open("sqlite3", path+opts())
+ if err != nil {
+ slog.Error("failed to open DB", "error", err)
+ backupFile(path)
+ return genDB(path)
+ }
+
+ // Integrity check.
+ var integrity string
+ if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil {
+ slog.Error("integrity check query failed", "error", err)
+ db.Close()
+ backupFile(path)
+ return genDB(path)
+ }
+ if integrity != "ok" {
+ slog.Error("integrity check failed", "error", ErrIntegrityCheckFailed)
+ db.Close()
+ backupFile(path)
+ return genDB(path)
+ }
+
+ // Validate schema and pragmas.
+ if err = validateSchema(db); err != nil {
+ slog.Error("schema validation failed", "error", err)
+ db.Close()
+ backupFile(path)
+ return genDB(path)
+ }
+
+ return db, nil
+}
+
+// validateSchema verifies that required pragmas and table definitions are set.
+func validateSchema(db *sql.DB) error {
+ // Check PRAGMA foreign_keys = on.
+ var fk int
+ if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil {
+ return err
+ }
+ if fk != 1 {
+ return ErrForeignKeysDisabled
+ }
+
+ // Check PRAGMA journal_mode = wal.
+ var jm string
+ if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil {
+ return err
+ }
+ if strings.ToLower(jm) != "wal" {
+ return ErrJournalModeInvalid
+ }
+
+ // Define required table definitions (as substrings in lower-case).
+ type tableCheck struct {
+ name string
+ substrings []string
+ }
+
+ checks := []tableCheck{
+ {
+ name: "User",
+ substrings: []string{
+ "create table user",
+ "userid", "integer", "primary key", "autoincrement",
+ "name", "text", "not null",
+ "email", "text", "not null", "unique",
+ },
+ },
+ {
+ name: "Task",
+ substrings: []string{
+ "create table task",
+ "taskid", "integer", "primary key", "autoincrement",
+ "title", "not null",
+ "description", "not null",
+ "due", "date", "not null",
+ "do", "date", "not null",
+ "owner", "integer", "not null",
+ "foreign key", "references user",
+ },
+ },
+ {
+ name: "Tag",
+ substrings: []string{
+ "create table tag",
+ "tagid", "integer", "primary key", "autoincrement",
+ "name", "text", "not null", "unique",
+ },
+ },
+ {
+ name: "TaskTag",
+ substrings: []string{
+ "create table tasktag",
+ "taskuuid", "integer", "not null",
+ "tagid", "integer", "not null",
+ "primary key",
+ "foreign key", "references task",
+ "foreign key", "references tag",
+ },
+ },
+ }
+
+ for _, chk := range checks {
+ sqlDef, err := fetchTableSQL(db, chk.name)
+ if err != nil {
+ return fmt.Errorf("failed to fetch definition for table %s: %w", chk.name, err)
+ }
+ lc := strings.ToLower(sqlDef)
+ for _, substr := range chk.substrings {
+ if !strings.Contains(lc, substr) {
+ return fmt.Errorf("%w: table %s missing %q", ErrSchemaMismatch, chk.name, substr)
+ }
+ }
+ }
+
+ return nil
+}
+
+func fetchTableSQL(db *sql.DB, table string) (string, error) {
+ var sqlDef sql.NullString
+ err := db.QueryRow("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&sqlDef)
+ if err != nil {
+ return "", err
+ }
+ if !sqlDef.Valid {
+ return "", fmt.Errorf("no SQL definition found for table %s", table)
+ }
+ return sqlDef.String, nil
+}
+
+// backupFile renames the existing file by appending a ".bak" suffix.
+func backupFile(path string) {
+ backupPath := path + ".bak"
+ // If backupPath exists, append a timestamp.
+ if _, err := os.Stat(backupPath); err == nil {
+ backupPath = fmt.Sprintf("%s.%d.bak", path, os.Getpid())
+ }
+ if err := os.Rename(path, backupPath); err != nil {
+ slog.Error("failed to backup file", "error", err, "original", path, "backup", backupPath)
+ } else {
+ slog.Info("backed up corrupt DB", "original", path, "backup", backupPath)
+ }
+}
+
+// genDB creates a new database at path with the valid schema.
+func genDB(path string) (*sql.DB, error) {
+ db, err := sql.Open("sqlite3", path)
+ if err != nil {
+ slog.Error("failed to create DB", "error", err)
+ return nil, err
+ }
+
+ // Set pragmas.
+ if _, err := db.Exec("PRAGMA foreign_keys = on; PRAGMA journal_mode = wal;"); err != nil {
+ slog.Error("failed to set pragmas", "error", err)
+ db.Close()
+ return nil, err
+ }
+
+ if _, err := db.Exec(schema()); err != nil {
+ slog.Error("failed to initialize schema", "error", err)
+ db.Close()
+ return nil, err
+ }
+
+ slog.Info("created new blank DB with valid schema", "path", path)
+ return db, nil
+}
diff --git a/internal/tagging/tagging.go b/internal/tagging/tagging.go
new file mode 100644
index 0000000..8e8efe9
--- /dev/null
+++ b/internal/tagging/tagging.go
@@ -0,0 +1,23 @@
+package tagging
+
+import (
+ "context"
+ "log/slog"
+
+ "github.com/ChausseBenjamin/rafta/internal/logging"
+ "github.com/ChausseBenjamin/rafta/internal/util"
+ "github.com/hashicorp/go-uuid"
+ "google.golang.org/grpc"
+)
+
+// gRPC interceptor to tag requests with a unique identifier
+func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ id, err := uuid.GenerateUUID()
+ if err != nil {
+ slog.Error("Unable to generate UUID for request", logging.ErrKey, err)
+ }
+ ctx = context.WithValue(ctx, util.ReqIDKey, id)
+ slog.DebugContext(ctx, "Tagging request with UUID", "value", id)
+
+ return handler(ctx, req)
+}
diff --git a/internal/util/context.go b/internal/util/context.go
new file mode 100644
index 0000000..e7cbab2
--- /dev/null
+++ b/internal/util/context.go
@@ -0,0 +1,8 @@
+package util
+
+type ContextKey uint8
+
+const (
+ DBKey ContextKey = iota
+ ReqIDKey
+)