aboutsummaryrefslogtreecommitdiff
path: root/src/main.zig
blob: a4fc38266f2ae836c98c1391c6dffdb09c4ac35c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
const std = @import("std");
const pq = @cImport(
    @cInclude("libpq-fe.h"),
);

pub const PqError = error{PqError};

// libpq wrapper
// later, this could be a pure-zig client implementation
pub const Db = struct {
    c_conn: *pq.PGconn,
    pub fn init(connect_url: [:0]const u8) !Db {
        if (pq.PQisthreadsafe() == 0) {
            std.log.err("PQisthreadsafe returned 0, can't use libpq in this program", .{});
            return PqError.PqError;
        }
        var maybe_conn: ?*pq.PGconn = pq.PQconnectdb(connect_url);
        if (maybe_conn == null) {
            std.log.err("PQconnectdb returned null", .{});
            return PqError.PqError;
        }
        if (pq.PQstatus(maybe_conn) == pq.CONNECTION_BAD) {
            std.log.err("PQstatus returned CONNECTION_BAD: {s}", .{pq.PQerrorMessage(maybe_conn)});

            return PqError.PqError;
        } else if (pq.PQstatus(maybe_conn) != pq.CONNECTION_OK) {
            std.log.err("PQstatus returned unknown status {}: {s}", .{ pq.PQstatus(maybe_conn), pq.PQerrorMessage(maybe_conn) });
            return PqError.PqError;
        }
        return Db{
            .c_conn = maybe_conn.?,
        };
    }

    pub fn deinit(self: Db) void {
        pq.PQfinish(self.c_conn);
    }

    pub fn exec(self: Db, query: [:0]const u8) !void {
        var res: ?*pq.PGresult = pq.PQexec(self.c_conn, query);
        defer pq.PQclear(res);
        var est: pq.ExecStatusType = pq.PQresultStatus(res);
        if (est != pq.PGRES_COMMAND_OK) {
            std.log.err("PQexec error code {} message {s}", .{ est, pq.PQerrorMessage(self.c_conn) });
            return PqError.PqError;
        }
    }

    pub fn prepare_statement(self: Db, allocator: std.mem.Allocator, query: [:0]const u8) PqError!Stmt {
        return Stmt{
            .db = self,
            .aa = std.heap.ArenaAllocator.init(allocator),
            .query = query,
        };
    }
};

pub const Stmt = struct {
    const MAX_PARAMS = 128;
    db: Db,
    query: [:0]const u8,
    aa: std.heap.ArenaAllocator,

    n_params: usize = 0,
    param_values: [MAX_PARAMS][*c]const u8 = undefined,
    did_exec: bool = false,
    c_res: ?*pq.PGresult = null,
    res_index: c_int = -1,
    n_tuples: c_int = -1,
    n_fields: c_int = -1,

    pub fn deinit(self: *Stmt) void {
        self.aa.deinit();
        if (self.c_res != null) {
            pq.PQclear(self.c_res);
        }
    }

    pub fn step(self: *Stmt) !bool {
        if (!self.did_exec) {
            self.did_exec = true;
            self.c_res = pq.PQexecParams(self.db.c_conn, self.query, @as(c_int, @intCast(self.n_params)), null, &self.param_values, null, null, 0);
            const rs = pq.PQresultStatus(self.c_res);
            if (rs != pq.PGRES_TUPLES_OK and rs != pq.PGRES_SINGLE_TUPLE and rs != pq.PGRES_COMMAND_OK) {
                std.log.err("PQresultStatus {} error: {s}", .{ rs, pq.PQerrorMessage(self.db.c_conn) });
                return PqError.PqError;
            }
            self.n_tuples = pq.PQntuples(self.c_res);
            self.n_fields = pq.PQnfields(self.c_res);
        }
        self.res_index = self.res_index + 1;
        return self.res_index < self.n_tuples;
    }

    pub fn read_column(self: Stmt, idx: c_int, comptime T: type) !T {
        const value_c: [*c]u8 = pq.PQgetvalue(self.c_res, self.res_index, idx);
        const value: []const u8 = std.mem.sliceTo(value_c, 0);
        return switch (@typeInfo(T)) {
            .Int => std.fmt.parseInt(T, value, 10),
            .Pointer => |ptrinfo| blk: {
                if (ptrinfo.child != u8) {
                    @compileError("pointer type []const u8 only is supported in read_column");
                }
                if (ptrinfo.size != .Slice) {
                    @compileError("pointer type []const u8 only is supported in read_column");
                }
                break :blk value;
            },
            else => @compileError("unhandled type " ++ @tagName(@typeInfo(T)) ++ " in read_column"),
        };
    }

    pub fn read_columnN(self: Stmt, name: [:0]const u8, comptime T: type) !T {
        const idx = pq.PQfnumber(self.c_res, name.ptr);
        if (idx == -1) {
            std.log.err("read_columnN ColumnNotFound [{s}]", .{name});
            return error.ColumnNotFound;
        }
        return read_column(self, idx, T);
    }

    pub fn bind(self: *Stmt, idx: usize, t: anytype) !void {
        const T = @TypeOf(t);
        const value: [:0]const u8 = switch (@typeInfo(T)) {
            .Pointer => try std.fmt.allocPrintZ(self.aa.allocator(), "{s}", .{t}),
            .Int => try std.fmt.allocPrintZ(self.aa.allocator(), "{d}", .{t}),
            else => @compileError("unhandled type " ++ @tagName(@typeInfo(T) ++ " in bind")),
        };
        self.param_values[idx] = value.ptr;
        self.n_params = @max(self.n_params, idx + 1);
    }

    pub fn read_struct(self: Stmt, comptime T: type) !T {
        const ti = @typeInfo(T);
        var t: T = undefined;
        inline for (ti.Struct.fields) |field| {
            const name: [:0]const u8 = &addZ(field.name.len, field.name[0..].*);
            const val = try self.read_columnN(name, field.type);
            @field(t, field.name) = val;
        }
        return t;
    }
};

// https://github.com/ziglang/zig/issues/16116
pub fn addZ(comptime length: usize, value: [length]u8) [length:0]u8 {
    var terminated_value: [length:0]u8 = undefined;
    terminated_value[length] = 0;
    @memcpy(&terminated_value, &value);
    return terminated_value;
}