From 747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Wed, 27 Sep 2023 23:34:46 +0100 Subject: Add a tagged union for all backend messages. Move read_message to proto.zig and make it return the tagged union rather than expecting a message type. --- src/conn/conn.zig | 74 ++++++++++++++++++++++-------------- src/main.zig | 16 ++------ src/proto/authentication_request.zig | 6 ++- src/proto/proto.zig | 48 +++++++++++++++++++++++ src/proto/row_description.zig | 2 +- 5 files changed, 103 insertions(+), 43 deletions(-) (limited to 'src') diff --git a/src/conn/conn.zig b/src/conn/conn.zig index 4d62f57..d97eca2 100644 --- a/src/conn/conn.zig +++ b/src/conn/conn.zig @@ -2,8 +2,11 @@ const std = @import("std"); const log = std.log.scoped(.pgz); const SSHashMap = std.StringHashMap([]const u8); const Config = @import("config.zig"); -const Proto = @import("../proto/proto.zig"); -const read_message = @import("../main.zig").read_message; +const proto = @import("../proto/proto.zig"); +const StartupMessage = proto.StartupMessage; +const PasswordMessage = proto.PasswordMessage; +const BackendMessage = proto.BackendMessage; +const read_message = proto.read_message; const ProtocolError = @import("../main.zig").ProtocolError; const ServerError = @import("../main.zig").ServerError; const ClientError = @import("../main.zig").ClientError; @@ -35,63 +38,52 @@ pub fn connect(config: Config) !Conn { }; errdefer res.deinit(); var writer = stream.writer(); - var dr = diagnosticReader(10000, stream.reader()); + var dr = diagnosticReader(100, stream.reader()); var reader = dr.reader(); var params = SSHashMap.init(allocator); try params.put("user", config.user); if (config.database) |database| try params.put("database", database); - var sm = Proto.StartupMessage{ + var sm = StartupMessage{ .parameters = params, }; defer sm.deinit(allocator); try sm.write(allocator, writer); lp: while (true) { - const response_type = try reader.readByte(); - switch (response_type) { - Proto.ErrorResponse.Tag => { - var err = try read_message(Proto.ErrorResponse, allocator, reader); - defer err.deinit(allocator); + var anymsg = try read_message(allocator, reader); + defer anymsg.deinit(allocator); + switch (anymsg) { + .ErrorResponse => |err| { log.err("Error connecting to server {any}", .{err}); return ServerError.ErrorResponse; }, - Proto.AuthenticationRequest.Tag => { - var ar = try read_message(Proto.AuthenticationRequest, allocator, reader); - defer ar.deinit(allocator); - // TODO handle the authentication request + .AuthenticationRequest => |ar| { switch (ar.inner_type) { .AuthRequestTypeOk => {}, // fine do nothing! .AuthRequestTypeCleartextPassword => { if (config.password) |password| { - const pm = Proto.PasswordMessage{ .password = password }; + const pm = PasswordMessage{ .password = password }; try pm.write(allocator, writer); } else { return ClientError.NoPasswordSupplied; } }, } - log.info("authentication request", .{}); }, - Proto.ReadyForQuery.Tag => { - var rfq = try read_message(Proto.ReadyForQuery, allocator, reader); - defer rfq.deinit(allocator); + .ReadyForQuery => |rfq| { // TODO do something about transaction state? res.status = .connStatusIdle; - log.info("ready for query", .{}); + log.info("ready for query {any}", .{rfq}); break :lp; }, - Proto.ParameterStatus.Tag => { - var ps = try read_message(Proto.ParameterStatus, allocator, reader); - defer ps.deinit(allocator); + .ParameterStatus => |ps| { // TODO Handle this somehow? log.info("ParameterStatus: {s}:{s}", .{ ps.name, ps.value }); }, - Proto.BackendKeyData.Tag => { - var bkd = try read_message(Proto.BackendKeyData, allocator, reader); - defer bkd.deinit(allocator); + .BackendKeyData => |bkd| { log.info("BackendKeyData process_id {} secret_key {}", .{ bkd.process_id, bkd.secret_key }); }, - else => { - log.err("unhandled message type [{c}]", .{response_type}); + else => |response_type| { + log.err("unhandled message type [{}]", .{response_type}); const diag = try dr.get(allocator); defer allocator.free(diag); log.err("diag [{s}]", .{diag}); @@ -102,10 +94,36 @@ pub fn connect(config: Config) !Conn { return res; } -fn deinit(self: *Conn) void { +pub fn deinit(self: *Conn) void { self.stream.close(); } +// How to handle this ... +// The Go code relies on polymorphism to generically read any message type. +// I _could_ have a tagged union type thing +// pub const ResultIterator = struct { +// conn: *Conn, +// command_concluded: bool = false, +// // NextRow advances the ResultReader to the next row and returns true if a row is available. +// pub fn next_row(self: *ResultIterator) bool { +// // TODO implement +// var reader = self.conn.stream.reader(); +// switch (try reader.readByte()) { +// case +// } +// return false; +// } +// }; + +// pub const MultiResultIterator = struct { +// conn: *Conn, +// fn next() ? +// }; + +// pub fn exec(self: *Conn) { + +// } + test "connect unix" { // must have a local postgres runnning // TODO maybe use docker to start one? diff --git a/src/main.zig b/src/main.zig index 818b3b7..9ce8de9 100644 --- a/src/main.zig +++ b/src/main.zig @@ -3,8 +3,9 @@ const testing = std.testing; pub const ProtocolError = error{ InvalidProtocolVersion, - InvalidKeyValuePair, + InvalidMessageType, InvalidMessageLength, + InvalidKeyValuePair, InvalidAuthType, MissingField, WrongMessageType, @@ -41,17 +42,6 @@ pub fn enum_from_int(comptime e: type, i: anytype) ?e { } } -// Tag should already have been read in order to determine msg_type! -pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, stream_reader: anytype) !msg_type { - if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!"); - if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!"); - const len = try stream_reader.readIntBig(u32); - const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); - defer allocator.free(buf); - try stream_reader.readNoEof(buf); - return try msg_type.read(allocator, buf); -} - pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) { return .{ .child_reader = base_reader }; } @@ -81,7 +71,7 @@ pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type { } // Caller frees - pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 { + pub fn get(self: @This(), allocator: std.mem.Allocator) ![]u8 { var buf = try allocator.alloc(u8, n); errdefer allocator.free(buf); @memcpy(buf[0..(n - self.pos)], self.ring[self.pos..n]); diff --git a/src/proto/authentication_request.zig b/src/proto/authentication_request.zig index 9203482..3ea5cd1 100644 --- a/src/proto/authentication_request.zig +++ b/src/proto/authentication_request.zig @@ -30,7 +30,11 @@ pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest { log.err("invalid message length, expected 4 got {}", .{b.len}); return ProtocolError.InvalidMessageLength; } - const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; + const auth_type_int = std.mem.readIntBig(u32, b[0..4]); + const inner_type = enum_from_int(InnerAuthRequestType, auth_type_int) orelse { + log.err("Unsupported auth type {}", .{auth_type_int}); + return ClientError.UnsupportedAuthType; + }; var inner: InnerAuthRequest = switch (inner_type) { .AuthRequestTypeOk => .{ .ok = AuthRequestOk{} }, .AuthRequestTypeCleartextPassword => .{ .cleartext_password = AuthRequestCleartextPassword{} }, diff --git a/src/proto/proto.zig b/src/proto/proto.zig index df1b717..5e4489d 100644 --- a/src/proto/proto.zig +++ b/src/proto/proto.zig @@ -1,3 +1,4 @@ +const std = @import("std"); pub const StartupMessage = @import("startup_message.zig"); pub const AuthenticationRequest = @import("authentication_request.zig"); pub const PasswordMessage = @import("password_message.zig"); @@ -13,6 +14,53 @@ const CopyXResponse = @import("copy_x_response.zig").CopyXResponse; pub const CopyInResponse = CopyXResponse('G'); pub const CopyOutResponse = CopyXResponse('H'); pub const CopyBothResponse = CopyXResponse('W'); +const enum_from_int = @import("../main.zig").enum_from_int; +const ProtocolError = @import("../main.zig").ProtocolError; + +pub const BackendMessage = union(enum) { + AuthenticationRequest: AuthenticationRequest, + ErrorResponse: ErrorResponse, + ReadyForQuery: ReadyForQuery, + ParameterStatus: ParameterStatus, + BackendKeyData: BackendKeyData, + DataRow: DataRow, + RowDescription: RowDescription, + CommandComplete: CommandComplete, + CopyInResponse: CopyInResponse, + CopyOutResponse: CopyOutResponse, + CopyBothResponse: CopyBothResponse, + + pub fn deinit(self: *BackendMessage, a: std.mem.Allocator) void { + switch (self.*) { + inline else => |*sf| { + sf.deinit(a); + }, + } + } +}; + +test { + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + const msg_type = field.type; + if (!@hasDecl(msg_type, "Tag")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .Tag decl"); + if (!@hasDecl(msg_type, "read")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .read decl"); + } +} + +pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !BackendMessage { + const tag = try stream_reader.readByte(); + const len = try stream_reader.readIntBig(u32); + const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); + defer allocator.free(buf); + try stream_reader.readNoEof(buf); + inline for (@typeInfo(BackendMessage).Union.fields) |field| { + if (field.type.Tag == tag) { + return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf)); + } + } else { + return ProtocolError.InvalidMessageType; + } +} test { _ = AuthenticationRequest; diff --git a/src/proto/row_description.zig b/src/proto/row_description.zig index a0e8810..b8105e2 100644 --- a/src/proto/row_description.zig +++ b/src/proto/row_description.zig @@ -34,7 +34,7 @@ pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { for (0..n_fields) |i| { const name_start = fbs.pos; try reader.skipUntilDelimiterOrEof(0); - const name_end = fbs.pos-1; + const name_end = fbs.pos - 1; const name = res.buf.?[name_start..name_end]; const field = Field{ .name = name, -- cgit v1.2.3-ZIG