aboutsummaryrefslogtreecommitdiff
path: root/src/proto/startup_message.zig
diff options
context:
space:
mode:
Diffstat (limited to 'src/proto/startup_message.zig')
-rw-r--r--src/proto/startup_message.zig85
1 files changed, 85 insertions, 0 deletions
diff --git a/src/proto/startup_message.zig b/src/proto/startup_message.zig
new file mode 100644
index 0000000..8224bdb
--- /dev/null
+++ b/src/proto/startup_message.zig
@@ -0,0 +1,85 @@
+const std = @import("std");
+const ProtocolError = @import("../main.zig").ProtocolError;
+const SSHashMap = std.StringHashMap([]const u8);
+const ByteArrayList = std.ArrayList(u8);
+
+const StartupMessage = @This();
+
+const ProtocolVersionNumber: u32 = 196608; // 3.0
+
+bytes: ?[]const u8 = null, // Owned
+parameters: SSHashMap,
+
+// message length should already have been read, b should contain the payload
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage {
+ if (b.len < 4) return ProtocolError.InvalidMessageLength;
+
+ var bytes = try allocator.dupe(u8, b);
+ errdefer allocator.free(bytes);
+ const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]);
+ if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion;
+
+ var parameters = SSHashMap.init(allocator);
+ var it = std.mem.splitScalar(u8, bytes[4..], 0);
+ while (it.next()) |next| {
+ const key = next;
+ const value = it.next() orelse return ProtocolError.InvalidKeyValuePair;
+ try parameters.put(key, value);
+ }
+
+ return .{
+ .bytes = bytes,
+ .parameters = parameters,
+ };
+}
+
+pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ var al = ByteArrayList.init(allocator);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeIntBig(u32, ProtocolVersionNumber);
+ var it = self.parameters.iterator();
+ while (it.next()) |entry| {
+ try writer.writeAll(entry.key_ptr.*);
+ try writer.writeByte(0);
+ try writer.writeAll(entry.value_ptr.*);
+ try writer.writeByte(0);
+ }
+ try writer.writeByte(0);
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+
+pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void {
+ self.parameters.deinit();
+ if (self.bytes != null) {
+ allocator.free(self.bytes.?);
+ }
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var params = SSHashMap.init(allocator);
+ try params.put("hello", "postgres");
+ var sm = StartupMessage{
+ .parameters = params,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try StartupMessage.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?);
+}