aboutsummaryrefslogtreecommitdiff
path: root/src/main.zig
blob: 818b3b7df3a25d84ac4cbfad2c5b3de0b1f5736d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
const std = @import("std");
const testing = std.testing;

pub const ProtocolError = error{
    InvalidProtocolVersion,
    InvalidKeyValuePair,
    InvalidMessageLength,
    InvalidAuthType,
    MissingField,
    WrongMessageType,
    InvalidTransactionStatus,
    InvalidFormatCode,
};

pub const FormatCode = enum(u16) {
    Text = 0,
    Binary = 1,
};

pub const ClientError = error{
    UnsupportedAuthType,
    NoPasswordSupplied,
};

pub const ServerError = error{
    ErrorResponse,
};

// Fallible version of enumFromInt
pub fn enum_from_int(comptime e: type, i: anytype) ?e {
    const enum_ti = @typeInfo(e);
    if (enum_ti != .Enum) @compileError("e should be an enum but instead it's a " ++ @typeName(e));
    const ei = enum_ti.Enum;
    if (@TypeOf(i) != ei.tag_type) @compileError("i should be of type " ++ @typeName(e) ++ " but instead it's " ++ @typeName(@TypeOf(i)));
    inline for (ei.fields) |field| {
        if (field.value == i) {
            return @enumFromInt(i);
        }
    } else {
        return null;
    }
}

// Tag should already have been read in order to determine msg_type!
pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, stream_reader: anytype) !msg_type {
    if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!");
    if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!");
    const len = try stream_reader.readIntBig(u32);
    const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4)));
    defer allocator.free(buf);
    try stream_reader.readNoEof(buf);
    return try msg_type.read(allocator, buf);
}

pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) {
    return .{ .child_reader = base_reader };
}

// keeps a buffer of the last n bytes read
pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type {
    return struct {
        child_reader: ReaderType,
        ring: [n]u8 = [_]u8{0} ** n,
        pos: usize = 0,

        pub const Error = ReaderType.Error;
        pub const Reader = std.io.Reader(*@This(), Error, read);

        pub fn read(self: *@This(), buf: []u8) Error!usize {
            const amt = try self.child_reader.read(buf);
            for (0..amt) |i| {
                self.ring[self.pos] = buf[i];
                self.pos += 1;
                self.pos %= n;
            }
            return amt;
        }

        pub fn reader(self: *@This()) Reader {
            return .{ .context = self };
        }

        // Caller frees
        pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 {
            var buf = try allocator.alloc(u8, n);
            errdefer allocator.free(buf);
            @memcpy(buf[0..(n - self.pos)], self.ring[self.pos..n]);
            @memcpy(buf[(n - self.pos)..n], self.ring[0..self.pos]);
            return buf;
        }
    };
}

test "diagnostc reader" {
    const a = std.testing.allocator;
    const string = "The quick brown fox jumped over the lazy dog";
    var fbs = std.io.fixedBufferStream(string);
    var dr = diagnosticReader(15, fbs.reader());
    var reader = dr.reader();
    var buf = [_]u8{0} ** 20;
    try reader.readNoEof(&buf);
    const diag = try dr.get(a);
    defer a.free(diag);
    try std.testing.expectEqualStrings("uick brown fox ", diag);
}

test {
    const Conn = @import("conn/conn.zig");
    const Proto = @import("proto/proto.zig");
    _ = Proto;
    _ = Conn;
}