diff options
author | Martin Ashby <martin@ashbysoft.com> | 2023-09-29 09:44:54 +0100 |
---|---|---|
committer | Martin Ashby <martin@ashbysoft.com> | 2023-09-29 09:44:54 +0100 |
commit | fada72cd26ad31e1fc834788c1224ed05a78143b (patch) | |
tree | db71f2cbc6cbf9c47148271682641c940eb75650 | |
parent | 6de632a41bdd127e92de68d61a18dfee91b8b188 (diff) | |
download | pgz-main.tar.gz pgz-main.tar.bz2 pgz-main.tar.xz pgz-main.zip |
structure.
Add a very basic test for running an actual query
-rw-r--r-- | src/conn/conn.zig | 36 | ||||
-rw-r--r-- | src/proto/command_complete.zig | 7 | ||||
-rw-r--r-- | src/proto/error_response.zig | 297 | ||||
-rw-r--r-- | src/proto/proto.zig | 23 |
4 files changed, 202 insertions, 161 deletions
diff --git a/src/conn/conn.zig b/src/conn/conn.zig index 9d5a8e1..378a8d1 100644 --- a/src/conn/conn.zig +++ b/src/conn/conn.zig @@ -8,6 +8,7 @@ const PasswordMessage = proto.PasswordMessage; const BackendMessage = proto.BackendMessage; const RowDescription = proto.RowDescription; const read_message = proto.read_message; +const clone_message = proto.clone_message; const ProtocolError = @import("../main.zig").ProtocolError; const ServerError = @import("../main.zig").ServerError; const ClientError = @import("../main.zig").ClientError; @@ -99,9 +100,10 @@ fn receive_message(self: *Conn) !BackendMessage { return ServerError.ErrorResponse; } }, - // .NoticeResponse => { - // // TODO handle notice response - // }, + .NoticeResponse => |nr| { + // log it? + log.warn("NOTICE {}", .{nr}); + }, // .NotificationResponse => { // // TODO handle notificationResponse // }, @@ -154,7 +156,8 @@ pub const ResultIterator = struct { pub fn skip_to_end(self: *ResultIterator) !void { while (self.command_complete == null) { - _ = try self.receive_message(); + var msg = try self.receive_message(); + msg.deinit(self.conn.allocator); } } @@ -163,16 +166,17 @@ pub const ResultIterator = struct { switch (msg) { .DataRow => |dr| { if (self.current_datarow != null) self.current_datarow.?.deinit(self.conn.allocator); - self.current_datarow = try dr.clone(self.conn.allocator); + self.current_datarow = try clone_message(dr, self.conn.allocator); }, .RowDescription => |rd| { if (self.row_description != null) return ProtocolError.UnexpectedMessage; - self.row_description = try rd.clone(self.conn.allocator); + self.row_description = try clone_message(rd, self.conn.allocator); }, .CommandComplete => |cc| { if (self.command_complete != null) return ProtocolError.UnexpectedMessage; - self.command_complete = try cc.clone(self.conn.allocator); + self.command_complete = try clone_message(cc, self.conn.allocator); }, + else => {}, } return msg; } @@ -252,3 +256,21 @@ test "connect tcp with wrong password" { // }; // try std.testing.expectError(ServerError.ErrorResponse, Conn.connect(cfg)); } + +test "exec" { + // must have a local postgres runnning + // TODO maybe use docker to start one? + const allocator = std.testing.allocator; + const cfg = Config{ + .allocator = allocator, + .address = .{ .unix = "/run/postgresql/.s.PGSQL.5432" }, + .database = "martin", + .user = "martin", + }; + var conn = try Conn.connect(cfg); + defer conn.deinit(); + var ri = try conn.exec("create table if not exists foo (col1 int not null)"); + defer ri.deinit(); + try ri.skip_to_end(); + +}
\ No newline at end of file diff --git a/src/proto/command_complete.zig b/src/proto/command_complete.zig index f9a9e26..ed8e052 100644 --- a/src/proto/command_complete.zig +++ b/src/proto/command_complete.zig @@ -30,13 +30,6 @@ pub fn deinit(self: *CommandComplete, a: std.mem.Allocator) void { if (self.buf != null) a.free(self.buf.?); } -pub fn clone(self: CommandComplete, a: std.mem.Allocator) !CommandComplete { - var ba = ByteArrayList.init(a); - errdefer ba.deinit(); - try self.write(a, ba.writer()); - return try CommandComplete.read(a, ba.items); -} - test "round trip" { const allocator = std.testing.allocator; var sm = CommandComplete{ diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig index 58ca06e..f572f0a 100644 --- a/src/proto/error_response.zig +++ b/src/proto/error_response.zig @@ -3,160 +3,165 @@ const HMByteString = std.AutoHashMap(u8, []const u8); const ByteArrayList = std.ArrayList(u8); const ProtocolError = @import("../main.zig").ProtocolError; -const ErrorResponse = @This(); -pub const Tag: u8 = 'E'; +pub fn ErrorNoticeResponse(comptime tag:u8) type { + return struct { + pub const Tag: u8 = tag; + buf: ?[]const u8 = null, // owned + severity: []const u8, + severity_unlocalized: ?[]const u8 = null, + code: []const u8, + message: []const u8, + detail: ?[]const u8 = null, + hint: ?[]const u8 = null, + position: ?u32 = null, + internal_position: ?u32 = null, + internal_query: ?[]const u8 = null, + where: ?[]const u8 = null, + schema_name: ?[]const u8 = null, + table_name: ?[]const u8 = null, + column_name: ?[]const u8 = null, + data_type_name: ?[]const u8 = null, + constraint_name: ?[]const u8 = null, + file_name: ?[]const u8 = null, + line: ?u32 = null, + routine: ?[]const u8 = null, + unknown_fields: HMByteString, -buf: ?[]const u8 = null, // owned -severity: []const u8, -severity_unlocalized: ?[]const u8 = null, -code: []const u8, -message: []const u8, -detail: ?[]const u8 = null, -hint: ?[]const u8 = null, -position: ?u32 = null, -internal_position: ?u32 = null, -internal_query: ?[]const u8 = null, -where: ?[]const u8 = null, -schema_name: ?[]const u8 = null, -table_name: ?[]const u8 = null, -column_name: ?[]const u8 = null, -data_type_name: ?[]const u8 = null, -constraint_name: ?[]const u8 = null, -file_name: ?[]const u8 = null, -line: ?u32 = null, -routine: ?[]const u8 = null, -unknown_fields: HMByteString, - -pub fn read(allocator: std.mem.Allocator, buf: []const u8) !ErrorResponse { - var res = ErrorResponse{ - .severity = "", - .code = "", - .message = "", - .unknown_fields = HMByteString.init(allocator), - .buf = buf, - }; - errdefer res.deinit(allocator); - var it = std.mem.splitScalar(u8, res.buf.?, 0); - var setSev = false; - var setCode = false; - var setMsg = false; - while (it.next()) |next| { - if (next.len < 1) break; - switch (next[0]) { - 0 => break, - 'S' => { - res.severity = next[1..]; - setSev = true; - }, - 'V' => { - res.severity_unlocalized = next[1..]; - }, - 'C' => { - res.code = next[1..]; - setCode = true; - }, - 'M' => { - res.message = next[1..]; - setMsg = true; - }, - 'D' => { - res.detail = next[1..]; - }, - 'H' => { - res.hint = next[1..]; - }, - 'P' => { - res.position = try std.fmt.parseInt(u32, next[1..], 10); - }, - 'p' => { - res.internal_position = try std.fmt.parseInt(u32, next[1..], 10); - }, - 'q' => { - res.internal_query = next[1..]; - }, - 'W' => { - res.where = next[1..]; - }, - 's' => { - res.schema_name = next[1..]; - }, - 't' => { - res.table_name = next[1..]; - }, - 'c' => { - res.column_name = next[1..]; - }, - 'd' => { - res.data_type_name = next[1..]; - }, - 'n' => { - res.constraint_name = next[1..]; - }, - 'F' => { - res.file_name = next[1..]; - }, - 'L' => { - res.line = try std.fmt.parseInt(u32, next[1..], 10); - }, - 'R' => { - res.routine = next[1..]; - }, - else => { - try res.unknown_fields.put(next[0], next[1..]); - }, + pub fn read(allocator: std.mem.Allocator, buf: []const u8) !@This() { + var res = @This(){ + .severity = "", + .code = "", + .message = "", + .unknown_fields = HMByteString.init(allocator), + .buf = buf, + }; + errdefer res.deinit(allocator); + var it = std.mem.splitScalar(u8, res.buf.?, 0); + var setSev = false; + var setCode = false; + var setMsg = false; + while (it.next()) |next| { + if (next.len < 1) break; + switch (next[0]) { + 0 => break, + 'S' => { + res.severity = next[1..]; + setSev = true; + }, + 'V' => { + res.severity_unlocalized = next[1..]; + }, + 'C' => { + res.code = next[1..]; + setCode = true; + }, + 'M' => { + res.message = next[1..]; + setMsg = true; + }, + 'D' => { + res.detail = next[1..]; + }, + 'H' => { + res.hint = next[1..]; + }, + 'P' => { + res.position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'p' => { + res.internal_position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'q' => { + res.internal_query = next[1..]; + }, + 'W' => { + res.where = next[1..]; + }, + 's' => { + res.schema_name = next[1..]; + }, + 't' => { + res.table_name = next[1..]; + }, + 'c' => { + res.column_name = next[1..]; + }, + 'd' => { + res.data_type_name = next[1..]; + }, + 'n' => { + res.constraint_name = next[1..]; + }, + 'F' => { + res.file_name = next[1..]; + }, + 'L' => { + res.line = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'R' => { + res.routine = next[1..]; + }, + else => { + try res.unknown_fields.put(next[0], next[1..]); + }, + } + } + if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField; + return res; } - } - if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField; - return res; -} -pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - var al = ByteArrayList.init(allocator); - defer al.deinit(); - var cw = std.io.countingWriter(al.writer()); - var writer = cw.writer(); - try writer.writeIntBig(u32, 0); // Length placeholder. + pub fn write(self: @This(), allocator: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(allocator); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // Length placeholder. - try write_field_nt('S', self, "severity", writer); - if (self.severity_unlocalized) |severity_unlocalized| try write_nt('V', severity_unlocalized, writer); - try write_field_nt('C', self, "code", writer); - try write_field_nt('M', self, "message", writer); - if (self.detail) |detail| try write_nt('D', detail, writer); - // TODO rest of the fields + try write_field_nt('S', self, "severity", writer); + if (self.severity_unlocalized) |severity_unlocalized| try write_nt('V', severity_unlocalized, writer); + try write_field_nt('C', self, "code", writer); + try write_field_nt('M', self, "message", writer); + if (self.detail) |detail| try write_nt('D', detail, writer); + // TODO rest of the fields - // replace the length and write it to the actual stream - std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); - try stream_writer.writeAll(al.items); -} -fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void { - try write_nt(tag, @field(self, field), writer); -} -fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void { - try writer.writeByte(tag); - try writer.writeAll(value); - try writer.writeByte(0); -} + // replace the length and write it to the actual stream + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); + } + fn write_field_nt(comptime label: u8, self: @This(), comptime field: []const u8, writer: anytype) !void { + try write_nt(label, @field(self, field), writer); + } + fn write_nt(comptime label: u8, value: []const u8, writer: anytype) !void { + try writer.writeByte(label); + try writer.writeAll(value); + try writer.writeByte(0); + } -pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void { - self.unknown_fields.deinit(); - if (self.buf != null) allocator.free(self.buf.?); -} + pub fn deinit(self: *@This(), allocator: std.mem.Allocator) void { + self.unknown_fields.deinit(); + if (self.buf != null) allocator.free(self.buf.?); + } -pub fn format(self: ErrorResponse, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { - _ = options; - _ = fmt; - try writer.writeAll("ErrorResponse severity ["); - try writer.writeAll(self.severity); - try writer.writeAll("] "); - try writer.writeAll("code ["); - try writer.writeAll(self.code); - try writer.writeAll("] "); - try writer.writeAll("message ["); - try writer.writeAll(self.message); - try writer.writeAll("]"); + pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + _ = options; + _ = fmt; + try writer.writeAll(@typeName(@TypeOf(@This()))); + try writer.writeAll(" severity ["); + try writer.writeAll(self.severity); + try writer.writeAll("] "); + try writer.writeAll("code ["); + try writer.writeAll(self.code); + try writer.writeAll("] "); + try writer.writeAll("message ["); + try writer.writeAll(self.message); + try writer.writeAll("]"); + } + }; } + test "round trip" { + const ErrorResponse = ErrorNoticeResponse('E'); const allocator = std.testing.allocator; var sm = ErrorResponse{ .severity = "foo", @@ -175,7 +180,7 @@ test "round trip" { var fbs = std.io.fixedBufferStream(bal.items); var reader = fbs.reader(); const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); + try std.testing.expectEqual(ErrorResponse.Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); try reader.readNoEof(buf); diff --git a/src/proto/proto.zig b/src/proto/proto.zig index 9025347..261ed9e 100644 --- a/src/proto/proto.zig +++ b/src/proto/proto.zig @@ -1,8 +1,11 @@ const std = @import("std"); +const ByteArrayList = std.ArrayList(u8); +const log = std.log.scoped(.pgz); pub const StartupMessage = @import("startup_message.zig"); pub const AuthenticationRequest = @import("authentication_request.zig"); pub const PasswordMessage = @import("password_message.zig"); -pub const ErrorResponse = @import("error_response.zig"); +pub const ErrorResponse = @import("error_response.zig").ErrorNoticeResponse('E'); +pub const NoticeResponse = @import("error_response.zig").ErrorNoticeResponse('N'); pub const ReadyForQuery = @import("ready_for_query.zig"); pub const ParameterStatus = @import("parameter_status.zig"); pub const BackendKeyData = @import("backend_key_data.zig"); @@ -20,6 +23,7 @@ const ProtocolError = @import("../main.zig").ProtocolError; pub const BackendMessage = union(enum) { AuthenticationRequest: AuthenticationRequest, ErrorResponse: ErrorResponse, + NoticeResponse: NoticeResponse, ReadyForQuery: ReadyForQuery, ParameterStatus: ParameterStatus, BackendKeyData: BackendKeyData, @@ -58,10 +62,27 @@ pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !Backe } } else { allocator.free(buf); + log.err("InvalidMessageType {c}", .{tag}); return ProtocolError.InvalidMessageType; } } +// Caller owns the resulting message. +// 'self' must be one of the message types above. +pub fn clone_message(self: anytype, a: std.mem.Allocator) !@TypeOf(self) { + var ba = ByteArrayList.init(a); + defer ba.deinit(); + try self.write(a, ba.writer()); + var fbs = std.io.fixedBufferStream(ba.items); + var reader = fbs.reader(); + _ = try reader.readByte(); + const len = try reader.readIntBig(u32); + var buf = try a.alloc(u8, len-4); + errdefer a.free(buf); + try reader.readNoEof(buf); + return try @TypeOf(self).read(a, buf); +} + test { _ = AuthenticationRequest; _ = PasswordMessage; |