aboutsummaryrefslogtreecommitdiff
path: root/src/main.zig
blob: 07b66288a1f6690b3a924e1057b1091e8c4cbb17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
const std = @import("std");
const testing = std.testing;
const StartupMessage = @import("startup_message.zig");
const AuthenticationRequest = @import("authentication_request.zig");
const PasswordMessage = @import("password_message.zig");
const ErrorResponse = @import("error_response.zig");
const ReadyForQuery = @import("ready_for_query.zig");
const ParameterStatus = @import("parameter_status.zig");
const BackendKeyData = @import("backend_key_data.zig");
const Query = @import("query.zig");
const DataRow = @import("data_row.zig");
const Conn = @import("conn.zig");

pub const ProtocolError = error{
    InvalidProtocolVersion,
    InvalidKeyValuePair,
    InvalidMessageLength,
    InvalidAuthType,
    MissingField,
    WrongMessageType,
    InvalidTransactionStatus,
};

pub const ClientError = error{
    UnsupportedAuthType,
};

pub const ServerError = error{
    ErrorResponse,
};

// Fallible version of enumFromInt
pub fn enum_from_int(comptime e: type, i: anytype) ?e {
    const enum_ti = @typeInfo(e);
    if (enum_ti != .Enum) @compileError("e should be an enum but instead it's a " ++ @typeName(e));
    const ei = enum_ti.Enum;
    if (@TypeOf(i) != ei.tag_type) @compileError("i should be of type " ++ @typeName(e) ++ " but instead it's " ++ @typeName(@TypeOf(i)));
    inline for (ei.fields) |field| {
        if (field.value == i) {
            return @enumFromInt(i);
        }
    } else {
        return null;
    }
}

// Tag should already have been read in order to determine msg_type!
pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, stream_reader: anytype) !msg_type {
    if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!");
    if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!");
    const len = try stream_reader.readIntBig(u32);
    const buf = try allocator.alloc(u8, @as(u32, @intCast(len-4)));
    defer allocator.free(buf);
    try stream_reader.readNoEof(buf);
    return try msg_type.read(allocator, buf);
}


pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) {
    return .{.child_reader = base_reader};
}

// keeps a buffer of the last n bytes read
pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type {
    return struct {
        child_reader: ReaderType,
        ring: [n]u8 = [_]u8{0}**n,
        pos: usize = 0,

        pub const Error = ReaderType.Error;
        pub const Reader = std.io.Reader(*@This(), Error, read);

        pub fn read(self: *@This(), buf: []u8) Error!usize {
            const amt = try self.child_reader.read(buf);
            for (0..amt) |i| {
                self.ring[self.pos] = buf[i];
                self.pos += 1;
                self.pos %= n;
            }
            return amt;
        }

        pub fn reader(self: *@This()) Reader {
            return .{ .context = self };
        }

        // Caller frees
        pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 {
            var buf = try allocator.alloc(u8, n);
            errdefer allocator.free(buf);
            @memcpy(buf[0..(n-self.pos)], self.ring[self.pos..n]);
            @memcpy(buf[(n-self.pos)..n], self.ring[0..self.pos]);
            return buf;
        }
    };
}

test "diagnostc reader" {
    const a = std.testing.allocator;
    const string = "The quick brown fox jumped over the lazy dog";
    var fbs = std.io.fixedBufferStream(string);
    var dr = diagnosticReader(15, fbs.reader());
    var reader = dr.reader();
    var buf = [_]u8{0}**20;
    try reader.readNoEof(&buf);
    const diag = try dr.get(a);
    defer a.free(diag);
    try std.testing.expectEqualStrings("uick brown fox ", diag);
}

test {
    _ = StartupMessage;
    _ = AuthenticationRequest;
    _ = PasswordMessage;
    _ = ErrorResponse;
    _ = Conn;
    _ = ReadyForQuery;
    _ = ParameterStatus;
    _ = BackendKeyData;
    _ = Query;
    _ = DataRow;
}