aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/log.zig18
-rw-r--r--src/main.zig103
-rw-r--r--src/postgres.zig126
-rw-r--r--src/sqlite.zig101
-rw-r--r--src/test.zig102
5 files changed, 450 insertions, 0 deletions
diff --git a/src/log.zig b/src/log.zig
new file mode 100644
index 0000000..dd81d5f
--- /dev/null
+++ b/src/log.zig
@@ -0,0 +1,18 @@
+const std = @import("std");
+const builtin = @import("builtin");
+/// Workaround for failing tests when errors are logged
+/// https://github.com/ziglang/zig/issues/5738#issuecomment-1466902082
+pub fn scoped_log_t(comptime scope: @Type(.EnumLiteral)) type {
+ return if (builtin.is_test)
+ // Downgrade `err` to `warn` for tests.
+ // Zig fails any test that does `log.err`, but we want to test those code paths here.
+ struct {
+ pub const base = std.log.scoped(scope);
+ pub const err = warn;
+ pub const warn = base.warn;
+ pub const info = base.info;
+ pub const debug = base.debug;
+ }
+ else
+ std.log.scoped(scope);
+}
diff --git a/src/main.zig b/src/main.zig
new file mode 100644
index 0000000..ebb9fde
--- /dev/null
+++ b/src/main.zig
@@ -0,0 +1,103 @@
+const std = @import("std");
+const log = @import("log.zig").scoped_log_t(.db);
+
+/// Database abstraction layer
+pub const Db = @This();
+
+/// Type erased pointer to actual implementation, and reference to implementation functions
+/// Inspired by std.mem.Allocator
+ptr: *anyopaque,
+vtable: VTable,
+
+fn init(ptr: *anyopaque, vtable: VTable) Db {
+ return Db{
+ .ptr = ptr,
+ .vtable = vtable,
+ };
+}
+
+pub const OpenError = error{ Failed, NotThreadSafe } || std.mem.Allocator.Error;
+pub const PrepareError = error{Failed} || std.mem.Allocator.Error;
+pub const BindError = error{};
+pub const StepError = error{Failed} || std.mem.Allocator.Error;
+pub const ColumnError = error{WrongType} || std.fmt.ParseIntError;
+
+// Dispatcher for concrete implementations. Inspired by std.mem.Allocator.
+const VTable = struct {
+ prepare: *const fn (*anyopaque, query: [:0]const u8) PrepareError!*anyopaque,
+ // bind: *const fn (db: *anyopaque, stmt: *anyopaque, idx: usize, val: anytype) BindError!void,
+ step: *const fn (db: *anyopaque, stmt: *anyopaque) StepError!bool,
+ // TODO support more types
+ // TODO think of something clever to avoid this function proliferation.
+ // Values returned should live at least until next call to .step()
+ column_i64: *const fn (db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?i64,
+ column_slice_const_u8: *const fn (db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?[:0]const u8,
+ close_stmt: *const fn (db: *anyopaque, stmt: *anyopaque) void,
+ close_db: *const fn (db: *anyopaque) void,
+};
+
+pub const Stmt = struct {
+ db: *Db,
+ ptr: *anyopaque,
+
+ fn bind(self: *Stmt, idx: u31, val: anytype) !void {
+ try self.db.vtable.bind(self.db.ptr, self.ptr, idx, val);
+ }
+ fn bind_named(self: *Stmt, name: [:0]const u8, val: anytype) !void {
+ try self.db.vtable.bind_named(self.db.ptr, self.ptr, name, val);
+ }
+ /// Advance the result set to the next row.
+ pub fn step(self: *Stmt) !bool {
+ return try self.db.vtable.step(self.db.ptr, self.ptr);
+ }
+ /// Read a column
+ pub fn column(self: *Stmt, comptime T: type, index: u31) !?T {
+ switch (@typeInfo(T)) {
+ .Int => |intinfo| {
+ if (intinfo.signedness != .signed) @compileError("integer type i64 only is supported in Stmt#column");
+ if (intinfo.bits != 64) @compileError("integer type i64 only is supported in Stmt#column");
+ return try self.db.vtable.column_i64(self.db.ptr, self.ptr, index);
+ },
+ .Pointer => |ptrinfo| {
+ if (ptrinfo.size != .Slice) @compileError("pointer type []const u8 only is supported in Stmt#column");
+ if (ptrinfo.child != u8) @compileError("pointer type []const u8 only is supported in Stmt#column");
+ return try self.db.vtable.column_slice_const_u8(self.db.ptr, self.ptr, index);
+ },
+ else => @compileError("unhandled type " ++ @tagName(@typeInfo(T)) ++ " in Stmt#column"),
+ }
+ }
+ pub fn close(self: *Stmt) void {
+ self.db.vtable.close_stmt(self.db.ptr, self.ptr);
+ }
+};
+
+/// Interface for running queries
+/// params should be a tuple with values for placeholders.
+pub fn query(self: *Db, comptime qry: [:0]const u8, params: anytype) !Stmt {
+ const ti = @typeInfo(@TypeOf(params));
+ if (ti != .Struct) @compileError("Db.query params must be a tuple struct but it's a " ++ @tagName(ti));
+ const si = ti.Struct;
+ if (!si.is_tuple) @compileError("Db.query params must be a tuple");
+
+ const stmt_ptr = try self.vtable.prepare(self.ptr, qry);
+ var stmt = Stmt{ .db = self, .ptr = stmt_ptr };
+ inline for (0..si.fields.len) |idx| {
+ try stmt.bind(idx, params[idx]);
+ }
+ return stmt;
+}
+
+/// Shortcut for query/step/close
+pub fn exec(self: *Db, comptime qry: [:0]const u8) !void {
+ var stmt = try self.query(qry, .{});
+ defer stmt.close();
+ _ = try stmt.step();
+}
+
+pub fn close(self: *Db) void {
+ self.vtable.close_db(self.ptr);
+}
+
+test {
+ _ = @import("test.zig");
+} \ No newline at end of file
diff --git a/src/postgres.zig b/src/postgres.zig
new file mode 100644
index 0000000..5a17bd7
--- /dev/null
+++ b/src/postgres.zig
@@ -0,0 +1,126 @@
+const std = @import("std");
+const pq = @cImport(
+ @cInclude("libpq-fe.h"),
+);
+const Db = @import("main.zig");
+const OpenError = Db.OpenError;
+const PrepareError = Db.PrepareError;
+const StepError = Db.StepError;
+const ColumnError = Db.ColumnError;
+const log = @import("log.zig").scoped_log_t(.postgres);
+
+/// Postgres implementation
+/// Single persistent connection implementation of postgres via libpq
+pub const Postgres = @This();
+
+allocator: std.mem.Allocator,
+conn: *pq.PGconn,
+
+/// Connect to a postgres database. URL is format accepted by libpq
+pub fn open(allocator: std.mem.Allocator, url: [:0]const u8) OpenError!Db {
+ if (pq.PQisthreadsafe() == 0) {
+ log.err("Postgres#open: PQisthreadsafe returned 0, can't use libpq in this program", .{});
+ return OpenError.NotThreadSafe;
+ }
+ var maybe_conn: ?*pq.PGconn = pq.PQconnectdb(url);
+ if (pq.PQstatus(maybe_conn) != pq.CONNECTION_OK) {
+ log.err("Postgres#open: PQstatus returned error {}: {s}", .{ pq.PQstatus(maybe_conn), pq.PQerrorMessage(maybe_conn) });
+ return OpenError.Failed;
+ }
+
+ var pg = try allocator.create(Postgres);
+ pg.allocator = allocator;
+ pg.conn = maybe_conn.?;
+ return Db{ .ptr = pg, .vtable = .{
+ .prepare = prepare,
+ .step = step,
+ .column_i64 = column_i64,
+ .column_slice_const_u8 = column_slice_const_u8,
+ .close_stmt = close_stmt,
+ .close_db = close_db,
+ } };
+}
+
+fn prepare(db: *anyopaque, query: [:0]const u8) PrepareError!*anyopaque {
+ var self: *Postgres = @alignCast(@ptrCast(db));
+ var pg_stmt = try self.allocator.create(PgStmt);
+ pg_stmt.* = .{
+ .query = query,
+ .params = std.ArrayList([*c]const u8).init(self.allocator),
+ };
+ return pg_stmt;
+}
+
+const PgStmt = struct {
+ query: [:0]const u8,
+ params: std.ArrayList([*c]const u8),
+ c_res: ?*pq.PGresult = null,
+ did_exec: bool = false,
+ n_tuples: ?c_int = null,
+ n_fields: ?c_int = null,
+ res_index: c_int = -1,
+};
+
+// TODO
+// fn bind
+
+fn step(db: *anyopaque, stmt: *anyopaque) StepError!bool {
+ var self: *Postgres = @alignCast(@ptrCast(db));
+ var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt));
+ if (!pgstmt.did_exec) {
+ const params = try pgstmt.params.toOwnedSlice();
+ pgstmt.c_res = pq.PQexecParams(self.conn, pgstmt.query, @intCast(params.len), null, params.ptr, null, null, 0);
+ const rs = pq.PQresultStatus(pgstmt.c_res);
+ if (rs != pq.PGRES_TUPLES_OK and rs != pq.PGRES_SINGLE_TUPLE and rs != pq.PGRES_COMMAND_OK) {
+ log.err("PQresultStatus {} error: {s}", .{ rs, pq.PQerrorMessage(self.conn) });
+ return StepError.Failed;
+ }
+ pgstmt.n_tuples = pq.PQntuples(pgstmt.c_res);
+ pgstmt.n_fields = pq.PQnfields(pgstmt.c_res);
+ pgstmt.did_exec = true;
+ }
+ pgstmt.res_index = pgstmt.res_index + 1;
+ return pgstmt.res_index < pgstmt.n_tuples.?;
+}
+
+fn column_i64(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?i64 {
+ _ = db;
+ var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt));
+ if (!pgstmt.did_exec) @panic("did_exec == false you must call exec then step before trying column");
+ if (pgstmt.res_index == -1) @panic("res_index == -1 you must call exec then step before trying column");
+ if (pq.PQgetisnull(pgstmt.c_res, pgstmt.res_index, idx) == 1) {
+ return null;
+ }
+ const value_c: [*c]const u8 = pq.PQgetvalue(pgstmt.c_res, pgstmt.res_index, idx);
+ const slice = std.mem.sliceTo(value_c, 0);
+ return try std.fmt.parseInt(i64, slice, 10);
+}
+
+fn column_slice_const_u8(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?[:0]const u8 {
+ _ = db;
+ var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt));
+ if (!pgstmt.did_exec) @panic("did_exec == false you must call exec then step before trying column");
+ if (pgstmt.res_index == -1) @panic("res_index == -1 you must call exec then step before trying column");
+ if (pq.PQgetisnull(pgstmt.c_res, pgstmt.res_index, idx) == 1) {
+ return null;
+ }
+ const value_c: ?[*:0]u8 = pq.PQgetvalue(pgstmt.c_res, pgstmt.res_index, idx);
+ const value_c_nonnull = value_c orelse return null;
+ return std.mem.sliceTo(value_c_nonnull, 0);
+}
+
+fn close_stmt(db: *anyopaque, stmt: *anyopaque) void {
+ var self: *Postgres = @alignCast(@ptrCast(db));
+ var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt));
+ if (pgstmt.c_res != null) {
+ pq.PQclear(pgstmt.c_res.?);
+ }
+ pgstmt.params.deinit();
+ self.allocator.destroy(pgstmt);
+}
+
+fn close_db(db: *anyopaque) void {
+ var self: *Postgres = @alignCast(@ptrCast(db));
+ pq.PQfinish(self.conn);
+ self.allocator.destroy(self);
+} \ No newline at end of file
diff --git a/src/sqlite.zig b/src/sqlite.zig
new file mode 100644
index 0000000..e0668cc
--- /dev/null
+++ b/src/sqlite.zig
@@ -0,0 +1,101 @@
+const std = @import("std");
+const sqlite = @cImport({
+ @cInclude("sqlite3.h");
+});
+const Db = @import("main.zig");
+const OpenError = Db.OpenError;
+const PrepareError = Db.PrepareError;
+const StepError = Db.StepError;
+const ColumnError = Db.ColumnError;
+const log = @import("log.zig").scoped_log_t(.sqlite);
+
+//// Sqlite implementation
+pub const Sqlite = @This();
+
+allocator: std.mem.Allocator,
+c_db: *sqlite.sqlite3,
+
+pub fn open(allocator: std.mem.Allocator, filename: [:0]const u8) OpenError!Db {
+ var db: ?*sqlite.sqlite3 = null;
+ const oo = sqlite.SQLITE_OPEN_CREATE | sqlite.SQLITE_OPEN_READWRITE | sqlite.SQLITE_OPEN_FULLMUTEX;
+ if (sqlite.sqlite3_open_v2(filename.ptr, &db, oo, null) != sqlite.SQLITE_OK) {
+ log.err("Sqlite#open: sqlite3_open_v2 error {s}", .{sqlite.sqlite3_errmsg(db)});
+ return OpenError.Failed;
+ }
+ var sqlite_db = try allocator.create(Sqlite);
+ sqlite_db.allocator = allocator;
+ sqlite_db.c_db = db.?;
+ return Db{
+ .ptr = sqlite_db,
+ .vtable = .{
+ .prepare = prepare,
+ .step = step,
+ .column_i64 = column_i64,
+ .column_slice_const_u8 = column_slice_const_u8,
+ .close_stmt = close_stmt,
+ .close_db = close_db,
+ },
+ };
+}
+
+fn prepare(db: *anyopaque, query: [:0]const u8) PrepareError!*anyopaque {
+ var self: *Sqlite = @alignCast(@ptrCast(db));
+ var sstmt: ?*sqlite.sqlite3_stmt = null;
+ if (sqlite.sqlite3_prepare_v2(self.c_db, query.ptr, @intCast(query.len), &sstmt, null) != sqlite.SQLITE_OK) {
+ log.err("Sqlite#prepare: sqlite3_prepare_v2: {s}", .{sqlite.sqlite3_errmsg(self.c_db)});
+ return PrepareError.Failed;
+ }
+ return sstmt.?;
+}
+
+fn step(db: *anyopaque, stmt: *anyopaque) StepError!bool {
+ var self: *Sqlite = @alignCast(@ptrCast(db));
+ var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt));
+ const res = sqlite.sqlite3_step(sstmt);
+ if (res == sqlite.SQLITE_ROW) {
+ return true;
+ } else if (res == sqlite.SQLITE_DONE) {
+ return false;
+ } else {
+ log.err("Sqlite#step: sqlite3_step: {s}", .{sqlite.sqlite3_errmsg(self.c_db)});
+ return StepError.Failed;
+ }
+}
+
+fn column_i64(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?i64 {
+ _ = db;
+ var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt));
+ const ct = sqlite.sqlite3_column_type(sstmt, idx);
+ if (ct == sqlite.SQLITE_NULL) {
+ return null;
+ } else if (ct == sqlite.SQLITE_INTEGER) {
+ return sqlite.sqlite3_column_int64(sstmt, idx);
+ } else {
+ return ColumnError.WrongType;
+ }
+}
+fn column_slice_const_u8(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?[:0]const u8 {
+ _ = db;
+ var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt));
+ const ct = sqlite.sqlite3_column_type(sstmt, idx);
+ if (ct == sqlite.SQLITE_NULL) {
+ return null;
+ } else if (ct == sqlite.SQLITE_TEXT) {
+ const value_c = sqlite.sqlite3_column_text(sstmt, idx);
+ return std.mem.sliceTo(value_c, 0);
+ } else {
+ return ColumnError.WrongType;
+ }
+}
+
+fn close_stmt(db: *anyopaque, stmt: *anyopaque) void {
+ _ = db;
+ var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt));
+ _ = sqlite.sqlite3_finalize(sstmt);
+}
+
+fn close_db(db: *anyopaque) void {
+ var self: *Sqlite = @alignCast(@ptrCast(db));
+ _ = sqlite.sqlite3_close_v2(self.c_db);
+ self.allocator.destroy(self);
+} \ No newline at end of file
diff --git a/src/test.zig b/src/test.zig
new file mode 100644
index 0000000..afd726b
--- /dev/null
+++ b/src/test.zig
@@ -0,0 +1,102 @@
+const std = @import("std");
+const Db = @import("main.zig");
+const OpenError = Db.OpenError;
+const PrepareError = Db.PrepareError;
+const StepError = Db.StepError;
+const ColumnError = Db.ColumnError;
+
+const Postgres = @import("postgres.zig");
+const Sqlite = @import("sqlite.zig");
+
+const TestDbIterator = struct {
+ allocator: std.mem.Allocator,
+ pg_url: [:0]const u8,
+ sqlite_url: [:0]const u8,
+ done_postgres: bool = false,
+ done_sqlite: bool = false,
+ fn init(allocator: std.mem.Allocator, pg_url: [:0]const u8, sqlite_url: [:0]const u8) TestDbIterator {
+ return TestDbIterator{
+ .allocator = allocator,
+ .pg_url = pg_url,
+ .sqlite_url = sqlite_url,
+ };
+ }
+ fn next(self: *TestDbIterator) ?(OpenError!Db) {
+ if (!self.done_postgres) {
+ self.done_postgres = true;
+ return Postgres.open(self.allocator, self.pg_url);
+ }
+ if (!self.done_sqlite) {
+ self.done_sqlite = true;
+ return Sqlite.open(self.allocator, self.sqlite_url);
+ }
+ return null;
+ }
+};
+
+test "open" {
+ var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db");
+ var maybe_db = it.next();
+ while (maybe_db != null): (maybe_db = it.next()) {
+ var db = try maybe_db.?;
+ defer db.close();
+ }
+}
+
+test "open error" {
+ var it = TestDbIterator.init(std.testing.allocator, "NOPE DOPE", "floogle/flungle");
+ var maybe_db = it.next();
+ while (maybe_db != null): (maybe_db = it.next()) {
+ var db = maybe_db.?;
+ try std.testing.expectEqual(@as(OpenError!Db, OpenError.Failed), db);
+ }
+}
+
+test "query" {
+ var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db");
+ var maybe_db = it.next();
+ while (maybe_db != null): (maybe_db = it.next()) {
+ var db = try maybe_db.?;
+ defer db.close();
+ try db.exec("begin"); // deliberately don't commit!
+ try db.exec("create table foo(col1 int, col2 text)");
+ try db.exec("insert into foo(col1, col2) values(123, 'hi')");
+ var stmt = try db.query("select col1, col2 from foo", .{});
+ defer stmt.close();
+ try std.testing.expect(try stmt.step());
+ const col1 = try stmt.column(i64, 0);
+ const col2 = try stmt.column([]const u8, 1);
+ try std.testing.expectEqual(@as(?i64, 123), col1);
+ try std.testing.expectEqualStrings("hi", col2.?);
+ }
+}
+
+test "query null column" {
+ var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db");
+ var maybe_db = it.next();
+ while (maybe_db != null): (maybe_db = it.next()) {
+ var db = try maybe_db.?;
+ defer db.close();
+ try db.exec("begin"); // deliberately don't commit!
+ try db.exec("create table foo(col1 int, col2 text)");
+ try db.exec("insert into foo(col1, col2) values(null, null)");
+ var stmt = try db.query("select col1, col2 from foo", .{});
+ defer stmt.close();
+ try std.testing.expect(try stmt.step());
+ const col1 = try stmt.column(i64, 0);
+ const col2 = try stmt.column([]const u8, 1);
+ try std.testing.expectEqual(@as(?i64, null), col1);
+ try std.testing.expectEqual(@as(?[]const u8,null), col2);
+ }
+}
+
+test "exec error" {
+ var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db");
+ var maybe_db = it.next();
+ while (maybe_db != null): (maybe_db = it.next()) {
+ var db = try maybe_db.?;
+ defer db.close();
+ const res = db.exec("AIN'T VALID BRO");
+ try std.testing.expectEqual(@as(StepError!void, StepError.Failed), res);
+ }
+}