summaryrefslogtreecommitdiff
path: root/internal/storage/setup.go
blob: f103b15460f452a67cdac70704d1986014d8dfe1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
package storage

import (
	"database/sql"
	"errors"
	"fmt"
	"log/slog"
	"os"
	"strings"

	_ "github.com/mattn/go-sqlite3"
)

var (
	ErrIntegrityCheckFailed = errors.New("integrity check failed")
	ErrForeignKeysDisabled  = errors.New("foreign_keys pragma is not enabled")
	ErrJournalModeInvalid   = errors.New("journal_mode is not wal")
	ErrSchemaMismatch       = errors.New("schema does not match expected definition")
)

// Setup opens the SQLite DB at path, verifies its integrity and schema,
// and returns the valid DB handle. On any error, it backs up the old file
// (if it exists) and calls genDB() to initialize a valid schema.
func Setup(path string) (*sql.DB, error) {
	_, statErr := os.Stat(path)
	exists := statErr == nil

	// If file doesn't exist, generate new DB.
	if !exists {
		return genDB(path)
	}

	db, err := sql.Open("sqlite3", path+opts())
	if err != nil {
		slog.Error("failed to open DB", "error", err)
		backupFile(path)
		return genDB(path)
	}

	// Integrity check.
	var integrity string
	if err = db.QueryRow("PRAGMA integrity_check;").Scan(&integrity); err != nil {
		slog.Error("integrity check query failed", "error", err)
		db.Close()
		backupFile(path)
		return genDB(path)
	}
	if integrity != "ok" {
		slog.Error("integrity check failed", "error", ErrIntegrityCheckFailed)
		db.Close()
		backupFile(path)
		return genDB(path)
	}

	// Validate schema and pragmas.
	if err = validateSchema(db); err != nil {
		slog.Error("schema validation failed", "error", err)
		db.Close()
		backupFile(path)
		return genDB(path)
	}

	return db, nil
}

// validateSchema verifies that required pragmas and table definitions are set.
func validateSchema(db *sql.DB) error {
	// Check PRAGMA foreign_keys = on.
	var fk int
	if err := db.QueryRow("PRAGMA foreign_keys;").Scan(&fk); err != nil {
		return err
	}
	if fk != 1 {
		return ErrForeignKeysDisabled
	}

	// Check PRAGMA journal_mode = wal.
	var jm string
	if err := db.QueryRow("PRAGMA journal_mode;").Scan(&jm); err != nil {
		return err
	}
	if strings.ToLower(jm) != "wal" {
		return ErrJournalModeInvalid
	}

	// Define required table definitions (as substrings in lower-case).
	type tableCheck struct {
		name       string
		substrings []string
	}

	checks := []tableCheck{
		{
			name: "User",
			substrings: []string{
				"create table user",
				"userid", "integer", "primary key", "autoincrement",
				"name", "text", "not null",
				"email", "text", "not null", "unique",
			},
		},
		{
			name: "Task",
			substrings: []string{
				"create table task",
				"taskid", "integer", "primary key", "autoincrement",
				"title", "not null",
				"description", "not null",
				"due", "date", "not null",
				"do", "date", "not null",
				"owner", "integer", "not null",
				"foreign key", "references user",
			},
		},
		{
			name: "Tag",
			substrings: []string{
				"create table tag",
				"tagid", "integer", "primary key", "autoincrement",
				"name", "text", "not null", "unique",
			},
		},
		{
			name: "TaskTag",
			substrings: []string{
				"create table tasktag",
				"taskuuid", "integer", "not null",
				"tagid", "integer", "not null",
				"primary key",
				"foreign key", "references task",
				"foreign key", "references tag",
			},
		},
	}

	for _, chk := range checks {
		sqlDef, err := fetchTableSQL(db, chk.name)
		if err != nil {
			return fmt.Errorf("failed to fetch definition for table %s: %w", chk.name, err)
		}
		lc := strings.ToLower(sqlDef)
		for _, substr := range chk.substrings {
			if !strings.Contains(lc, substr) {
				return fmt.Errorf("%w: table %s missing %q", ErrSchemaMismatch, chk.name, substr)
			}
		}
	}

	return nil
}

func fetchTableSQL(db *sql.DB, table string) (string, error) {
	var sqlDef sql.NullString
	err := db.QueryRow("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&sqlDef)
	if err != nil {
		return "", err
	}
	if !sqlDef.Valid {
		return "", fmt.Errorf("no SQL definition found for table %s", table)
	}
	return sqlDef.String, nil
}

// backupFile renames the existing file by appending a ".bak" suffix.
func backupFile(path string) {
	backupPath := path + ".bak"
	// If backupPath exists, append a timestamp.
	if _, err := os.Stat(backupPath); err == nil {
		backupPath = fmt.Sprintf("%s.%d.bak", path, os.Getpid())
	}
	if err := os.Rename(path, backupPath); err != nil {
		slog.Error("failed to backup file", "error", err, "original", path, "backup", backupPath)
	} else {
		slog.Info("backed up corrupt DB", "original", path, "backup", backupPath)
	}
}

// genDB creates a new database at path with the valid schema.
func genDB(path string) (*sql.DB, error) {
	db, err := sql.Open("sqlite3", path)
	if err != nil {
		slog.Error("failed to create DB", "error", err)
		return nil, err
	}

	// Set pragmas.
	if _, err := db.Exec("PRAGMA foreign_keys = on; PRAGMA journal_mode = wal;"); err != nil {
		slog.Error("failed to set pragmas", "error", err)
		db.Close()
		return nil, err
	}

	if _, err := db.Exec(schema()); err != nil {
		slog.Error("failed to initialize schema", "error", err)
		db.Close()
		return nil, err
	}

	slog.Info("created new blank DB with valid schema", "path", path)
	return db, nil
}