const std = @import("std"); const pq = @cImport( @cInclude("libpq-fe.h"), ); pub const PqError = error{PqError}; // libpq wrapper // later, this could be a pure-zig client implementation pub const Db = struct { c_conn: *pq.PGconn, pub fn init(connect_url: [:0]const u8) !Db { if (pq.PQisthreadsafe() == 0) { std.log.err("PQisthreadsafe returned 0, can't use libpq in this program", .{}); return PqError.PqError; } var maybe_conn: ?*pq.PGconn = pq.PQconnectdb(connect_url); if (maybe_conn == null) { std.log.err("PQconnectdb returned null", .{}); return PqError.PqError; } if (pq.PQstatus(maybe_conn) == pq.CONNECTION_BAD) { std.log.err("PQstatus returned CONNECTION_BAD: {s}", .{pq.PQerrorMessage(maybe_conn)}); return PqError.PqError; } else if (pq.PQstatus(maybe_conn) != pq.CONNECTION_OK) { std.log.err("PQstatus returned unknown status {}: {s}", .{ pq.PQstatus(maybe_conn), pq.PQerrorMessage(maybe_conn) }); return PqError.PqError; } return Db{ .c_conn = maybe_conn.?, }; } pub fn deinit(self: Db) void { pq.PQfinish(self.c_conn); } pub fn exec(self: Db, query: [:0]const u8) !void { var res: ?*pq.PGresult = pq.PQexec(self.c_conn, query); defer pq.PQclear(res); var est: pq.ExecStatusType = pq.PQresultStatus(res); if (est != pq.PGRES_COMMAND_OK) { std.log.err("PQexec error code {} message {s}", .{ est, pq.PQerrorMessage(self.c_conn) }); return PqError.PqError; } } pub fn prepare_statement(self: Db, query: [:0]const u8) PqError!Stmt { return Stmt{ .db = self, .query = query, }; //pq.PQexec(conn: ?*PGconn, query: [*c]const u8) } }; pub const Stmt = struct { const MAX_PARAMS = 128; db: Db, query: [:0]const u8, // TODO take child allocator as a param aa: std.heap.ArenaAllocator = std.heap.ArenaAllocator.init(std.heap.page_allocator), n_params: usize = 0, param_values: [MAX_PARAMS][*c]const u8 = undefined, did_exec: bool = false, c_res: ?*pq.PGresult = null, res_index: c_int = -1, n_tuples: c_int = -1, n_fields: c_int = -1, pub fn deinit(self: *Stmt) void { self.aa.deinit(); pq.PQclear(self.c_res); } pub fn step(self: *Stmt) !bool { if (!self.did_exec) { self.did_exec = true; self.c_res = pq.PQexecParams(self.db.c_conn, self.query, @as(c_int, @intCast(self.n_params)), null, &self.param_values, null, null, 0); const rs = pq.PQresultStatus(self.c_res); if (rs != pq.PGRES_TUPLES_OK and rs != pq.PGRES_SINGLE_TUPLE and rs != pq.PGRES_COMMAND_OK) { std.log.err("PQresultStatus {} error: {s}", .{ rs, pq.PQerrorMessage(self.db.c_conn) }); return PqError.PqError; } self.n_tuples = pq.PQntuples(self.c_res); self.n_fields = pq.PQnfields(self.c_res); } self.res_index = self.res_index + 1; return self.res_index < self.n_tuples; } pub fn read_column(self: Stmt, idx: c_int, comptime T: type) !T { const value_c: [*c]u8 = pq.PQgetvalue(self.c_res, self.res_index, idx); const value: []const u8 = std.mem.sliceTo(value_c, 0); return switch (@typeInfo(T)) { .Int => std.fmt.parseInt(T, value, 10), .Pointer => |ptrinfo| blk: { if (ptrinfo.child != u8) { @compileError("pointer type []const u8 only is supported in read_column"); } if (ptrinfo.size != .Slice) { @compileError("pointer type []const u8 only is supported in read_column"); } break :blk value; }, else => @compileError("unhandled type " ++ @tagName(@typeInfo(T)) ++ " in read_column"), }; } pub fn read_columnN(self: Stmt, name: [:0]const u8, comptime T: type) !T { const idx = pq.PQfnumber(self.c_res, name.ptr); if (idx == -1) { return error.ColumnNotFound; } return read_column(self, idx, T); } pub fn bind(self: *Stmt, idx: usize, t: anytype) !void { const T = @TypeOf(t); const value: [:0]const u8 = switch (@typeInfo(T)) { .Pointer => try std.fmt.allocPrintZ(self.aa.allocator(), "{s}", .{t}), .Int => try std.fmt.allocPrintZ(self.aa.allocator(), "{d}", .{t}), else => @compileError("unhandled type " ++ @tagName(@typeInfo(T) ++ " in bind")), }; self.param_values[idx] = value.ptr; self.n_params = @max(self.n_params, idx + 1); } pub fn read_struct(self: Stmt, comptime T: type) !T { const ti = @typeInfo(T); var t: T = undefined; inline for (ti.Struct.fields) |field| { const val = try self.read_columnN(field.name, field.type); @field(t, field.name) = val; } return t; } };