pq-zig

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README

main.zig (5666B)


      1 const std = @import("std");
      2 const pq = @cImport(
      3     @cInclude("libpq-fe.h"),
      4 );
      5 
      6 pub const PqError = error{PqError};
      7 
      8 // libpq wrapper
      9 // later, this could be a pure-zig client implementation
     10 pub const Db = struct {
     11     c_conn: *pq.PGconn,
     12     pub fn init(connect_url: [:0]const u8) !Db {
     13         if (pq.PQisthreadsafe() == 0) {
     14             std.log.err("PQisthreadsafe returned 0, can't use libpq in this program", .{});
     15             return PqError.PqError;
     16         }
     17         const maybe_conn: ?*pq.PGconn = pq.PQconnectdb(connect_url);
     18         if (maybe_conn == null) {
     19             std.log.err("PQconnectdb returned null", .{});
     20             return PqError.PqError;
     21         }
     22         if (pq.PQstatus(maybe_conn) == pq.CONNECTION_BAD) {
     23             std.log.err("PQstatus returned CONNECTION_BAD: {s}", .{pq.PQerrorMessage(maybe_conn)});
     24 
     25             return PqError.PqError;
     26         } else if (pq.PQstatus(maybe_conn) != pq.CONNECTION_OK) {
     27             std.log.err("PQstatus returned unknown status {}: {s}", .{ pq.PQstatus(maybe_conn), pq.PQerrorMessage(maybe_conn) });
     28             return PqError.PqError;
     29         }
     30         return Db{
     31             .c_conn = maybe_conn.?,
     32         };
     33     }
     34 
     35     pub fn deinit(self: Db) void {
     36         pq.PQfinish(self.c_conn);
     37     }
     38 
     39     pub fn exec(self: Db, query: [:0]const u8) !void {
     40         const res: ?*pq.PGresult = pq.PQexec(self.c_conn, query);
     41         defer pq.PQclear(res);
     42         const est: pq.ExecStatusType = pq.PQresultStatus(res);
     43         if (est != pq.PGRES_COMMAND_OK) {
     44             std.log.err("PQexec error code {} message {s}", .{ est, pq.PQerrorMessage(self.c_conn) });
     45             return PqError.PqError;
     46         }
     47     }
     48 
     49     pub fn prepare_statement(self: Db, allocator: std.mem.Allocator, query: [:0]const u8) PqError!Stmt {
     50         return Stmt{
     51             .db = self,
     52             .aa = std.heap.ArenaAllocator.init(allocator),
     53             .query = query,
     54         };
     55     }
     56 };
     57 
     58 pub const Stmt = struct {
     59     const MAX_PARAMS = 128;
     60     db: Db,
     61     query: [:0]const u8,
     62     aa: std.heap.ArenaAllocator,
     63 
     64     n_params: usize = 0,
     65     param_values: [MAX_PARAMS][*c]const u8 = undefined,
     66     did_exec: bool = false,
     67     c_res: ?*pq.PGresult = null,
     68     res_index: c_int = -1,
     69     n_tuples: c_int = -1,
     70     n_fields: c_int = -1,
     71 
     72     pub fn deinit(self: *Stmt) void {
     73         self.aa.deinit();
     74         if (self.c_res != null) {
     75             pq.PQclear(self.c_res);
     76         }
     77     }
     78 
     79     pub fn step(self: *Stmt) !bool {
     80         if (!self.did_exec) {
     81             self.did_exec = true;
     82             self.c_res = pq.PQexecParams(self.db.c_conn, self.query, @as(c_int, @intCast(self.n_params)), null, &self.param_values, null, null, 0);
     83             const rs = pq.PQresultStatus(self.c_res);
     84             if (rs != pq.PGRES_TUPLES_OK and rs != pq.PGRES_SINGLE_TUPLE and rs != pq.PGRES_COMMAND_OK) {
     85                 std.log.err("PQresultStatus {} error: {s}", .{ rs, pq.PQerrorMessage(self.db.c_conn) });
     86                 return PqError.PqError;
     87             }
     88             self.n_tuples = pq.PQntuples(self.c_res);
     89             self.n_fields = pq.PQnfields(self.c_res);
     90         }
     91         self.res_index = self.res_index + 1;
     92         return self.res_index < self.n_tuples;
     93     }
     94 
     95     pub fn read_column(self: Stmt, idx: c_int, comptime T: type) !T {
     96         const value_c: [*c]u8 = pq.PQgetvalue(self.c_res, self.res_index, idx);
     97         const value: []const u8 = std.mem.sliceTo(value_c, 0);
     98         return switch (@typeInfo(T)) {
     99             .Int => std.fmt.parseInt(T, value, 10),
    100             .Pointer => |ptrinfo| blk: {
    101                 if (ptrinfo.child != u8) {
    102                     @compileError("pointer type []const u8 only is supported in read_column");
    103                 }
    104                 if (ptrinfo.size != .Slice) {
    105                     @compileError("pointer type []const u8 only is supported in read_column");
    106                 }
    107                 break :blk value;
    108             },
    109             else => @compileError("unhandled type " ++ @tagName(@typeInfo(T)) ++ " in read_column"),
    110         };
    111     }
    112 
    113     pub fn read_columnN(self: Stmt, name: [:0]const u8, comptime T: type) !T {
    114         const idx = pq.PQfnumber(self.c_res, name.ptr);
    115         if (idx == -1) {
    116             std.log.err("read_columnN ColumnNotFound [{s}]", .{name});
    117             return error.ColumnNotFound;
    118         }
    119         return read_column(self, idx, T);
    120     }
    121 
    122     pub fn bind(self: *Stmt, idx: usize, t: anytype) !void {
    123         const T = @TypeOf(t);
    124         const value: [:0]const u8 = switch (@typeInfo(T)) {
    125             .Pointer => try std.fmt.allocPrintZ(self.aa.allocator(), "{s}", .{t}),
    126             .Int => try std.fmt.allocPrintZ(self.aa.allocator(), "{d}", .{t}),
    127             else => @compileError("unhandled type " ++ @tagName(@typeInfo(T) ++ " in bind")),
    128         };
    129         self.param_values[idx] = value.ptr;
    130         self.n_params = @max(self.n_params, idx + 1);
    131     }
    132 
    133     pub fn read_struct(self: Stmt, comptime T: type) !T {
    134         const ti = @typeInfo(T);
    135         var t: T = undefined;
    136         inline for (ti.Struct.fields) |field| {
    137             const name: [:0]const u8 = &addZ(field.name.len, field.name[0..].*);
    138             const val = try self.read_columnN(name, field.type);
    139             @field(t, field.name) = val;
    140         }
    141         return t;
    142     }
    143 };
    144 
    145 // https://github.com/ziglang/zig/issues/16116
    146 pub fn addZ(comptime length: usize, value: [length]u8) [length:0]u8 {
    147     var terminated_value: [length:0]u8 = undefined;
    148     terminated_value[length] = 0;
    149     @memcpy(&terminated_value, &value);
    150     return terminated_value;
    151 }
    152 
    153 test "connect" {
    154     var db = try Db.init("postgresql://localhost/comments");
    155     db.deinit();
    156 }