From 747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Wed, 27 Sep 2023 23:34:46 +0100 Subject: Add a tagged union for all backend messages. Move read_message to proto.zig and make it return the tagged union rather than expecting a message type. --- src/proto/proto.zig | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) (limited to 'src/proto/proto.zig') diff --git a/src/proto/proto.zig b/src/proto/proto.zig index df1b717..5e4489d 100644 --- a/src/proto/proto.zig +++ b/src/proto/proto.zig @@ -1,3 +1,4 @@ +const std = @import("std"); pub const StartupMessage = @import("startup_message.zig"); pub const AuthenticationRequest = @import("authentication_request.zig"); pub const PasswordMessage = @import("password_message.zig"); @@ -13,6 +14,53 @@ const CopyXResponse = @import("copy_x_response.zig").CopyXResponse; pub const CopyInResponse = CopyXResponse('G'); pub const CopyOutResponse = CopyXResponse('H'); pub const CopyBothResponse = CopyXResponse('W'); +const enum_from_int = @import("../main.zig").enum_from_int; +const ProtocolError = @import("../main.zig").ProtocolError; + +pub const BackendMessage = union(enum) { + AuthenticationRequest: AuthenticationRequest, + ErrorResponse: ErrorResponse, + ReadyForQuery: ReadyForQuery, + ParameterStatus: ParameterStatus, + BackendKeyData: BackendKeyData, + DataRow: DataRow, + RowDescription: RowDescription, + CommandComplete: CommandComplete, + CopyInResponse: CopyInResponse, + CopyOutResponse: CopyOutResponse, + CopyBothResponse: CopyBothResponse, + + pub fn deinit(self: *BackendMessage, a: std.mem.Allocator) void { + switch (self.*) { + inline else => |*sf| { + sf.deinit(a); + }, + } + } +}; + +test { + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + const msg_type = field.type; + if (!@hasDecl(msg_type, "Tag")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .Tag decl"); + if (!@hasDecl(msg_type, "read")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .read decl"); + } +} + +pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !BackendMessage { + const tag = try stream_reader.readByte(); + const len = try stream_reader.readIntBig(u32); + const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); + defer allocator.free(buf); + try stream_reader.readNoEof(buf); + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + if (field.type.Tag == tag) { + return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf)); + } + } else { + return ProtocolError.InvalidMessageType; + } +} test { _ = AuthenticationRequest; -- cgit v1.2.3-ZIG