diff options
author | Benjamin Chausse <benjamin@chausse.xyz> | 2025-02-03 01:12:45 -0500 |
---|---|---|
committer | Benjamin Chausse <benjamin@chausse.xyz> | 2025-02-03 01:12:45 -0500 |
commit | 5389e1a5d26fdbf2441fa5a1e101999e8449b9d1 (patch) | |
tree | 069cd37cb8e556c1ba3b47c3ea8576a1aa91ea2c /internal |
Batman
Diffstat (limited to 'internal')
-rw-r--r-- | internal/app/action.go | 47 | ||||
-rw-r--r-- | internal/app/command.go | 23 | ||||
-rw-r--r-- | internal/app/flags.go | 127 | ||||
-rw-r--r-- | internal/logging/context.go | 39 | ||||
-rw-r--r-- | internal/logging/logging.go | 92 | ||||
-rw-r--r-- | internal/logging/trace.go | 61 | ||||
-rw-r--r-- | internal/manualgen/manualgen.go | 21 | ||||
-rw-r--r-- | internal/server/model/.gitignore | 2 | ||||
-rw-r--r-- | internal/server/server.go | 16 | ||||
-rw-r--r-- | internal/server/setup.go | 34 | ||||
-rw-r--r-- | internal/server/task.go | 33 | ||||
-rw-r--r-- | internal/server/user.go | 34 | ||||
-rw-r--r-- | internal/storage/db.go | 18 | ||||
-rw-r--r-- | internal/storage/schema.go | 38 | ||||
-rw-r--r-- | internal/storage/setup.go | 201 | ||||
-rw-r--r-- | internal/tagging/tagging.go | 23 | ||||
-rw-r--r-- | internal/util/context.go | 8 |
17 files changed, 817 insertions, 0 deletions
diff --git a/internal/app/action.go b/internal/app/action.go new file mode 100644 index 0000000..b579253 --- /dev/null +++ b/internal/app/action.go @@ -0,0 +1,47 @@ +package app + +import ( + "context" + "fmt" + "log/slog" + + "github.com/ChausseBenjamin/rafta/internal/logging" + "github.com/ChausseBenjamin/rafta/internal/server" + "github.com/ChausseBenjamin/rafta/internal/storage" + "github.com/urfave/cli/v3" +) + +func action(ctx context.Context, cmd *cli.Command) error { + err := logging.Setup( + cmd.String(FlagLogLevel), + cmd.String(FlagLogFormat), + cmd.String(FlagLogOutput), + ) + if err != nil { + slog.Warn("Error(s) occured during logger initialization", logging.ErrKey, err) + } + + slog.Info("Starting rafta server") + + // TODO: Setup the db + store, err := storage.Setup(cmd.String(FlagDBPath)) + if err != nil { + slog.Error("Unable to setup database", logging.ErrKey, err) + } + + srv, lis, err := server.Setup(cmd.Int(FlagListenPort), store) + if err != nil { + slog.Error("Unable to setup server", logging.ErrKey, err) + + return err + } + + slog.Info(fmt.Sprintf("Listening on port %d", cmd.Int(FlagListenPort))) + if err := srv.Serve(lis); err != nil { + slog.Error("Server runtime error", logging.ErrKey, err) + + return err + } + + return nil +} diff --git a/internal/app/command.go b/internal/app/command.go new file mode 100644 index 0000000..5c32ba2 --- /dev/null +++ b/internal/app/command.go @@ -0,0 +1,23 @@ +package app + +import ( + "github.com/urfave/cli/v3" +) + +const ( + AppName = "rafta" + AppUsage = "Really, Another Freaking Todo App?!" +) + +var version = "COMPILED" + +func Command() *cli.Command { + return &cli.Command{ + Name: AppName, + Usage: AppUsage, + Authors: []any{"Benjamin Chausse <benjamin@chausse.xyz>"}, + Version: version, + Flags: flags(), + Action: action, + } +} diff --git a/internal/app/flags.go b/internal/app/flags.go new file mode 100644 index 0000000..e96b8ef --- /dev/null +++ b/internal/app/flags.go @@ -0,0 +1,127 @@ +package app + +import ( + "context" + "fmt" + "log/slog" + "os" + "strings" + + "github.com/ChausseBenjamin/rafta/internal/logging" + "github.com/ChausseBenjamin/rafta/internal/server" + "github.com/urfave/cli/v3" +) + +const ( + FlagListenPort = "port" + FlagLogLevel = "log-level" + FlagLogFormat = "log-format" + FlagLogOutput = "log-output" + FlagDBPath = "database" +) + +func flags() []cli.Flag { + return []cli.Flag{ + // Logging {{{ + &cli.StringFlag{ + Name: FlagLogFormat, + Aliases: []string{"f"}, + Value: "plain", + Usage: "plain, json", + Sources: cli.EnvVars("LOG_FORMAT"), + Action: validateLogFormat, + }, + &cli.StringFlag{ + Name: FlagLogOutput, + Aliases: []string{"o"}, + Value: "stdout", + Usage: "stdout, stderr, file", + Sources: cli.EnvVars("LOG_OUTPUT"), + Action: validateLogOutput, + }, + &cli.StringFlag{ + Name: FlagLogLevel, + Aliases: []string{"l"}, + Value: "info", + Usage: "debug, info, warn, error", + Sources: cli.EnvVars("LOG_LEVEL"), + Action: validateLogLevel, + }, // }}} + // gRPC server {{{ + &cli.IntFlag{ + Name: FlagListenPort, + Aliases: []string{"p"}, + Value: 1234, + Sources: cli.EnvVars("LISTEN_PORT"), + Action: validateListenPort, + }, // }}} + // Database {{{ + &cli.StringFlag{ + Name: FlagDBPath, + Aliases: []string{"d"}, + Value: "store.db", + Usage: "database file", + Sources: cli.EnvVars("DATABASE_PATH"), + Action: validateDBPath, + }, // }}} + } +} + +func validateLogOutput(ctx context.Context, cmd *cli.Command, s string) error { + switch { + case s == "stdout" || s == "stderr": + return nil + default: + // assume file + f, err := os.OpenFile(s, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + slog.ErrorContext( + ctx, + fmt.Sprintf("Error creating/accessing provided log file %s", s), + ) + return err + } + defer f.Close() + return nil + } +} + +func validateLogLevel(ctx context.Context, cmd *cli.Command, s string) error { + for _, lvl := range []string{"deb", "inf", "warn", "err"} { + if strings.Contains(strings.ToLower(s), lvl) { + return nil + } + } + slog.ErrorContext( + ctx, + fmt.Sprintf("Unknown log level provided: %s", s), + ) + return logging.ErrInvalidLevel +} + +func validateLogFormat(ctx context.Context, cmd *cli.Command, s string) error { + s = strings.ToLower(s) + if s == "json" || s == "plain" { + return nil + } + return nil +} + +func validateListenPort(ctx context.Context, cmd *cli.Command, p int64) error { + if p < 1024 || p > 65535 { + slog.ErrorContext( + ctx, + fmt.Sprintf("Out-of-bound port provided: %d", p), + ) + return server.ErrOutOfBoundsPort + } + return nil +} + +func validateDBPath(ctx context.Context, cmd *cli.Command, s string) error { + // TODO: Ensure the db file is writable. + // TODO: Ensure the db file is a valid sqlite3 db. + // TODO: Call db.Reset() if either of the above fail. + // TODO: Log the error/crash if the db file is not writable. + return nil +} diff --git a/internal/logging/context.go b/internal/logging/context.go new file mode 100644 index 0000000..29bad62 --- /dev/null +++ b/internal/logging/context.go @@ -0,0 +1,39 @@ +package logging + +import ( + "context" + "log/slog" +) + +type ctxTracker struct { + ctxKey interface{} + logKey string + next slog.Handler +} + +func (h ctxTracker) Handle(ctx context.Context, r slog.Record) error { + if v := ctx.Value(h.ctxKey); v != nil { + r.AddAttrs(slog.Any(h.logKey, v)) + } + return h.next.Handle(ctx, r) +} + +func (h ctxTracker) Enabled(ctx context.Context, lvl slog.Level) bool { + return h.next.Enabled(ctx, lvl) +} + +func (h ctxTracker) WithAttrs(attrs []slog.Attr) slog.Handler { + return h.next.WithAttrs(attrs) +} + +func (h ctxTracker) WithGroup(name string) slog.Handler { + return h.next.WithGroup(name) +} + +func withTrackedContext(current slog.Handler, ctxKey interface{}, logKey string) *ctxTracker { + return &ctxTracker{ + ctxKey: ctxKey, + logKey: logKey, + next: current, + } +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go new file mode 100644 index 0000000..91a9734 --- /dev/null +++ b/internal/logging/logging.go @@ -0,0 +1,92 @@ +package logging + +import ( + "errors" + "io" + "log/slog" + "os" + "strings" + "time" + + "github.com/ChausseBenjamin/rafta/internal/util" + "github.com/charmbracelet/log" +) + +const ( + ErrKey = "error_message" +) + +var ( + ErrInvalidLevel = errors.New("invalid log level") + ErrInvalidFormat = errors.New("invalid log format") +) + +func Setup(lvlStr, fmtStr, outStr string) error { + output, outputErr := setOutput(outStr) + format, formatErr := setFormat(fmtStr) + level, levelErr := setLevel(lvlStr) + + prefixStr := "" + if format != log.JSONFormatter { + prefixStr = "Rafta 🚢" + } + + var h slog.Handler = log.NewWithOptions( + output, + log.Options{ + TimeFormat: time.DateTime, + Prefix: prefixStr, + Level: level, + ReportCaller: true, + Formatter: format, + }, + ) + + h = withTrackedContext(h, util.ReqIDKey, "request_id") + h = withStackTrace(h) + slog.SetDefault(slog.New(h)) + return errors.Join(outputErr, formatErr, levelErr) +} + +func setLevel(target string) (log.Level, error) { + for _, l := range []struct { + prefix string + level log.Level + }{ + {"deb", log.DebugLevel}, + {"inf", log.InfoLevel}, + {"warn", log.WarnLevel}, + {"err", log.ErrorLevel}, + } { + if strings.HasPrefix(strings.ToLower(target), l.prefix) { + return l.level, nil + } + } + return log.InfoLevel, ErrInvalidLevel +} + +func setFormat(f string) (log.Formatter, error) { + switch f { + case "plain", "text": + return log.TextFormatter, nil + case "json", "structured": + return log.JSONFormatter, nil + } + return log.TextFormatter, ErrInvalidFormat +} + +func setOutput(path string) (io.Writer, error) { + switch path { + case "stdout": + return os.Stdout, nil + case "stderr": + return os.Stderr, nil + default: + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return os.Stdout, err + } else { + return f, nil + } + } +} diff --git a/internal/logging/trace.go b/internal/logging/trace.go new file mode 100644 index 0000000..8900727 --- /dev/null +++ b/internal/logging/trace.go @@ -0,0 +1,61 @@ +package logging + +import ( + "context" + "log/slog" + "runtime" + "strconv" + "strings" +) + +const ( + prgCount = 20 + defSkip = 6 +) + +type stackTracer struct { + h slog.Handler + nSkip int +} + +func (h stackTracer) Enabled(ctx context.Context, lvl slog.Level) bool { + return h.h.Enabled(ctx, lvl) +} + +func (h stackTracer) WithAttrs(attrs []slog.Attr) slog.Handler { + return h.h.WithAttrs(attrs) +} + +func (h stackTracer) WithGroup(name string) slog.Handler { + return h.h.WithGroup(name) +} + +func (h stackTracer) Handle(ctx context.Context, r slog.Record) error { + if r.Level < slog.LevelError { + return h.h.Handle(ctx, r) + } + + trace := h.GetTrace() + r.AddAttrs(slog.String("trace", trace)) + + return h.h.Handle(ctx, r) +} + +func (h stackTracer) GetTrace() string { + var b strings.Builder + pc := make([]uintptr, prgCount) + n := runtime.Callers(h.nSkip, pc) + frames := runtime.CallersFrames(pc[:n]) + + for frame, more := frames.Next(); more; frame, more = frames.Next() { + b.WriteString(frame.Function + "\n " + frame.File + ":" + strconv.Itoa(frame.Line) + "\n") + } + return b.String() +} + +func withStackTrace(h slog.Handler) slog.Handler { + return stackTracer{ + h: h, + nSkip: defSkip, + } +} diff --git a/internal/manualgen/manualgen.go b/internal/manualgen/manualgen.go new file mode 100644 index 0000000..88a72e2 --- /dev/null +++ b/internal/manualgen/manualgen.go @@ -0,0 +1,21 @@ +package main + +import ( + "log/slog" + "os" + + "github.com/ChausseBenjamin/rafta/internal/app" + "github.com/ChausseBenjamin/rafta/internal/logging" + docs "github.com/urfave/cli-docs/v3" +) + +func main() { + a := app.Command() + + man, err := docs.ToManWithSection(a, 1) + if err != nil { + slog.Error("failed to generate markdown", logging.ErrKey, err) + os.Exit(1) + } + os.Stdout.Write([]byte(man)) +} diff --git a/internal/server/model/.gitignore b/internal/server/model/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/internal/server/model/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..70e5861 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,16 @@ +package server + +import ( + "database/sql" + "errors" + + m "github.com/ChausseBenjamin/rafta/internal/server/model" +) + +var ErrOutOfBoundsPort = errors.New("port out of bounds") + +// Implements ComsServer interface +type Service struct { + store *sql.DB + m.UnimplementedRaftaServer +} diff --git a/internal/server/setup.go b/internal/server/setup.go new file mode 100644 index 0000000..551cd95 --- /dev/null +++ b/internal/server/setup.go @@ -0,0 +1,34 @@ +package server + +import ( + "database/sql" + "fmt" + "log/slog" + "net" + + "github.com/ChausseBenjamin/rafta/internal/logging" + m "github.com/ChausseBenjamin/rafta/internal/server/model" + "github.com/ChausseBenjamin/rafta/internal/tagging" + "google.golang.org/grpc" +) + +func Setup(port int64, storage *sql.DB) (*grpc.Server, net.Listener, error) { + lis, err := net.Listen( + "tcp", + fmt.Sprintf(":%d", port), + ) + if err != nil { + slog.Error("Unable to create listener", logging.ErrKey, err) + return nil, nil, err + } + + grpcServer := grpc.NewServer( + grpc.ChainUnaryInterceptor( + tagging.UnaryInterceptor, + ), + ) + raftaService := &Service{store: storage} + m.RegisterRaftaServer(grpcServer, raftaService) + + return grpcServer, lis, nil +} diff --git a/internal/server/task.go b/internal/server/task.go new file mode 100644 index 0000000..b8cf0b8 --- /dev/null +++ b/internal/server/task.go @@ -0,0 +1,33 @@ +package server + +import ( + "context" + "log/slog" + + m "github.com/ChausseBenjamin/rafta/internal/server/model" + "google.golang.org/protobuf/types/known/emptypb" +) + +func (s Service) GetUserTasks(ctx context.Context, id *m.UserID) (*m.TaskList, error) { + slog.ErrorContext(ctx, "GetUserTasks not implemented yet") + return nil, nil +} + +func (s Service) GetTask(ctx context.Context, id *m.TaskID) (*m.Task, error) { + return nil, nil +} + +func (s Service) DeleteTask(ctx context.Context, id *m.TaskID) (*emptypb.Empty, error) { + slog.ErrorContext(ctx, "DeleteTask not implemented yet") + return nil, nil +} + +func (s Service) UpdateTask(ctx context.Context, t *m.Task) (*m.Task, error) { + slog.ErrorContext(ctx, "UpdateTask not implemented yet") + return t, nil +} + +func (s Service) CreateTask(ctx context.Context, data *m.TaskData) (*m.Task, error) { + slog.ErrorContext(ctx, "CreateTask not implemented yet") + return nil, nil +} diff --git a/internal/server/user.go b/internal/server/user.go new file mode 100644 index 0000000..c4a97c4 --- /dev/null +++ b/internal/server/user.go @@ -0,0 +1,34 @@ +package server + +import ( + "context" + "log/slog" + + m "github.com/ChausseBenjamin/rafta/internal/server/model" + "google.golang.org/protobuf/types/known/emptypb" +) + +func (s Service) GetAllUsers(ctx context.Context, empty *emptypb.Empty) (*m.UserList, error) { + slog.ErrorContext(ctx, "GetAllUsers not implemented yet") + return nil, nil +} + +func (s Service) GetUser(ctx context.Context, id *m.UserID) (*m.User, error) { + slog.ErrorContext(ctx, "GetUser not implemented yet") + return nil, nil +} + +func (s Service) DeleteUser(ctx context.Context, id *m.UserID) (*emptypb.Empty, error) { + slog.ErrorContext(ctx, "DeleteUser not implemented yet") + return nil, nil +} + +func (s Service) UpdateUser(ctx context.Context, u *m.User) (*m.User, error) { + slog.ErrorContext(ctx, "UpdateUser not implemented yet") + return nil, nil +} + +func (s Service) CreateUser(ctx context.Context, data *m.UserData) (*m.User, error) { + slog.ErrorContext(ctx, "CreateUser not implemented yet") + return nil, nil +} diff --git a/internal/storage/db.go b/internal/storage/db.go new file mode 100644 index 0000000..06fe31c --- /dev/null +++ b/internal/storage/db.go @@ -0,0 +1,18 @@ +package storage + +import ( + "context" + "database/sql" + "log/slog" + + "github.com/ChausseBenjamin/rafta/internal/util" +) + +func GetDB(ctx context.Context) *sql.DB { + db, ok := ctx.Value(util.DBKey).(*sql.DB) + if !ok { + slog.Error("Unable to retrieve database from context") + return nil + } + return db +} diff --git a/internal/storage/schema.go b/internal/storage/schema.go new file mode 100644 index 0000000..042c09b --- /dev/null +++ b/internal/storage/schema.go @@ -0,0 +1,38 @@ +package storage + +func opts() string { + return "?_foreign_keys=on&_journal_mode=WAL" +} + +func schema() string { + return ` +CREATE TABLE User ( + UserID INTEGER PRIMARY KEY AUTOINCREMENT, + Name TEXT NOT NULL, + Email TEXT NOT NULL UNIQUE +); + +CREATE TABLE Task ( + TaskID INTEGER PRIMARY KEY AUTOINCREMENT, + Title TEXT NOT NULL, + Description TEXT NOT NULL, + Due DATE NOT NULL, + Do DATE NOT NULL, + Owner INTEGER NOT NULL, + FOREIGN KEY (Owner) REFERENCES User(UserID) ON DELETE CASCADE +); + +CREATE TABLE Tag ( + TagID INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + Name TEXT NOT NULL UNIQUE +); + +CREATE TABLE TaskTag ( + TaskUUID INTEGER NOT NULL, + TagID INTEGER NOT NULL, + PRIMARY KEY (TaskUUID, TagID), + FOREIGN KEY (TaskUUID) REFERENCES Task(UUID) ON DELETE CASCADE, + FOREIGN KEY (TagID) REFERENCES Tag(TagID) ON DELETE CASCADE +); +` +} diff --git a/internal/storage/setup.go b/internal/storage/setup.go new file mode 100644 index 0000000..f103b15 --- /dev/null +++ b/internal/storage/setup.go @@ -0,0 +1,201 @@ +package storage + +import ( + "database/sql" + "errors" + "fmt" + "log/slog" + "os" + "strings" + + _ "github.com/mattn/go-sqlite3" +) + +var ( + ErrIntegrityCheckFailed = errors.New("integrity check failed") + ErrForeignKeysDisabled = errors.New("foreign_keys pragma is not enabled") + ErrJournalModeInvalid = errors.New("journal_mode is not wal") + ErrSchemaMismatch = errors.New("schema does not match expected definition") +) + +// Setup opens the SQLite DB at path, verifies its integrity and schema, +// and returns the valid DB handle. On any error, it backs up the old file +// (if it exists) and calls genDB() to initialize a valid schema. +func Setup(path string) (*sql.DB, error) { + _, statErr := os.Stat(path) + exists := statErr == nil + + // If file doesn't exist, generate new DB. + if !exists { + return genDB(path) + } + + db, err := sql.Open("sqlite3", path+opts()) + if err != nil { + slog.Error("failed to open DB", "error", err) + backupFile(path) + return genDB(path) + } + + // Integrity check. + var integrity string + if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil { + slog.Error("integrity check query failed", "error", err) + db.Close() + backupFile(path) + return genDB(path) + } + if integrity != "ok" { + slog.Error("integrity check failed", "error", ErrIntegrityCheckFailed) + db.Close() + backupFile(path) + return genDB(path) + } + + // Validate schema and pragmas. + if err = validateSchema(db); err != nil { + slog.Error("schema validation failed", "error", err) + db.Close() + backupFile(path) + return genDB(path) + } + + return db, nil +} + +// validateSchema verifies that required pragmas and table definitions are set. +func validateSchema(db *sql.DB) error { + // Check PRAGMA foreign_keys = on. + var fk int + if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil { + return err + } + if fk != 1 { + return ErrForeignKeysDisabled + } + + // Check PRAGMA journal_mode = wal. + var jm string + if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil { + return err + } + if strings.ToLower(jm) != "wal" { + return ErrJournalModeInvalid + } + + // Define required table definitions (as substrings in lower-case). + type tableCheck struct { + name string + substrings []string + } + + checks := []tableCheck{ + { + name: "User", + substrings: []string{ + "create table user", + "userid", "integer", "primary key", "autoincrement", + "name", "text", "not null", + "email", "text", "not null", "unique", + }, + }, + { + name: "Task", + substrings: []string{ + "create table task", + "taskid", "integer", "primary key", "autoincrement", + "title", "not null", + "description", "not null", + "due", "date", "not null", + "do", "date", "not null", + "owner", "integer", "not null", + "foreign key", "references user", + }, + }, + { + name: "Tag", + substrings: []string{ + "create table tag", + "tagid", "integer", "primary key", "autoincrement", + "name", "text", "not null", "unique", + }, + }, + { + name: "TaskTag", + substrings: []string{ + "create table tasktag", + "taskuuid", "integer", "not null", + "tagid", "integer", "not null", + "primary key", + "foreign key", "references task", + "foreign key", "references tag", + }, + }, + } + + for _, chk := range checks { + sqlDef, err := fetchTableSQL(db, chk.name) + if err != nil { + return fmt.Errorf("failed to fetch definition for table %s: %w", chk.name, err) + } + lc := strings.ToLower(sqlDef) + for _, substr := range chk.substrings { + if !strings.Contains(lc, substr) { + return fmt.Errorf("%w: table %s missing %q", ErrSchemaMismatch, chk.name, substr) + } + } + } + + return nil +} + +func fetchTableSQL(db *sql.DB, table string) (string, error) { + var sqlDef sql.NullString + err := db.QueryRow("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&sqlDef) + if err != nil { + return "", err + } + if !sqlDef.Valid { + return "", fmt.Errorf("no SQL definition found for table %s", table) + } + return sqlDef.String, nil +} + +// backupFile renames the existing file by appending a ".bak" suffix. +func backupFile(path string) { + backupPath := path + ".bak" + // If backupPath exists, append a timestamp. + if _, err := os.Stat(backupPath); err == nil { + backupPath = fmt.Sprintf("%s.%d.bak", path, os.Getpid()) + } + if err := os.Rename(path, backupPath); err != nil { + slog.Error("failed to backup file", "error", err, "original", path, "backup", backupPath) + } else { + slog.Info("backed up corrupt DB", "original", path, "backup", backupPath) + } +} + +// genDB creates a new database at path with the valid schema. +func genDB(path string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + slog.Error("failed to create DB", "error", err) + return nil, err + } + + // Set pragmas. + if _, err := db.Exec("PRAGMA foreign_keys = on; PRAGMA journal_mode = wal;"); err != nil { + slog.Error("failed to set pragmas", "error", err) + db.Close() + return nil, err + } + + if _, err := db.Exec(schema()); err != nil { + slog.Error("failed to initialize schema", "error", err) + db.Close() + return nil, err + } + + slog.Info("created new blank DB with valid schema", "path", path) + return db, nil +} diff --git a/internal/tagging/tagging.go b/internal/tagging/tagging.go new file mode 100644 index 0000000..8e8efe9 --- /dev/null +++ b/internal/tagging/tagging.go @@ -0,0 +1,23 @@ +package tagging + +import ( + "context" + "log/slog" + + "github.com/ChausseBenjamin/rafta/internal/logging" + "github.com/ChausseBenjamin/rafta/internal/util" + "github.com/hashicorp/go-uuid" + "google.golang.org/grpc" +) + +// gRPC interceptor to tag requests with a unique identifier +func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + id, err := uuid.GenerateUUID() + if err != nil { + slog.Error("Unable to generate UUID for request", logging.ErrKey, err) + } + ctx = context.WithValue(ctx, util.ReqIDKey, id) + slog.DebugContext(ctx, "Tagging request with UUID", "value", id) + + return handler(ctx, req) +} diff --git a/internal/util/context.go b/internal/util/context.go new file mode 100644 index 0000000..e7cbab2 --- /dev/null +++ b/internal/util/context.go @@ -0,0 +1,8 @@ +package util + +type ContextKey uint8 + +const ( + DBKey ContextKey = iota + ReqIDKey +) |