diff options
author | Martin Ashby <martin@ashbysoft.com> | 2023-09-27 23:34:46 +0100 |
---|---|---|
committer | Martin Ashby <martin@ashbysoft.com> | 2023-09-27 23:34:46 +0100 |
commit | 747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2 (patch) | |
tree | 7115e12e19f684640bd2aad4e5d998e13bbb5484 /src/proto/proto.zig | |
parent | 08472c27c77d27ea084e3458842540351c5a5c28 (diff) | |
download | pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.gz pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.bz2 pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.xz pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.zip |
Add a tagged union for all backend messages.
Move read_message to proto.zig and make it return the tagged union
rather than expecting a message type.
Diffstat (limited to 'src/proto/proto.zig')
-rw-r--r-- | src/proto/proto.zig | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/src/proto/proto.zig b/src/proto/proto.zig index df1b717..5e4489d 100644 --- a/src/proto/proto.zig +++ b/src/proto/proto.zig @@ -1,3 +1,4 @@ +const std = @import("std"); pub const StartupMessage = @import("startup_message.zig"); pub const AuthenticationRequest = @import("authentication_request.zig"); pub const PasswordMessage = @import("password_message.zig"); @@ -13,6 +14,53 @@ const CopyXResponse = @import("copy_x_response.zig").CopyXResponse; pub const CopyInResponse = CopyXResponse('G'); pub const CopyOutResponse = CopyXResponse('H'); pub const CopyBothResponse = CopyXResponse('W'); +const enum_from_int = @import("../main.zig").enum_from_int; +const ProtocolError = @import("../main.zig").ProtocolError; + +pub const BackendMessage = union(enum) { + AuthenticationRequest: AuthenticationRequest, + ErrorResponse: ErrorResponse, + ReadyForQuery: ReadyForQuery, + ParameterStatus: ParameterStatus, + BackendKeyData: BackendKeyData, + DataRow: DataRow, + RowDescription: RowDescription, + CommandComplete: CommandComplete, + CopyInResponse: CopyInResponse, + CopyOutResponse: CopyOutResponse, + CopyBothResponse: CopyBothResponse, + + pub fn deinit(self: *BackendMessage, a: std.mem.Allocator) void { + switch (self.*) { + inline else => |*sf| { + sf.deinit(a); + }, + } + } +}; + +test { + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + const msg_type = field.type; + if (!@hasDecl(msg_type, "Tag")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .Tag decl"); + if (!@hasDecl(msg_type, "read")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .read decl"); + } +} + +pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !BackendMessage { + const tag = try stream_reader.readByte(); + const len = try stream_reader.readIntBig(u32); + const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); + defer allocator.free(buf); + try stream_reader.readNoEof(buf); + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + if (field.type.Tag == tag) { + return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf)); + } + } else { + return ProtocolError.InvalidMessageType; + } +} test { _ = AuthenticationRequest; |