From 747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Wed, 27 Sep 2023 23:34:46 +0100 Subject: Add a tagged union for all backend messages. Move read_message to proto.zig and make it return the tagged union rather than expecting a message type. --- src/proto/authentication_request.zig | 6 ++++- src/proto/proto.zig | 48 ++++++++++++++++++++++++++++++++++++ src/proto/row_description.zig | 2 +- 3 files changed, 54 insertions(+), 2 deletions(-) (limited to 'src/proto') diff --git a/src/proto/authentication_request.zig b/src/proto/authentication_request.zig index 9203482..3ea5cd1 100644 --- a/src/proto/authentication_request.zig +++ b/src/proto/authentication_request.zig @@ -30,7 +30,11 @@ pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest { log.err("invalid message length, expected 4 got {}", .{b.len}); return ProtocolError.InvalidMessageLength; } - const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; + const auth_type_int = std.mem.readIntBig(u32, b[0..4]); + const inner_type = enum_from_int(InnerAuthRequestType, auth_type_int) orelse { + log.err("Unsupported auth type {}", .{auth_type_int}); + return ClientError.UnsupportedAuthType; + }; var inner: InnerAuthRequest = switch (inner_type) { .AuthRequestTypeOk => .{ .ok = AuthRequestOk{} }, .AuthRequestTypeCleartextPassword => .{ .cleartext_password = AuthRequestCleartextPassword{} }, diff --git a/src/proto/proto.zig b/src/proto/proto.zig index df1b717..5e4489d 100644 --- a/src/proto/proto.zig +++ b/src/proto/proto.zig @@ -1,3 +1,4 @@ +const std = @import("std"); pub const StartupMessage = @import("startup_message.zig"); pub const AuthenticationRequest = @import("authentication_request.zig"); pub const PasswordMessage = @import("password_message.zig"); @@ -13,6 +14,53 @@ const CopyXResponse = @import("copy_x_response.zig").CopyXResponse; pub const CopyInResponse = CopyXResponse('G'); pub const CopyOutResponse = CopyXResponse('H'); pub const CopyBothResponse = CopyXResponse('W'); +const enum_from_int = @import("../main.zig").enum_from_int; +const ProtocolError = @import("../main.zig").ProtocolError; + +pub const BackendMessage = union(enum) { + AuthenticationRequest: AuthenticationRequest, + ErrorResponse: ErrorResponse, + ReadyForQuery: ReadyForQuery, + ParameterStatus: ParameterStatus, + BackendKeyData: BackendKeyData, + DataRow: DataRow, + RowDescription: RowDescription, + CommandComplete: CommandComplete, + CopyInResponse: CopyInResponse, + CopyOutResponse: CopyOutResponse, + CopyBothResponse: CopyBothResponse, + + pub fn deinit(self: *BackendMessage, a: std.mem.Allocator) void { + switch (self.*) { + inline else => |*sf| { + sf.deinit(a); + }, + } + } +}; + +test { + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + const msg_type = field.type; + if (!@hasDecl(msg_type, "Tag")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .Tag decl"); + if (!@hasDecl(msg_type, "read")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .read decl"); + } +} + +pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !BackendMessage { + const tag = try stream_reader.readByte(); + const len = try stream_reader.readIntBig(u32); + const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); + defer allocator.free(buf); + try stream_reader.readNoEof(buf); + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + if (field.type.Tag == tag) { + return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf)); + } + } else { + return ProtocolError.InvalidMessageType; + } +} test { _ = AuthenticationRequest; diff --git a/src/proto/row_description.zig b/src/proto/row_description.zig index a0e8810..b8105e2 100644 --- a/src/proto/row_description.zig +++ b/src/proto/row_description.zig @@ -34,7 +34,7 @@ pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { for (0..n_fields) |i| { const name_start = fbs.pos; try reader.skipUntilDelimiterOrEof(0); - const name_end = fbs.pos-1; + const name_end = fbs.pos - 1; const name = res.buf.?[name_start..name_end]; const field = Field{ .name = name, -- cgit v1.2.3-ZIG