aboutsummaryrefslogtreecommitdiff
path: root/src/proto/startup_message.zig
blob: 8224bdb4e140db4a1d75724efe6b8df82a0de048 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
const std = @import("std");
const ProtocolError = @import("../main.zig").ProtocolError;
const SSHashMap = std.StringHashMap([]const u8);
const ByteArrayList = std.ArrayList(u8);

const StartupMessage = @This();

const ProtocolVersionNumber: u32 = 196608; // 3.0

bytes: ?[]const u8 = null, // Owned
parameters: SSHashMap,

// message length should already have been read, b should contain the payload
pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage {
    if (b.len < 4) return ProtocolError.InvalidMessageLength;

    var bytes = try allocator.dupe(u8, b);
    errdefer allocator.free(bytes);
    const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]);
    if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion;

    var parameters = SSHashMap.init(allocator);
    var it = std.mem.splitScalar(u8, bytes[4..], 0);
    while (it.next()) |next| {
        const key = next;
        const value = it.next() orelse return ProtocolError.InvalidKeyValuePair;
        try parameters.put(key, value);
    }

    return .{
        .bytes = bytes,
        .parameters = parameters,
    };
}

pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void {
    var al = ByteArrayList.init(allocator);
    defer al.deinit();
    var cw = std.io.countingWriter(al.writer());
    var writer = cw.writer();
    try writer.writeIntBig(u32, 0); // length placeholder
    try writer.writeIntBig(u32, ProtocolVersionNumber);
    var it = self.parameters.iterator();
    while (it.next()) |entry| {
        try writer.writeAll(entry.key_ptr.*);
        try writer.writeByte(0);
        try writer.writeAll(entry.value_ptr.*);
        try writer.writeByte(0);
    }
    try writer.writeByte(0);
    std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
    try stream_writer.writeAll(al.items);
}

pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void {
    self.parameters.deinit();
    if (self.bytes != null) {
        allocator.free(self.bytes.?);
    }
}

test "round trip" {
    const allocator = std.testing.allocator;
    var params = SSHashMap.init(allocator);
    try params.put("hello", "postgres");
    var sm = StartupMessage{
        .parameters = params,
    };
    defer sm.deinit(allocator);

    var bal = ByteArrayList.init(allocator);
    defer bal.deinit();
    try sm.write(allocator, bal.writer());

    var fbs = std.io.fixedBufferStream(bal.items);
    var reader = fbs.reader();
    const len = try reader.readIntBig(u32);
    const buf = try allocator.alloc(u8, len - 4);
    defer allocator.free(buf);
    try reader.readNoEof(buf);
    var sm2 = try StartupMessage.read(allocator, buf);
    defer sm2.deinit(allocator);

    try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?);
}