From 2c4ac3819b8c42de1410fd524c2c9d08d937ec70 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Sun, 3 Sep 2023 20:32:51 +0100 Subject: Initial --- src/log.zig | 18 ++++++++ src/main.zig | 103 +++++++++++++++++++++++++++++++++++++++++++++ src/postgres.zig | 126 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/sqlite.zig | 101 ++++++++++++++++++++++++++++++++++++++++++++ src/test.zig | 102 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 450 insertions(+) create mode 100644 src/log.zig create mode 100644 src/main.zig create mode 100644 src/postgres.zig create mode 100644 src/sqlite.zig create mode 100644 src/test.zig (limited to 'src') diff --git a/src/log.zig b/src/log.zig new file mode 100644 index 0000000..dd81d5f --- /dev/null +++ b/src/log.zig @@ -0,0 +1,18 @@ +const std = @import("std"); +const builtin = @import("builtin"); +/// Workaround for failing tests when errors are logged +/// https://github.com/ziglang/zig/issues/5738#issuecomment-1466902082 +pub fn scoped_log_t(comptime scope: @Type(.EnumLiteral)) type { + return if (builtin.is_test) + // Downgrade `err` to `warn` for tests. + // Zig fails any test that does `log.err`, but we want to test those code paths here. + struct { + pub const base = std.log.scoped(scope); + pub const err = warn; + pub const warn = base.warn; + pub const info = base.info; + pub const debug = base.debug; + } + else + std.log.scoped(scope); +} diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..ebb9fde --- /dev/null +++ b/src/main.zig @@ -0,0 +1,103 @@ +const std = @import("std"); +const log = @import("log.zig").scoped_log_t(.db); + +/// Database abstraction layer +pub const Db = @This(); + +/// Type erased pointer to actual implementation, and reference to implementation functions +/// Inspired by std.mem.Allocator +ptr: *anyopaque, +vtable: VTable, + +fn init(ptr: *anyopaque, vtable: VTable) Db { + return Db{ + .ptr = ptr, + .vtable = vtable, + }; +} + +pub const OpenError = error{ Failed, NotThreadSafe } || std.mem.Allocator.Error; +pub const PrepareError = error{Failed} || std.mem.Allocator.Error; +pub const BindError = error{}; +pub const StepError = error{Failed} || std.mem.Allocator.Error; +pub const ColumnError = error{WrongType} || std.fmt.ParseIntError; + +// Dispatcher for concrete implementations. Inspired by std.mem.Allocator. +const VTable = struct { + prepare: *const fn (*anyopaque, query: [:0]const u8) PrepareError!*anyopaque, + // bind: *const fn (db: *anyopaque, stmt: *anyopaque, idx: usize, val: anytype) BindError!void, + step: *const fn (db: *anyopaque, stmt: *anyopaque) StepError!bool, + // TODO support more types + // TODO think of something clever to avoid this function proliferation. + // Values returned should live at least until next call to .step() + column_i64: *const fn (db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?i64, + column_slice_const_u8: *const fn (db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?[:0]const u8, + close_stmt: *const fn (db: *anyopaque, stmt: *anyopaque) void, + close_db: *const fn (db: *anyopaque) void, +}; + +pub const Stmt = struct { + db: *Db, + ptr: *anyopaque, + + fn bind(self: *Stmt, idx: u31, val: anytype) !void { + try self.db.vtable.bind(self.db.ptr, self.ptr, idx, val); + } + fn bind_named(self: *Stmt, name: [:0]const u8, val: anytype) !void { + try self.db.vtable.bind_named(self.db.ptr, self.ptr, name, val); + } + /// Advance the result set to the next row. + pub fn step(self: *Stmt) !bool { + return try self.db.vtable.step(self.db.ptr, self.ptr); + } + /// Read a column + pub fn column(self: *Stmt, comptime T: type, index: u31) !?T { + switch (@typeInfo(T)) { + .Int => |intinfo| { + if (intinfo.signedness != .signed) @compileError("integer type i64 only is supported in Stmt#column"); + if (intinfo.bits != 64) @compileError("integer type i64 only is supported in Stmt#column"); + return try self.db.vtable.column_i64(self.db.ptr, self.ptr, index); + }, + .Pointer => |ptrinfo| { + if (ptrinfo.size != .Slice) @compileError("pointer type []const u8 only is supported in Stmt#column"); + if (ptrinfo.child != u8) @compileError("pointer type []const u8 only is supported in Stmt#column"); + return try self.db.vtable.column_slice_const_u8(self.db.ptr, self.ptr, index); + }, + else => @compileError("unhandled type " ++ @tagName(@typeInfo(T)) ++ " in Stmt#column"), + } + } + pub fn close(self: *Stmt) void { + self.db.vtable.close_stmt(self.db.ptr, self.ptr); + } +}; + +/// Interface for running queries +/// params should be a tuple with values for placeholders. +pub fn query(self: *Db, comptime qry: [:0]const u8, params: anytype) !Stmt { + const ti = @typeInfo(@TypeOf(params)); + if (ti != .Struct) @compileError("Db.query params must be a tuple struct but it's a " ++ @tagName(ti)); + const si = ti.Struct; + if (!si.is_tuple) @compileError("Db.query params must be a tuple"); + + const stmt_ptr = try self.vtable.prepare(self.ptr, qry); + var stmt = Stmt{ .db = self, .ptr = stmt_ptr }; + inline for (0..si.fields.len) |idx| { + try stmt.bind(idx, params[idx]); + } + return stmt; +} + +/// Shortcut for query/step/close +pub fn exec(self: *Db, comptime qry: [:0]const u8) !void { + var stmt = try self.query(qry, .{}); + defer stmt.close(); + _ = try stmt.step(); +} + +pub fn close(self: *Db) void { + self.vtable.close_db(self.ptr); +} + +test { + _ = @import("test.zig"); +} \ No newline at end of file diff --git a/src/postgres.zig b/src/postgres.zig new file mode 100644 index 0000000..5a17bd7 --- /dev/null +++ b/src/postgres.zig @@ -0,0 +1,126 @@ +const std = @import("std"); +const pq = @cImport( + @cInclude("libpq-fe.h"), +); +const Db = @import("main.zig"); +const OpenError = Db.OpenError; +const PrepareError = Db.PrepareError; +const StepError = Db.StepError; +const ColumnError = Db.ColumnError; +const log = @import("log.zig").scoped_log_t(.postgres); + +/// Postgres implementation +/// Single persistent connection implementation of postgres via libpq +pub const Postgres = @This(); + +allocator: std.mem.Allocator, +conn: *pq.PGconn, + +/// Connect to a postgres database. URL is format accepted by libpq +pub fn open(allocator: std.mem.Allocator, url: [:0]const u8) OpenError!Db { + if (pq.PQisthreadsafe() == 0) { + log.err("Postgres#open: PQisthreadsafe returned 0, can't use libpq in this program", .{}); + return OpenError.NotThreadSafe; + } + var maybe_conn: ?*pq.PGconn = pq.PQconnectdb(url); + if (pq.PQstatus(maybe_conn) != pq.CONNECTION_OK) { + log.err("Postgres#open: PQstatus returned error {}: {s}", .{ pq.PQstatus(maybe_conn), pq.PQerrorMessage(maybe_conn) }); + return OpenError.Failed; + } + + var pg = try allocator.create(Postgres); + pg.allocator = allocator; + pg.conn = maybe_conn.?; + return Db{ .ptr = pg, .vtable = .{ + .prepare = prepare, + .step = step, + .column_i64 = column_i64, + .column_slice_const_u8 = column_slice_const_u8, + .close_stmt = close_stmt, + .close_db = close_db, + } }; +} + +fn prepare(db: *anyopaque, query: [:0]const u8) PrepareError!*anyopaque { + var self: *Postgres = @alignCast(@ptrCast(db)); + var pg_stmt = try self.allocator.create(PgStmt); + pg_stmt.* = .{ + .query = query, + .params = std.ArrayList([*c]const u8).init(self.allocator), + }; + return pg_stmt; +} + +const PgStmt = struct { + query: [:0]const u8, + params: std.ArrayList([*c]const u8), + c_res: ?*pq.PGresult = null, + did_exec: bool = false, + n_tuples: ?c_int = null, + n_fields: ?c_int = null, + res_index: c_int = -1, +}; + +// TODO +// fn bind + +fn step(db: *anyopaque, stmt: *anyopaque) StepError!bool { + var self: *Postgres = @alignCast(@ptrCast(db)); + var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt)); + if (!pgstmt.did_exec) { + const params = try pgstmt.params.toOwnedSlice(); + pgstmt.c_res = pq.PQexecParams(self.conn, pgstmt.query, @intCast(params.len), null, params.ptr, null, null, 0); + const rs = pq.PQresultStatus(pgstmt.c_res); + if (rs != pq.PGRES_TUPLES_OK and rs != pq.PGRES_SINGLE_TUPLE and rs != pq.PGRES_COMMAND_OK) { + log.err("PQresultStatus {} error: {s}", .{ rs, pq.PQerrorMessage(self.conn) }); + return StepError.Failed; + } + pgstmt.n_tuples = pq.PQntuples(pgstmt.c_res); + pgstmt.n_fields = pq.PQnfields(pgstmt.c_res); + pgstmt.did_exec = true; + } + pgstmt.res_index = pgstmt.res_index + 1; + return pgstmt.res_index < pgstmt.n_tuples.?; +} + +fn column_i64(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?i64 { + _ = db; + var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt)); + if (!pgstmt.did_exec) @panic("did_exec == false you must call exec then step before trying column"); + if (pgstmt.res_index == -1) @panic("res_index == -1 you must call exec then step before trying column"); + if (pq.PQgetisnull(pgstmt.c_res, pgstmt.res_index, idx) == 1) { + return null; + } + const value_c: [*c]const u8 = pq.PQgetvalue(pgstmt.c_res, pgstmt.res_index, idx); + const slice = std.mem.sliceTo(value_c, 0); + return try std.fmt.parseInt(i64, slice, 10); +} + +fn column_slice_const_u8(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?[:0]const u8 { + _ = db; + var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt)); + if (!pgstmt.did_exec) @panic("did_exec == false you must call exec then step before trying column"); + if (pgstmt.res_index == -1) @panic("res_index == -1 you must call exec then step before trying column"); + if (pq.PQgetisnull(pgstmt.c_res, pgstmt.res_index, idx) == 1) { + return null; + } + const value_c: ?[*:0]u8 = pq.PQgetvalue(pgstmt.c_res, pgstmt.res_index, idx); + const value_c_nonnull = value_c orelse return null; + return std.mem.sliceTo(value_c_nonnull, 0); +} + +fn close_stmt(db: *anyopaque, stmt: *anyopaque) void { + var self: *Postgres = @alignCast(@ptrCast(db)); + var pgstmt: *PgStmt = @alignCast(@ptrCast(stmt)); + if (pgstmt.c_res != null) { + pq.PQclear(pgstmt.c_res.?); + } + pgstmt.params.deinit(); + self.allocator.destroy(pgstmt); +} + +fn close_db(db: *anyopaque) void { + var self: *Postgres = @alignCast(@ptrCast(db)); + pq.PQfinish(self.conn); + self.allocator.destroy(self); +} \ No newline at end of file diff --git a/src/sqlite.zig b/src/sqlite.zig new file mode 100644 index 0000000..e0668cc --- /dev/null +++ b/src/sqlite.zig @@ -0,0 +1,101 @@ +const std = @import("std"); +const sqlite = @cImport({ + @cInclude("sqlite3.h"); +}); +const Db = @import("main.zig"); +const OpenError = Db.OpenError; +const PrepareError = Db.PrepareError; +const StepError = Db.StepError; +const ColumnError = Db.ColumnError; +const log = @import("log.zig").scoped_log_t(.sqlite); + +//// Sqlite implementation +pub const Sqlite = @This(); + +allocator: std.mem.Allocator, +c_db: *sqlite.sqlite3, + +pub fn open(allocator: std.mem.Allocator, filename: [:0]const u8) OpenError!Db { + var db: ?*sqlite.sqlite3 = null; + const oo = sqlite.SQLITE_OPEN_CREATE | sqlite.SQLITE_OPEN_READWRITE | sqlite.SQLITE_OPEN_FULLMUTEX; + if (sqlite.sqlite3_open_v2(filename.ptr, &db, oo, null) != sqlite.SQLITE_OK) { + log.err("Sqlite#open: sqlite3_open_v2 error {s}", .{sqlite.sqlite3_errmsg(db)}); + return OpenError.Failed; + } + var sqlite_db = try allocator.create(Sqlite); + sqlite_db.allocator = allocator; + sqlite_db.c_db = db.?; + return Db{ + .ptr = sqlite_db, + .vtable = .{ + .prepare = prepare, + .step = step, + .column_i64 = column_i64, + .column_slice_const_u8 = column_slice_const_u8, + .close_stmt = close_stmt, + .close_db = close_db, + }, + }; +} + +fn prepare(db: *anyopaque, query: [:0]const u8) PrepareError!*anyopaque { + var self: *Sqlite = @alignCast(@ptrCast(db)); + var sstmt: ?*sqlite.sqlite3_stmt = null; + if (sqlite.sqlite3_prepare_v2(self.c_db, query.ptr, @intCast(query.len), &sstmt, null) != sqlite.SQLITE_OK) { + log.err("Sqlite#prepare: sqlite3_prepare_v2: {s}", .{sqlite.sqlite3_errmsg(self.c_db)}); + return PrepareError.Failed; + } + return sstmt.?; +} + +fn step(db: *anyopaque, stmt: *anyopaque) StepError!bool { + var self: *Sqlite = @alignCast(@ptrCast(db)); + var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt)); + const res = sqlite.sqlite3_step(sstmt); + if (res == sqlite.SQLITE_ROW) { + return true; + } else if (res == sqlite.SQLITE_DONE) { + return false; + } else { + log.err("Sqlite#step: sqlite3_step: {s}", .{sqlite.sqlite3_errmsg(self.c_db)}); + return StepError.Failed; + } +} + +fn column_i64(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?i64 { + _ = db; + var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt)); + const ct = sqlite.sqlite3_column_type(sstmt, idx); + if (ct == sqlite.SQLITE_NULL) { + return null; + } else if (ct == sqlite.SQLITE_INTEGER) { + return sqlite.sqlite3_column_int64(sstmt, idx); + } else { + return ColumnError.WrongType; + } +} +fn column_slice_const_u8(db: *anyopaque, stmt: *anyopaque, idx: u31) ColumnError!?[:0]const u8 { + _ = db; + var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt)); + const ct = sqlite.sqlite3_column_type(sstmt, idx); + if (ct == sqlite.SQLITE_NULL) { + return null; + } else if (ct == sqlite.SQLITE_TEXT) { + const value_c = sqlite.sqlite3_column_text(sstmt, idx); + return std.mem.sliceTo(value_c, 0); + } else { + return ColumnError.WrongType; + } +} + +fn close_stmt(db: *anyopaque, stmt: *anyopaque) void { + _ = db; + var sstmt: *sqlite.sqlite3_stmt = @alignCast(@ptrCast(stmt)); + _ = sqlite.sqlite3_finalize(sstmt); +} + +fn close_db(db: *anyopaque) void { + var self: *Sqlite = @alignCast(@ptrCast(db)); + _ = sqlite.sqlite3_close_v2(self.c_db); + self.allocator.destroy(self); +} \ No newline at end of file diff --git a/src/test.zig b/src/test.zig new file mode 100644 index 0000000..afd726b --- /dev/null +++ b/src/test.zig @@ -0,0 +1,102 @@ +const std = @import("std"); +const Db = @import("main.zig"); +const OpenError = Db.OpenError; +const PrepareError = Db.PrepareError; +const StepError = Db.StepError; +const ColumnError = Db.ColumnError; + +const Postgres = @import("postgres.zig"); +const Sqlite = @import("sqlite.zig"); + +const TestDbIterator = struct { + allocator: std.mem.Allocator, + pg_url: [:0]const u8, + sqlite_url: [:0]const u8, + done_postgres: bool = false, + done_sqlite: bool = false, + fn init(allocator: std.mem.Allocator, pg_url: [:0]const u8, sqlite_url: [:0]const u8) TestDbIterator { + return TestDbIterator{ + .allocator = allocator, + .pg_url = pg_url, + .sqlite_url = sqlite_url, + }; + } + fn next(self: *TestDbIterator) ?(OpenError!Db) { + if (!self.done_postgres) { + self.done_postgres = true; + return Postgres.open(self.allocator, self.pg_url); + } + if (!self.done_sqlite) { + self.done_sqlite = true; + return Sqlite.open(self.allocator, self.sqlite_url); + } + return null; + } +}; + +test "open" { + var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db"); + var maybe_db = it.next(); + while (maybe_db != null): (maybe_db = it.next()) { + var db = try maybe_db.?; + defer db.close(); + } +} + +test "open error" { + var it = TestDbIterator.init(std.testing.allocator, "NOPE DOPE", "floogle/flungle"); + var maybe_db = it.next(); + while (maybe_db != null): (maybe_db = it.next()) { + var db = maybe_db.?; + try std.testing.expectEqual(@as(OpenError!Db, OpenError.Failed), db); + } +} + +test "query" { + var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db"); + var maybe_db = it.next(); + while (maybe_db != null): (maybe_db = it.next()) { + var db = try maybe_db.?; + defer db.close(); + try db.exec("begin"); // deliberately don't commit! + try db.exec("create table foo(col1 int, col2 text)"); + try db.exec("insert into foo(col1, col2) values(123, 'hi')"); + var stmt = try db.query("select col1, col2 from foo", .{}); + defer stmt.close(); + try std.testing.expect(try stmt.step()); + const col1 = try stmt.column(i64, 0); + const col2 = try stmt.column([]const u8, 1); + try std.testing.expectEqual(@as(?i64, 123), col1); + try std.testing.expectEqualStrings("hi", col2.?); + } +} + +test "query null column" { + var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db"); + var maybe_db = it.next(); + while (maybe_db != null): (maybe_db = it.next()) { + var db = try maybe_db.?; + defer db.close(); + try db.exec("begin"); // deliberately don't commit! + try db.exec("create table foo(col1 int, col2 text)"); + try db.exec("insert into foo(col1, col2) values(null, null)"); + var stmt = try db.query("select col1, col2 from foo", .{}); + defer stmt.close(); + try std.testing.expect(try stmt.step()); + const col1 = try stmt.column(i64, 0); + const col2 = try stmt.column([]const u8, 1); + try std.testing.expectEqual(@as(?i64, null), col1); + try std.testing.expectEqual(@as(?[]const u8,null), col2); + } +} + +test "exec error" { + var it = TestDbIterator.init(std.testing.allocator, "postgresql:///testdb", "testdb.db"); + var maybe_db = it.next(); + while (maybe_db != null): (maybe_db = it.next()) { + var db = try maybe_db.?; + defer db.close(); + const res = db.exec("AIN'T VALID BRO"); + try std.testing.expectEqual(@as(StepError!void, StepError.Failed), res); + } +} -- cgit v1.2.3-ZIG