aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-09-27 23:34:46 +0100
committerMartin Ashby <martin@ashbysoft.com>2023-09-27 23:34:46 +0100
commit747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2 (patch)
tree7115e12e19f684640bd2aad4e5d998e13bbb5484
parent08472c27c77d27ea084e3458842540351c5a5c28 (diff)
downloadpgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.gz
pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.bz2
pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.tar.xz
pgz-747c6e55cbe2283fd85ef8cd930e88d2bb0b7db2.zip
Add a tagged union for all backend messages.
Move read_message to proto.zig and make it return the tagged union rather than expecting a message type.
-rw-r--r--src/conn/conn.zig74
-rw-r--r--src/main.zig16
-rw-r--r--src/proto/authentication_request.zig6
-rw-r--r--src/proto/proto.zig48
-rw-r--r--src/proto/row_description.zig2
5 files changed, 103 insertions, 43 deletions
diff --git a/src/conn/conn.zig b/src/conn/conn.zig
index 4d62f57..d97eca2 100644
--- a/src/conn/conn.zig
+++ b/src/conn/conn.zig
@@ -2,8 +2,11 @@ const std = @import("std");
const log = std.log.scoped(.pgz);
const SSHashMap = std.StringHashMap([]const u8);
const Config = @import("config.zig");
-const Proto = @import("../proto/proto.zig");
-const read_message = @import("../main.zig").read_message;
+const proto = @import("../proto/proto.zig");
+const StartupMessage = proto.StartupMessage;
+const PasswordMessage = proto.PasswordMessage;
+const BackendMessage = proto.BackendMessage;
+const read_message = proto.read_message;
const ProtocolError = @import("../main.zig").ProtocolError;
const ServerError = @import("../main.zig").ServerError;
const ClientError = @import("../main.zig").ClientError;
@@ -35,63 +38,52 @@ pub fn connect(config: Config) !Conn {
};
errdefer res.deinit();
var writer = stream.writer();
- var dr = diagnosticReader(10000, stream.reader());
+ var dr = diagnosticReader(100, stream.reader());
var reader = dr.reader();
var params = SSHashMap.init(allocator);
try params.put("user", config.user);
if (config.database) |database| try params.put("database", database);
- var sm = Proto.StartupMessage{
+ var sm = StartupMessage{
.parameters = params,
};
defer sm.deinit(allocator);
try sm.write(allocator, writer);
lp: while (true) {
- const response_type = try reader.readByte();
- switch (response_type) {
- Proto.ErrorResponse.Tag => {
- var err = try read_message(Proto.ErrorResponse, allocator, reader);
- defer err.deinit(allocator);
+ var anymsg = try read_message(allocator, reader);
+ defer anymsg.deinit(allocator);
+ switch (anymsg) {
+ .ErrorResponse => |err| {
log.err("Error connecting to server {any}", .{err});
return ServerError.ErrorResponse;
},
- Proto.AuthenticationRequest.Tag => {
- var ar = try read_message(Proto.AuthenticationRequest, allocator, reader);
- defer ar.deinit(allocator);
- // TODO handle the authentication request
+ .AuthenticationRequest => |ar| {
switch (ar.inner_type) {
.AuthRequestTypeOk => {}, // fine do nothing!
.AuthRequestTypeCleartextPassword => {
if (config.password) |password| {
- const pm = Proto.PasswordMessage{ .password = password };
+ const pm = PasswordMessage{ .password = password };
try pm.write(allocator, writer);
} else {
return ClientError.NoPasswordSupplied;
}
},
}
- log.info("authentication request", .{});
},
- Proto.ReadyForQuery.Tag => {
- var rfq = try read_message(Proto.ReadyForQuery, allocator, reader);
- defer rfq.deinit(allocator);
+ .ReadyForQuery => |rfq| {
// TODO do something about transaction state?
res.status = .connStatusIdle;
- log.info("ready for query", .{});
+ log.info("ready for query {any}", .{rfq});
break :lp;
},
- Proto.ParameterStatus.Tag => {
- var ps = try read_message(Proto.ParameterStatus, allocator, reader);
- defer ps.deinit(allocator);
+ .ParameterStatus => |ps| {
// TODO Handle this somehow?
log.info("ParameterStatus: {s}:{s}", .{ ps.name, ps.value });
},
- Proto.BackendKeyData.Tag => {
- var bkd = try read_message(Proto.BackendKeyData, allocator, reader);
- defer bkd.deinit(allocator);
+ .BackendKeyData => |bkd| {
log.info("BackendKeyData process_id {} secret_key {}", .{ bkd.process_id, bkd.secret_key });
},
- else => {
- log.err("unhandled message type [{c}]", .{response_type});
+ else => |response_type| {
+ log.err("unhandled message type [{}]", .{response_type});
const diag = try dr.get(allocator);
defer allocator.free(diag);
log.err("diag [{s}]", .{diag});
@@ -102,10 +94,36 @@ pub fn connect(config: Config) !Conn {
return res;
}
-fn deinit(self: *Conn) void {
+pub fn deinit(self: *Conn) void {
self.stream.close();
}
+// How to handle this ...
+// The Go code relies on polymorphism to generically read any message type.
+// I _could_ have a tagged union type thing
+// pub const ResultIterator = struct {
+// conn: *Conn,
+// command_concluded: bool = false,
+// // NextRow advances the ResultReader to the next row and returns true if a row is available.
+// pub fn next_row(self: *ResultIterator) bool {
+// // TODO implement
+// var reader = self.conn.stream.reader();
+// switch (try reader.readByte()) {
+// case
+// }
+// return false;
+// }
+// };
+
+// pub const MultiResultIterator = struct {
+// conn: *Conn,
+// fn next() ?
+// };
+
+// pub fn exec(self: *Conn) {
+
+// }
+
test "connect unix" {
// must have a local postgres runnning
// TODO maybe use docker to start one?
diff --git a/src/main.zig b/src/main.zig
index 818b3b7..9ce8de9 100644
--- a/src/main.zig
+++ b/src/main.zig
@@ -3,8 +3,9 @@ const testing = std.testing;
pub const ProtocolError = error{
InvalidProtocolVersion,
- InvalidKeyValuePair,
+ InvalidMessageType,
InvalidMessageLength,
+ InvalidKeyValuePair,
InvalidAuthType,
MissingField,
WrongMessageType,
@@ -41,17 +42,6 @@ pub fn enum_from_int(comptime e: type, i: anytype) ?e {
}
}
-// Tag should already have been read in order to determine msg_type!
-pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, stream_reader: anytype) !msg_type {
- if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!");
- if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!");
- const len = try stream_reader.readIntBig(u32);
- const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4)));
- defer allocator.free(buf);
- try stream_reader.readNoEof(buf);
- return try msg_type.read(allocator, buf);
-}
-
pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) {
return .{ .child_reader = base_reader };
}
@@ -81,7 +71,7 @@ pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type {
}
// Caller frees
- pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 {
+ pub fn get(self: @This(), allocator: std.mem.Allocator) ![]u8 {
var buf = try allocator.alloc(u8, n);
errdefer allocator.free(buf);
@memcpy(buf[0..(n - self.pos)], self.ring[self.pos..n]);
diff --git a/src/proto/authentication_request.zig b/src/proto/authentication_request.zig
index 9203482..3ea5cd1 100644
--- a/src/proto/authentication_request.zig
+++ b/src/proto/authentication_request.zig
@@ -30,7 +30,11 @@ pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest {
log.err("invalid message length, expected 4 got {}", .{b.len});
return ProtocolError.InvalidMessageLength;
}
- const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType;
+ const auth_type_int = std.mem.readIntBig(u32, b[0..4]);
+ const inner_type = enum_from_int(InnerAuthRequestType, auth_type_int) orelse {
+ log.err("Unsupported auth type {}", .{auth_type_int});
+ return ClientError.UnsupportedAuthType;
+ };
var inner: InnerAuthRequest = switch (inner_type) {
.AuthRequestTypeOk => .{ .ok = AuthRequestOk{} },
.AuthRequestTypeCleartextPassword => .{ .cleartext_password = AuthRequestCleartextPassword{} },
diff --git a/src/proto/proto.zig b/src/proto/proto.zig
index df1b717..5e4489d 100644
--- a/src/proto/proto.zig
+++ b/src/proto/proto.zig
@@ -1,3 +1,4 @@
+const std = @import("std");
pub const StartupMessage = @import("startup_message.zig");
pub const AuthenticationRequest = @import("authentication_request.zig");
pub const PasswordMessage = @import("password_message.zig");
@@ -13,6 +14,53 @@ const CopyXResponse = @import("copy_x_response.zig").CopyXResponse;
pub const CopyInResponse = CopyXResponse('G');
pub const CopyOutResponse = CopyXResponse('H');
pub const CopyBothResponse = CopyXResponse('W');
+const enum_from_int = @import("../main.zig").enum_from_int;
+const ProtocolError = @import("../main.zig").ProtocolError;
+
+pub const BackendMessage = union(enum) {
+ AuthenticationRequest: AuthenticationRequest,
+ ErrorResponse: ErrorResponse,
+ ReadyForQuery: ReadyForQuery,
+ ParameterStatus: ParameterStatus,
+ BackendKeyData: BackendKeyData,
+ DataRow: DataRow,
+ RowDescription: RowDescription,
+ CommandComplete: CommandComplete,
+ CopyInResponse: CopyInResponse,
+ CopyOutResponse: CopyOutResponse,
+ CopyBothResponse: CopyBothResponse,
+
+ pub fn deinit(self: *BackendMessage, a: std.mem.Allocator) void {
+ switch (self.*) {
+ inline else => |*sf| {
+ sf.deinit(a);
+ },
+ }
+ }
+};
+
+test {
+ inline for (@typeInfo(BackendMessage).Union.fields) |field| {
+ const msg_type = field.type;
+ if (!@hasDecl(msg_type, "Tag")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .Tag decl");
+ if (!@hasDecl(msg_type, "read")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .read decl");
+ }
+}
+
+pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !BackendMessage {
+ const tag = try stream_reader.readByte();
+ const len = try stream_reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4)));
+ defer allocator.free(buf);
+ try stream_reader.readNoEof(buf);
+ inline for (@typeInfo(BackendMessage).Union.fields) |field| {
+ if (field.type.Tag == tag) {
+ return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf));
+ }
+ } else {
+ return ProtocolError.InvalidMessageType;
+ }
+}
test {
_ = AuthenticationRequest;
diff --git a/src/proto/row_description.zig b/src/proto/row_description.zig
index a0e8810..b8105e2 100644
--- a/src/proto/row_description.zig
+++ b/src/proto/row_description.zig
@@ -34,7 +34,7 @@ pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription {
for (0..n_fields) |i| {
const name_start = fbs.pos;
try reader.skipUntilDelimiterOrEof(0);
- const name_end = fbs.pos-1;
+ const name_end = fbs.pos - 1;
const name = res.buf.?[name_start..name_end];
const field = Field{
.name = name,