aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--README.md4
-rw-r--r--build.zig47
-rw-r--r--src/authentication_ok.zig46
-rw-r--r--src/main.zig39
-rw-r--r--src/startup_message.zig85
6 files changed, 223 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ee7098f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+zig-out/
+zig-cache/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..7bb69da
--- /dev/null
+++ b/README.md
@@ -0,0 +1,4 @@
+# pgz
+
+A zig port of [pgx](https://github.com/jackc/pgx)
+
diff --git a/build.zig b/build.zig
new file mode 100644
index 0000000..891c1bb
--- /dev/null
+++ b/build.zig
@@ -0,0 +1,47 @@
+const std = @import("std");
+
+// Although this function looks imperative, note that its job is to
+// declaratively construct a build graph that will be executed by an external
+// runner.
+pub fn build(b: *std.Build) void {
+ // Standard target options allows the person running `zig build` to choose
+ // what target to build for. Here we do not override the defaults, which
+ // means any target is allowed, and the default is native. Other options
+ // for restricting supported target set are available.
+ const target = b.standardTargetOptions(.{});
+
+ // Standard optimization options allow the person running `zig build` to select
+ // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. Here we do not
+ // set a preferred release mode, allowing the user to decide how to optimize.
+ const optimize = b.standardOptimizeOption(.{});
+
+ const lib = b.addStaticLibrary(.{
+ .name = "pgz",
+ // In this case the main source file is merely a path, however, in more
+ // complicated build scripts, this could be a generated file.
+ .root_source_file = .{ .path = "src/main.zig" },
+ .target = target,
+ .optimize = optimize,
+ });
+
+ // This declares intent for the library to be installed into the standard
+ // location when the user invokes the "install" step (the default step when
+ // running `zig build`).
+ b.installArtifact(lib);
+
+ // Creates a step for unit testing. This only builds the test executable
+ // but does not run it.
+ const main_tests = b.addTest(.{
+ .root_source_file = .{ .path = "src/main.zig" },
+ .target = target,
+ .optimize = optimize,
+ });
+
+ const run_main_tests = b.addRunArtifact(main_tests);
+
+ // This creates a build step. It will be visible in the `zig build --help` menu,
+ // and can be selected like this: `zig build test`
+ // This will evaluate the `test` step rather than the default, which is "install".
+ const test_step = b.step("test", "Run library tests");
+ test_step.dependOn(&run_main_tests.step);
+}
diff --git a/src/authentication_ok.zig b/src/authentication_ok.zig
new file mode 100644
index 0000000..3c31375
--- /dev/null
+++ b/src/authentication_ok.zig
@@ -0,0 +1,46 @@
+const std = @import("std");
+const ProtocolError = @import("main.zig").ProtocolError;
+const AuthType = @import("main.zig").AuthType;
+const enum_from_int = @import("main.zig").enum_from_int;
+const ClientError = @import("main.zig").ClientError;
+const AuthenticationOk = @This();
+const ByteArrayList = std.ArrayList(u8);
+
+pub const Tag: u8 = 'R';
+
+pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationOk {
+ if (b.len != 4) return ProtocolError.InvalidMessageLength;
+
+ const auth_type = enum_from_int(AuthType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType;
+ if (auth_type != AuthType.AuthTypeOk) return ProtocolError.InvalidAuthType;
+ return .{};
+}
+
+pub fn write(_: AuthenticationOk, _: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, 8);
+ try stream_writer.writeIntBig(u32, @intFromEnum(AuthType.AuthTypeOk));
+}
+
+pub fn deinit(_: *AuthenticationOk, _: std.mem.Allocator) void {}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = AuthenticationOk{};
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try AuthenticationOk.read(allocator, buf);
+ defer sm2.deinit(allocator);
+}
diff --git a/src/main.zig b/src/main.zig
new file mode 100644
index 0000000..4bed468
--- /dev/null
+++ b/src/main.zig
@@ -0,0 +1,39 @@
+const std = @import("std");
+const testing = std.testing;
+const StartupMessage = @import("startup_message.zig");
+const AuthenticationOk = @import("authentication_ok.zig");
+
+pub const ProtocolError = error{
+ InvalidProtocolVersion,
+ InvalidKeyValuePair,
+ InvalidMessageLength,
+ InvalidAuthType,
+};
+
+pub const ClientError = error{
+ UnsupportedAuthType,
+};
+
+pub const AuthType = enum(u32) {
+ AuthTypeOk = 0,
+};
+
+// Fallible version of enumFromInt
+pub fn enum_from_int(comptime e: type, i: anytype) ?e {
+ const enum_ti = @typeInfo(e);
+ if (enum_ti != .Enum) @compileError("e should be an enum but instead it's a " ++ @typeName(e));
+ const ei = enum_ti.Enum;
+ if (@TypeOf(i) != ei.tag_type) @compileError("i should be of type " ++ @typeName(e) ++ " but instead it's " ++ @typeName(@TypeOf(i)));
+ inline for (ei.fields) |field| {
+ if (field.value == i) {
+ return @enumFromInt(i);
+ }
+ } else {
+ return null;
+ }
+}
+
+test {
+ _ = StartupMessage;
+ _ = AuthenticationOk;
+}
diff --git a/src/startup_message.zig b/src/startup_message.zig
new file mode 100644
index 0000000..b5031c1
--- /dev/null
+++ b/src/startup_message.zig
@@ -0,0 +1,85 @@
+const std = @import("std");
+const ProtocolError = @import("main.zig").ProtocolError;
+const SSHashMap = std.StringHashMap([]const u8);
+const ByteArrayList = std.ArrayList(u8);
+
+const StartupMessage = @This();
+
+const ProtocolVersionNumber: u32 = 196608; // 3.0
+
+bytes: ?[]const u8 = null, // Owned
+parameters: SSHashMap,
+
+// message length should already have been read, b should contain the payload
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage {
+ if (b.len < 4) return ProtocolError.InvalidMessageLength;
+
+ var bytes = try allocator.dupe(u8, b);
+ errdefer allocator.free(bytes);
+ const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]);
+ if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion;
+
+ var parameters = SSHashMap.init(allocator);
+ var it = std.mem.splitScalar(u8, bytes[4..], 0);
+ while (it.next()) |next| {
+ const key = next;
+ const value = it.next() orelse return ProtocolError.InvalidKeyValuePair;
+ try parameters.put(key, value);
+ }
+
+ return .{
+ .bytes = bytes,
+ .parameters = parameters,
+ };
+}
+
+pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ var al = ByteArrayList.init(allocator);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeIntBig(u32, ProtocolVersionNumber);
+ var it = self.parameters.iterator();
+ while (it.next()) |entry| {
+ try writer.writeAll(entry.key_ptr.*);
+ try writer.writeByte(0);
+ try writer.writeAll(entry.value_ptr.*);
+ try writer.writeByte(0);
+ }
+ try writer.writeByte(0);
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+
+pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void {
+ self.parameters.deinit();
+ if (self.bytes != null) {
+ allocator.free(self.bytes.?);
+ }
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var params = SSHashMap.init(allocator);
+ try params.put("hello", "postgres");
+ var sm = StartupMessage{
+ .parameters = params,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try StartupMessage.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?);
+}