main.zig (5666B)
1 const std = @import("std"); 2 const pq = @cImport( 3 @cInclude("libpq-fe.h"), 4 ); 5 6 pub const PqError = error{PqError}; 7 8 // libpq wrapper 9 // later, this could be a pure-zig client implementation 10 pub const Db = struct { 11 c_conn: *pq.PGconn, 12 pub fn init(connect_url: [:0]const u8) !Db { 13 if (pq.PQisthreadsafe() == 0) { 14 std.log.err("PQisthreadsafe returned 0, can't use libpq in this program", .{}); 15 return PqError.PqError; 16 } 17 const maybe_conn: ?*pq.PGconn = pq.PQconnectdb(connect_url); 18 if (maybe_conn == null) { 19 std.log.err("PQconnectdb returned null", .{}); 20 return PqError.PqError; 21 } 22 if (pq.PQstatus(maybe_conn) == pq.CONNECTION_BAD) { 23 std.log.err("PQstatus returned CONNECTION_BAD: {s}", .{pq.PQerrorMessage(maybe_conn)}); 24 25 return PqError.PqError; 26 } else if (pq.PQstatus(maybe_conn) != pq.CONNECTION_OK) { 27 std.log.err("PQstatus returned unknown status {}: {s}", .{ pq.PQstatus(maybe_conn), pq.PQerrorMessage(maybe_conn) }); 28 return PqError.PqError; 29 } 30 return Db{ 31 .c_conn = maybe_conn.?, 32 }; 33 } 34 35 pub fn deinit(self: Db) void { 36 pq.PQfinish(self.c_conn); 37 } 38 39 pub fn exec(self: Db, query: [:0]const u8) !void { 40 const res: ?*pq.PGresult = pq.PQexec(self.c_conn, query); 41 defer pq.PQclear(res); 42 const est: pq.ExecStatusType = pq.PQresultStatus(res); 43 if (est != pq.PGRES_COMMAND_OK) { 44 std.log.err("PQexec error code {} message {s}", .{ est, pq.PQerrorMessage(self.c_conn) }); 45 return PqError.PqError; 46 } 47 } 48 49 pub fn prepare_statement(self: Db, allocator: std.mem.Allocator, query: [:0]const u8) PqError!Stmt { 50 return Stmt{ 51 .db = self, 52 .aa = std.heap.ArenaAllocator.init(allocator), 53 .query = query, 54 }; 55 } 56 }; 57 58 pub const Stmt = struct { 59 const MAX_PARAMS = 128; 60 db: Db, 61 query: [:0]const u8, 62 aa: std.heap.ArenaAllocator, 63 64 n_params: usize = 0, 65 param_values: [MAX_PARAMS][*c]const u8 = undefined, 66 did_exec: bool = false, 67 c_res: ?*pq.PGresult = null, 68 res_index: c_int = -1, 69 n_tuples: c_int = -1, 70 n_fields: c_int = -1, 71 72 pub fn deinit(self: *Stmt) void { 73 self.aa.deinit(); 74 if (self.c_res != null) { 75 pq.PQclear(self.c_res); 76 } 77 } 78 79 pub fn step(self: *Stmt) !bool { 80 if (!self.did_exec) { 81 self.did_exec = true; 82 self.c_res = pq.PQexecParams(self.db.c_conn, self.query, @as(c_int, @intCast(self.n_params)), null, &self.param_values, null, null, 0); 83 const rs = pq.PQresultStatus(self.c_res); 84 if (rs != pq.PGRES_TUPLES_OK and rs != pq.PGRES_SINGLE_TUPLE and rs != pq.PGRES_COMMAND_OK) { 85 std.log.err("PQresultStatus {} error: {s}", .{ rs, pq.PQerrorMessage(self.db.c_conn) }); 86 return PqError.PqError; 87 } 88 self.n_tuples = pq.PQntuples(self.c_res); 89 self.n_fields = pq.PQnfields(self.c_res); 90 } 91 self.res_index = self.res_index + 1; 92 return self.res_index < self.n_tuples; 93 } 94 95 pub fn read_column(self: Stmt, idx: c_int, comptime T: type) !T { 96 const value_c: [*c]u8 = pq.PQgetvalue(self.c_res, self.res_index, idx); 97 const value: []const u8 = std.mem.sliceTo(value_c, 0); 98 return switch (@typeInfo(T)) { 99 .Int => std.fmt.parseInt(T, value, 10), 100 .Pointer => |ptrinfo| blk: { 101 if (ptrinfo.child != u8) { 102 @compileError("pointer type []const u8 only is supported in read_column"); 103 } 104 if (ptrinfo.size != .Slice) { 105 @compileError("pointer type []const u8 only is supported in read_column"); 106 } 107 break :blk value; 108 }, 109 else => @compileError("unhandled type " ++ @tagName(@typeInfo(T)) ++ " in read_column"), 110 }; 111 } 112 113 pub fn read_columnN(self: Stmt, name: [:0]const u8, comptime T: type) !T { 114 const idx = pq.PQfnumber(self.c_res, name.ptr); 115 if (idx == -1) { 116 std.log.err("read_columnN ColumnNotFound [{s}]", .{name}); 117 return error.ColumnNotFound; 118 } 119 return read_column(self, idx, T); 120 } 121 122 pub fn bind(self: *Stmt, idx: usize, t: anytype) !void { 123 const T = @TypeOf(t); 124 const value: [:0]const u8 = switch (@typeInfo(T)) { 125 .Pointer => try std.fmt.allocPrintZ(self.aa.allocator(), "{s}", .{t}), 126 .Int => try std.fmt.allocPrintZ(self.aa.allocator(), "{d}", .{t}), 127 else => @compileError("unhandled type " ++ @tagName(@typeInfo(T) ++ " in bind")), 128 }; 129 self.param_values[idx] = value.ptr; 130 self.n_params = @max(self.n_params, idx + 1); 131 } 132 133 pub fn read_struct(self: Stmt, comptime T: type) !T { 134 const ti = @typeInfo(T); 135 var t: T = undefined; 136 inline for (ti.Struct.fields) |field| { 137 const name: [:0]const u8 = &addZ(field.name.len, field.name[0..].*); 138 const val = try self.read_columnN(name, field.type); 139 @field(t, field.name) = val; 140 } 141 return t; 142 } 143 }; 144 145 // https://github.com/ziglang/zig/issues/16116 146 pub fn addZ(comptime length: usize, value: [length]u8) [length:0]u8 { 147 var terminated_value: [length:0]u8 = undefined; 148 terminated_value[length] = 0; 149 @memcpy(&terminated_value, &value); 150 return terminated_value; 151 } 152 153 test "connect" { 154 var db = try Db.init("postgresql://localhost/comments"); 155 db.deinit(); 156 }