aboutsummaryrefslogtreecommitdiff
path: root/src/proto/proto.zig
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-09-27 23:34:46 +0100
committerMartin Ashby <martin@ashbysoft.com>2023-09-27 23:34:46 +0100
commit747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2 (patch)
tree7115e12e19f684640bd2aad4e5d998e13bbb5484 /src/proto/proto.zig
parent08472c27c77d27ea084e3458842540351c5a5c28 (diff)
downloadpgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.gz
pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.bz2
pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.xz
pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.zip
Add a tagged union for all backend messages.
Move read_message to proto.zig and make it return the tagged union rather than expecting a message type.
Diffstat (limited to 'src/proto/proto.zig')
-rw-r--r--src/proto/proto.zig48
1 files changed, 48 insertions, 0 deletions
diff --git a/src/proto/proto.zig b/src/proto/proto.zig
index df1b717..5e4489d 100644
--- a/src/proto/proto.zig
+++ b/src/proto/proto.zig
@@ -1,3 +1,4 @@
+const std = @import("std");
pub const StartupMessage = @import("startup_message.zig");
pub const AuthenticationRequest = @import("authentication_request.zig");
pub const PasswordMessage = @import("password_message.zig");
@@ -13,6 +14,53 @@ const CopyXResponse = @import("copy_x_response.zig").CopyXResponse;
pub const CopyInResponse = CopyXResponse('G');
pub const CopyOutResponse = CopyXResponse('H');
pub const CopyBothResponse = CopyXResponse('W');
+const enum_from_int = @import("../main.zig").enum_from_int;
+const ProtocolError = @import("../main.zig").ProtocolError;
+
+pub const BackendMessage = union(enum) {
+ AuthenticationRequest: AuthenticationRequest,
+ ErrorResponse: ErrorResponse,
+ ReadyForQuery: ReadyForQuery,
+ ParameterStatus: ParameterStatus,
+ BackendKeyData: BackendKeyData,
+ DataRow: DataRow,
+ RowDescription: RowDescription,
+ CommandComplete: CommandComplete,
+ CopyInResponse: CopyInResponse,
+ CopyOutResponse: CopyOutResponse,
+ CopyBothResponse: CopyBothResponse,
+
+ pub fn deinit(self: *BackendMessage, a: std.mem.Allocator) void {
+ switch (self.*) {
+ inline else => |*sf| {
+ sf.deinit(a);
+ },
+ }
+ }
+};
+
+test {
+ inline for (@typeInfo(BackendMessage).Union.fields) |field| {
+ const msg_type = field.type;
+ if (!@hasDecl(msg_type, "Tag")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .Tag decl");
+ if (!@hasDecl(msg_type, "read")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .read decl");
+ }
+}
+
+pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !BackendMessage {
+ const tag = try stream_reader.readByte();
+ const len = try stream_reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4)));
+ defer allocator.free(buf);
+ try stream_reader.readNoEof(buf);
+ inline for (@typeInfo(BackendMessage).Union.fields) |field| {
+ if (field.type.Tag == tag) {
+ return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf));
+ }
+ } else {
+ return ProtocolError.InvalidMessageType;
+ }
+}
test {
_ = AuthenticationRequest;