1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
const std = @import("std");
const ProtocolError = @import("main.zig").ProtocolError;
const SSHashMap = std.StringHashMap([]const u8);
const ByteArrayList = std.ArrayList(u8);
const StartupMessage = @This();
const ProtocolVersionNumber: u32 = 196608; // 3.0
bytes: ?[]const u8 = null, // Owned
parameters: SSHashMap,
// message length should already have been read, b should contain the payload
pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage {
if (b.len < 4) return ProtocolError.InvalidMessageLength;
var bytes = try allocator.dupe(u8, b);
errdefer allocator.free(bytes);
const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]);
if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion;
var parameters = SSHashMap.init(allocator);
var it = std.mem.splitScalar(u8, bytes[4..], 0);
while (it.next()) |next| {
const key = next;
const value = it.next() orelse return ProtocolError.InvalidKeyValuePair;
try parameters.put(key, value);
}
return .{
.bytes = bytes,
.parameters = parameters,
};
}
pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void {
var al = ByteArrayList.init(allocator);
defer al.deinit();
var cw = std.io.countingWriter(al.writer());
var writer = cw.writer();
try writer.writeIntBig(u32, 0); // length placeholder
try writer.writeIntBig(u32, ProtocolVersionNumber);
var it = self.parameters.iterator();
while (it.next()) |entry| {
try writer.writeAll(entry.key_ptr.*);
try writer.writeByte(0);
try writer.writeAll(entry.value_ptr.*);
try writer.writeByte(0);
}
try writer.writeByte(0);
std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
try stream_writer.writeAll(al.items);
}
pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void {
self.parameters.deinit();
if (self.bytes != null) {
allocator.free(self.bytes.?);
}
}
test "round trip" {
const allocator = std.testing.allocator;
var params = SSHashMap.init(allocator);
try params.put("hello", "postgres");
var sm = StartupMessage{
.parameters = params,
};
defer sm.deinit(allocator);
var bal = ByteArrayList.init(allocator);
defer bal.deinit();
try sm.write(allocator, bal.writer());
var fbs = std.io.fixedBufferStream(bal.items);
var reader = fbs.reader();
const len = try reader.readIntBig(u32);
const buf = try allocator.alloc(u8, len - 4);
defer allocator.free(buf);
try reader.readNoEof(buf);
var sm2 = try StartupMessage.read(allocator, buf);
defer sm2.deinit(allocator);
try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?);
}
|