aboutsummaryrefslogtreecommitdiff
path: root/src/main.zig
blob: 8c2aed9aa758f509d39f163e34b09b438ddbcb06 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
const std = @import("std");
const testing = std.testing;
const StartupMessage = @import("startup_message.zig");
const AuthenticationRequest = @import("authentication_request.zig");
const PasswordMessage = @import("password_message.zig");
const ErrorResponse = @import("error_response.zig");
const ReadyForQuery = @import("ready_for_query.zig");
const ParameterStatus = @import("parameter_status.zig");
const BackendKeyData = @import("backend_key_data.zig");
const Query = @import("query.zig");
const DataRow = @import("data_row.zig");
const RowDescription = @import("row_description.zig");
const Conn = @import("conn.zig");

pub const ProtocolError = error{
    InvalidProtocolVersion,
    InvalidKeyValuePair,
    InvalidMessageLength,
    InvalidAuthType,
    MissingField,
    WrongMessageType,
    InvalidTransactionStatus,
    InvalidFormatCode,
};

pub const ClientError = error{
    UnsupportedAuthType,
};

pub const ServerError = error{
    ErrorResponse,
};

pub const FormatCode = enum(u16) {
    Text = 0,
    Binary = 1,
};

// Fallible version of enumFromInt
pub fn enum_from_int(comptime e: type, i: anytype) ?e {
    const enum_ti = @typeInfo(e);
    if (enum_ti != .Enum) @compileError("e should be an enum but instead it's a " ++ @typeName(e));
    const ei = enum_ti.Enum;
    if (@TypeOf(i) != ei.tag_type) @compileError("i should be of type " ++ @typeName(e) ++ " but instead it's " ++ @typeName(@TypeOf(i)));
    inline for (ei.fields) |field| {
        if (field.value == i) {
            return @enumFromInt(i);
        }
    } else {
        return null;
    }
}

// Tag should already have been read in order to determine msg_type!
pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, stream_reader: anytype) !msg_type {
    if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!");
    if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!");
    const len = try stream_reader.readIntBig(u32);
    const buf = try allocator.alloc(u8, @as(u32, @intCast(len-4)));
    defer allocator.free(buf);
    try stream_reader.readNoEof(buf);
    return try msg_type.read(allocator, buf);
}


pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) {
    return .{.child_reader = base_reader};
}

// keeps a buffer of the last n bytes read
pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type {
    return struct {
        child_reader: ReaderType,
        ring: [n]u8 = [_]u8{0}**n,
        pos: usize = 0,

        pub const Error = ReaderType.Error;
        pub const Reader = std.io.Reader(*@This(), Error, read);

        pub fn read(self: *@This(), buf: []u8) Error!usize {
            const amt = try self.child_reader.read(buf);
            for (0..amt) |i| {
                self.ring[self.pos] = buf[i];
                self.pos += 1;
                self.pos %= n;
            }
            return amt;
        }

        pub fn reader(self: *@This()) Reader {
            return .{ .context = self };
        }

        // Caller frees
        pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 {
            var buf = try allocator.alloc(u8, n);
            errdefer allocator.free(buf);
            @memcpy(buf[0..(n-self.pos)], self.ring[self.pos..n]);
            @memcpy(buf[(n-self.pos)..n], self.ring[0..self.pos]);
            return buf;
        }
    };
}

test "diagnostc reader" {
    const a = std.testing.allocator;
    const string = "The quick brown fox jumped over the lazy dog";
    var fbs = std.io.fixedBufferStream(string);
    var dr = diagnosticReader(15, fbs.reader());
    var reader = dr.reader();
    var buf = [_]u8{0}**20;
    try reader.readNoEof(&buf);
    const diag = try dr.get(a);
    defer a.free(diag);
    try std.testing.expectEqualStrings("uick brown fox ", diag);
}

test {
    _ = StartupMessage;
    _ = AuthenticationRequest;
    _ = PasswordMessage;
    _ = ErrorResponse;
    _ = Conn;
    _ = ReadyForQuery;
    _ = ParameterStatus;
    _ = BackendKeyData;
    _ = Query;
    _ = DataRow;
    _ = RowDescription;
}