aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-09-22 22:55:47 +0100
committerMartin Ashby <martin@ashbysoft.com>2023-09-22 22:55:47 +0100
commit6202dd351c83e9e54bffdbff844414b4dd763eba (patch)
treef9b0ce153a690e09a31c98f305295af29b682452 /src
downloadpgz-6202dd351c83e9e54bffdbff844414b4dd763eba.tar.gz
pgz-6202dd351c83e9e54bffdbff844414b4dd763eba.tar.bz2
pgz-6202dd351c83e9e54bffdbff844414b4dd763eba.tar.xz
pgz-6202dd351c83e9e54bffdbff844414b4dd763eba.zip
Initial: port of pgx to zig (so, pgz)
Starting with message structures, so far we have startup_message and authentication_ok
Diffstat (limited to 'src')
-rw-r--r--src/authentication_ok.zig46
-rw-r--r--src/main.zig39
-rw-r--r--src/startup_message.zig85
3 files changed, 170 insertions, 0 deletions
diff --git a/src/authentication_ok.zig b/src/authentication_ok.zig
new file mode 100644
index 0000000..3c31375
--- /dev/null
+++ b/src/authentication_ok.zig
@@ -0,0 +1,46 @@
+const std = @import("std");
+const ProtocolError = @import("main.zig").ProtocolError;
+const AuthType = @import("main.zig").AuthType;
+const enum_from_int = @import("main.zig").enum_from_int;
+const ClientError = @import("main.zig").ClientError;
+const AuthenticationOk = @This();
+const ByteArrayList = std.ArrayList(u8);
+
+pub const Tag: u8 = 'R';
+
+pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationOk {
+ if (b.len != 4) return ProtocolError.InvalidMessageLength;
+
+ const auth_type = enum_from_int(AuthType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType;
+ if (auth_type != AuthType.AuthTypeOk) return ProtocolError.InvalidAuthType;
+ return .{};
+}
+
+pub fn write(_: AuthenticationOk, _: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, 8);
+ try stream_writer.writeIntBig(u32, @intFromEnum(AuthType.AuthTypeOk));
+}
+
+pub fn deinit(_: *AuthenticationOk, _: std.mem.Allocator) void {}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = AuthenticationOk{};
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try AuthenticationOk.read(allocator, buf);
+ defer sm2.deinit(allocator);
+}
diff --git a/src/main.zig b/src/main.zig
new file mode 100644
index 0000000..4bed468
--- /dev/null
+++ b/src/main.zig
@@ -0,0 +1,39 @@
+const std = @import("std");
+const testing = std.testing;
+const StartupMessage = @import("startup_message.zig");
+const AuthenticationOk = @import("authentication_ok.zig");
+
+pub const ProtocolError = error{
+ InvalidProtocolVersion,
+ InvalidKeyValuePair,
+ InvalidMessageLength,
+ InvalidAuthType,
+};
+
+pub const ClientError = error{
+ UnsupportedAuthType,
+};
+
+pub const AuthType = enum(u32) {
+ AuthTypeOk = 0,
+};
+
+// Fallible version of enumFromInt
+pub fn enum_from_int(comptime e: type, i: anytype) ?e {
+ const enum_ti = @typeInfo(e);
+ if (enum_ti != .Enum) @compileError("e should be an enum but instead it's a " ++ @typeName(e));
+ const ei = enum_ti.Enum;
+ if (@TypeOf(i) != ei.tag_type) @compileError("i should be of type " ++ @typeName(e) ++ " but instead it's " ++ @typeName(@TypeOf(i)));
+ inline for (ei.fields) |field| {
+ if (field.value == i) {
+ return @enumFromInt(i);
+ }
+ } else {
+ return null;
+ }
+}
+
+test {
+ _ = StartupMessage;
+ _ = AuthenticationOk;
+}
diff --git a/src/startup_message.zig b/src/startup_message.zig
new file mode 100644
index 0000000..b5031c1
--- /dev/null
+++ b/src/startup_message.zig
@@ -0,0 +1,85 @@
+const std = @import("std");
+const ProtocolError = @import("main.zig").ProtocolError;
+const SSHashMap = std.StringHashMap([]const u8);
+const ByteArrayList = std.ArrayList(u8);
+
+const StartupMessage = @This();
+
+const ProtocolVersionNumber: u32 = 196608; // 3.0
+
+bytes: ?[]const u8 = null, // Owned
+parameters: SSHashMap,
+
+// message length should already have been read, b should contain the payload
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage {
+ if (b.len < 4) return ProtocolError.InvalidMessageLength;
+
+ var bytes = try allocator.dupe(u8, b);
+ errdefer allocator.free(bytes);
+ const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]);
+ if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion;
+
+ var parameters = SSHashMap.init(allocator);
+ var it = std.mem.splitScalar(u8, bytes[4..], 0);
+ while (it.next()) |next| {
+ const key = next;
+ const value = it.next() orelse return ProtocolError.InvalidKeyValuePair;
+ try parameters.put(key, value);
+ }
+
+ return .{
+ .bytes = bytes,
+ .parameters = parameters,
+ };
+}
+
+pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ var al = ByteArrayList.init(allocator);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeIntBig(u32, ProtocolVersionNumber);
+ var it = self.parameters.iterator();
+ while (it.next()) |entry| {
+ try writer.writeAll(entry.key_ptr.*);
+ try writer.writeByte(0);
+ try writer.writeAll(entry.value_ptr.*);
+ try writer.writeByte(0);
+ }
+ try writer.writeByte(0);
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+
+pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void {
+ self.parameters.deinit();
+ if (self.bytes != null) {
+ allocator.free(self.bytes.?);
+ }
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var params = SSHashMap.init(allocator);
+ try params.put("hello", "postgres");
+ var sm = StartupMessage{
+ .parameters = params,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try StartupMessage.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?);
+}