aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-09-29 09:44:54 +0100
committerMartin Ashby <martin@ashbysoft.com>2023-09-29 09:44:54 +0100
commitfada72cd26ad31e1fc834788c1224ed05a78143b (patch)
treedb71f2cbc6cbf9c47148271682641c940eb75650
parent6de632a41bdd127e92de68d61a18dfee91b8b188 (diff)
downloadpgz-fada72cd26ad31e1fc834788c1224ed05a78143b.tar.gz
pgz-fada72cd26ad31e1fc834788c1224ed05a78143b.tar.bz2
pgz-fada72cd26ad31e1fc834788c1224ed05a78143b.tar.xz
pgz-fada72cd26ad31e1fc834788c1224ed05a78143b.zip
Generify ErrorResponse to allow for NoticeResponse which shares it'sHEADmain
structure. Add a very basic test for running an actual query
-rw-r--r--src/conn/conn.zig36
-rw-r--r--src/proto/command_complete.zig7
-rw-r--r--src/proto/error_response.zig297
-rw-r--r--src/proto/proto.zig23
4 files changed, 202 insertions, 161 deletions
diff --git a/src/conn/conn.zig b/src/conn/conn.zig
index 9d5a8e1..378a8d1 100644
--- a/src/conn/conn.zig
+++ b/src/conn/conn.zig
@@ -8,6 +8,7 @@ const PasswordMessage = proto.PasswordMessage;
const BackendMessage = proto.BackendMessage;
const RowDescription = proto.RowDescription;
const read_message = proto.read_message;
+const clone_message = proto.clone_message;
const ProtocolError = @import("../main.zig").ProtocolError;
const ServerError = @import("../main.zig").ServerError;
const ClientError = @import("../main.zig").ClientError;
@@ -99,9 +100,10 @@ fn receive_message(self: *Conn) !BackendMessage {
return ServerError.ErrorResponse;
}
},
- // .NoticeResponse => {
- // // TODO handle notice response
- // },
+ .NoticeResponse => |nr| {
+ // log it?
+ log.warn("NOTICE {}", .{nr});
+ },
// .NotificationResponse => {
// // TODO handle notificationResponse
// },
@@ -154,7 +156,8 @@ pub const ResultIterator = struct {
pub fn skip_to_end(self: *ResultIterator) !void {
while (self.command_complete == null) {
- _ = try self.receive_message();
+ var msg = try self.receive_message();
+ msg.deinit(self.conn.allocator);
}
}
@@ -163,16 +166,17 @@ pub const ResultIterator = struct {
switch (msg) {
.DataRow => |dr| {
if (self.current_datarow != null) self.current_datarow.?.deinit(self.conn.allocator);
- self.current_datarow = try dr.clone(self.conn.allocator);
+ self.current_datarow = try clone_message(dr, self.conn.allocator);
},
.RowDescription => |rd| {
if (self.row_description != null) return ProtocolError.UnexpectedMessage;
- self.row_description = try rd.clone(self.conn.allocator);
+ self.row_description = try clone_message(rd, self.conn.allocator);
},
.CommandComplete => |cc| {
if (self.command_complete != null) return ProtocolError.UnexpectedMessage;
- self.command_complete = try cc.clone(self.conn.allocator);
+ self.command_complete = try clone_message(cc, self.conn.allocator);
},
+ else => {},
}
return msg;
}
@@ -252,3 +256,21 @@ test "connect tcp with wrong password" {
// };
// try std.testing.expectError(ServerError.ErrorResponse, Conn.connect(cfg));
}
+
+test "exec" {
+ // must have a local postgres runnning
+ // TODO maybe use docker to start one?
+ const allocator = std.testing.allocator;
+ const cfg = Config{
+ .allocator = allocator,
+ .address = .{ .unix = "/run/postgresql/.s.PGSQL.5432" },
+ .database = "martin",
+ .user = "martin",
+ };
+ var conn = try Conn.connect(cfg);
+ defer conn.deinit();
+ var ri = try conn.exec("create table if not exists foo (col1 int not null)");
+ defer ri.deinit();
+ try ri.skip_to_end();
+
+} \ No newline at end of file
diff --git a/src/proto/command_complete.zig b/src/proto/command_complete.zig
index f9a9e26..ed8e052 100644
--- a/src/proto/command_complete.zig
+++ b/src/proto/command_complete.zig
@@ -30,13 +30,6 @@ pub fn deinit(self: *CommandComplete, a: std.mem.Allocator) void {
if (self.buf != null) a.free(self.buf.?);
}
-pub fn clone(self: CommandComplete, a: std.mem.Allocator) !CommandComplete {
- var ba = ByteArrayList.init(a);
- errdefer ba.deinit();
- try self.write(a, ba.writer());
- return try CommandComplete.read(a, ba.items);
-}
-
test "round trip" {
const allocator = std.testing.allocator;
var sm = CommandComplete{
diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig
index 58ca06e..f572f0a 100644
--- a/src/proto/error_response.zig
+++ b/src/proto/error_response.zig
@@ -3,160 +3,165 @@ const HMByteString = std.AutoHashMap(u8, []const u8);
const ByteArrayList = std.ArrayList(u8);
const ProtocolError = @import("../main.zig").ProtocolError;
-const ErrorResponse = @This();
-pub const Tag: u8 = 'E';
+pub fn ErrorNoticeResponse(comptime tag:u8) type {
+ return struct {
+ pub const Tag: u8 = tag;
+ buf: ?[]const u8 = null, // owned
+ severity: []const u8,
+ severity_unlocalized: ?[]const u8 = null,
+ code: []const u8,
+ message: []const u8,
+ detail: ?[]const u8 = null,
+ hint: ?[]const u8 = null,
+ position: ?u32 = null,
+ internal_position: ?u32 = null,
+ internal_query: ?[]const u8 = null,
+ where: ?[]const u8 = null,
+ schema_name: ?[]const u8 = null,
+ table_name: ?[]const u8 = null,
+ column_name: ?[]const u8 = null,
+ data_type_name: ?[]const u8 = null,
+ constraint_name: ?[]const u8 = null,
+ file_name: ?[]const u8 = null,
+ line: ?u32 = null,
+ routine: ?[]const u8 = null,
+ unknown_fields: HMByteString,
-buf: ?[]const u8 = null, // owned
-severity: []const u8,
-severity_unlocalized: ?[]const u8 = null,
-code: []const u8,
-message: []const u8,
-detail: ?[]const u8 = null,
-hint: ?[]const u8 = null,
-position: ?u32 = null,
-internal_position: ?u32 = null,
-internal_query: ?[]const u8 = null,
-where: ?[]const u8 = null,
-schema_name: ?[]const u8 = null,
-table_name: ?[]const u8 = null,
-column_name: ?[]const u8 = null,
-data_type_name: ?[]const u8 = null,
-constraint_name: ?[]const u8 = null,
-file_name: ?[]const u8 = null,
-line: ?u32 = null,
-routine: ?[]const u8 = null,
-unknown_fields: HMByteString,
-
-pub fn read(allocator: std.mem.Allocator, buf: []const u8) !ErrorResponse {
- var res = ErrorResponse{
- .severity = "",
- .code = "",
- .message = "",
- .unknown_fields = HMByteString.init(allocator),
- .buf = buf,
- };
- errdefer res.deinit(allocator);
- var it = std.mem.splitScalar(u8, res.buf.?, 0);
- var setSev = false;
- var setCode = false;
- var setMsg = false;
- while (it.next()) |next| {
- if (next.len < 1) break;
- switch (next[0]) {
- 0 => break,
- 'S' => {
- res.severity = next[1..];
- setSev = true;
- },
- 'V' => {
- res.severity_unlocalized = next[1..];
- },
- 'C' => {
- res.code = next[1..];
- setCode = true;
- },
- 'M' => {
- res.message = next[1..];
- setMsg = true;
- },
- 'D' => {
- res.detail = next[1..];
- },
- 'H' => {
- res.hint = next[1..];
- },
- 'P' => {
- res.position = try std.fmt.parseInt(u32, next[1..], 10);
- },
- 'p' => {
- res.internal_position = try std.fmt.parseInt(u32, next[1..], 10);
- },
- 'q' => {
- res.internal_query = next[1..];
- },
- 'W' => {
- res.where = next[1..];
- },
- 's' => {
- res.schema_name = next[1..];
- },
- 't' => {
- res.table_name = next[1..];
- },
- 'c' => {
- res.column_name = next[1..];
- },
- 'd' => {
- res.data_type_name = next[1..];
- },
- 'n' => {
- res.constraint_name = next[1..];
- },
- 'F' => {
- res.file_name = next[1..];
- },
- 'L' => {
- res.line = try std.fmt.parseInt(u32, next[1..], 10);
- },
- 'R' => {
- res.routine = next[1..];
- },
- else => {
- try res.unknown_fields.put(next[0], next[1..]);
- },
+ pub fn read(allocator: std.mem.Allocator, buf: []const u8) !@This() {
+ var res = @This(){
+ .severity = "",
+ .code = "",
+ .message = "",
+ .unknown_fields = HMByteString.init(allocator),
+ .buf = buf,
+ };
+ errdefer res.deinit(allocator);
+ var it = std.mem.splitScalar(u8, res.buf.?, 0);
+ var setSev = false;
+ var setCode = false;
+ var setMsg = false;
+ while (it.next()) |next| {
+ if (next.len < 1) break;
+ switch (next[0]) {
+ 0 => break,
+ 'S' => {
+ res.severity = next[1..];
+ setSev = true;
+ },
+ 'V' => {
+ res.severity_unlocalized = next[1..];
+ },
+ 'C' => {
+ res.code = next[1..];
+ setCode = true;
+ },
+ 'M' => {
+ res.message = next[1..];
+ setMsg = true;
+ },
+ 'D' => {
+ res.detail = next[1..];
+ },
+ 'H' => {
+ res.hint = next[1..];
+ },
+ 'P' => {
+ res.position = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'p' => {
+ res.internal_position = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'q' => {
+ res.internal_query = next[1..];
+ },
+ 'W' => {
+ res.where = next[1..];
+ },
+ 's' => {
+ res.schema_name = next[1..];
+ },
+ 't' => {
+ res.table_name = next[1..];
+ },
+ 'c' => {
+ res.column_name = next[1..];
+ },
+ 'd' => {
+ res.data_type_name = next[1..];
+ },
+ 'n' => {
+ res.constraint_name = next[1..];
+ },
+ 'F' => {
+ res.file_name = next[1..];
+ },
+ 'L' => {
+ res.line = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'R' => {
+ res.routine = next[1..];
+ },
+ else => {
+ try res.unknown_fields.put(next[0], next[1..]);
+ },
+ }
+ }
+ if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField;
+ return res;
}
- }
- if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField;
- return res;
-}
-pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void {
- try stream_writer.writeByte(Tag);
- var al = ByteArrayList.init(allocator);
- defer al.deinit();
- var cw = std.io.countingWriter(al.writer());
- var writer = cw.writer();
- try writer.writeIntBig(u32, 0); // Length placeholder.
+ pub fn write(self: @This(), allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ var al = ByteArrayList.init(allocator);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // Length placeholder.
- try write_field_nt('S', self, "severity", writer);
- if (self.severity_unlocalized) |severity_unlocalized| try write_nt('V', severity_unlocalized, writer);
- try write_field_nt('C', self, "code", writer);
- try write_field_nt('M', self, "message", writer);
- if (self.detail) |detail| try write_nt('D', detail, writer);
- // TODO rest of the fields
+ try write_field_nt('S', self, "severity", writer);
+ if (self.severity_unlocalized) |severity_unlocalized| try write_nt('V', severity_unlocalized, writer);
+ try write_field_nt('C', self, "code", writer);
+ try write_field_nt('M', self, "message", writer);
+ if (self.detail) |detail| try write_nt('D', detail, writer);
+ // TODO rest of the fields
- // replace the length and write it to the actual stream
- std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
- try stream_writer.writeAll(al.items);
-}
-fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void {
- try write_nt(tag, @field(self, field), writer);
-}
-fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void {
- try writer.writeByte(tag);
- try writer.writeAll(value);
- try writer.writeByte(0);
-}
+ // replace the length and write it to the actual stream
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+ }
+ fn write_field_nt(comptime label: u8, self: @This(), comptime field: []const u8, writer: anytype) !void {
+ try write_nt(label, @field(self, field), writer);
+ }
+ fn write_nt(comptime label: u8, value: []const u8, writer: anytype) !void {
+ try writer.writeByte(label);
+ try writer.writeAll(value);
+ try writer.writeByte(0);
+ }
-pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void {
- self.unknown_fields.deinit();
- if (self.buf != null) allocator.free(self.buf.?);
-}
+ pub fn deinit(self: *@This(), allocator: std.mem.Allocator) void {
+ self.unknown_fields.deinit();
+ if (self.buf != null) allocator.free(self.buf.?);
+ }
-pub fn format(self: ErrorResponse, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
- _ = options;
- _ = fmt;
- try writer.writeAll("ErrorResponse severity [");
- try writer.writeAll(self.severity);
- try writer.writeAll("] ");
- try writer.writeAll("code [");
- try writer.writeAll(self.code);
- try writer.writeAll("] ");
- try writer.writeAll("message [");
- try writer.writeAll(self.message);
- try writer.writeAll("]");
+ pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
+ _ = options;
+ _ = fmt;
+ try writer.writeAll(@typeName(@TypeOf(@This())));
+ try writer.writeAll(" severity [");
+ try writer.writeAll(self.severity);
+ try writer.writeAll("] ");
+ try writer.writeAll("code [");
+ try writer.writeAll(self.code);
+ try writer.writeAll("] ");
+ try writer.writeAll("message [");
+ try writer.writeAll(self.message);
+ try writer.writeAll("]");
+ }
+ };
}
+
test "round trip" {
+ const ErrorResponse = ErrorNoticeResponse('E');
const allocator = std.testing.allocator;
var sm = ErrorResponse{
.severity = "foo",
@@ -175,7 +180,7 @@ test "round trip" {
var fbs = std.io.fixedBufferStream(bal.items);
var reader = fbs.reader();
const tag = try reader.readByte();
- try std.testing.expectEqual(Tag, tag);
+ try std.testing.expectEqual(ErrorResponse.Tag, tag);
const len = try reader.readIntBig(u32);
const buf = try allocator.alloc(u8, len - 4);
try reader.readNoEof(buf);
diff --git a/src/proto/proto.zig b/src/proto/proto.zig
index 9025347..261ed9e 100644
--- a/src/proto/proto.zig
+++ b/src/proto/proto.zig
@@ -1,8 +1,11 @@
const std = @import("std");
+const ByteArrayList = std.ArrayList(u8);
+const log = std.log.scoped(.pgz);
pub const StartupMessage = @import("startup_message.zig");
pub const AuthenticationRequest = @import("authentication_request.zig");
pub const PasswordMessage = @import("password_message.zig");
-pub const ErrorResponse = @import("error_response.zig");
+pub const ErrorResponse = @import("error_response.zig").ErrorNoticeResponse('E');
+pub const NoticeResponse = @import("error_response.zig").ErrorNoticeResponse('N');
pub const ReadyForQuery = @import("ready_for_query.zig");
pub const ParameterStatus = @import("parameter_status.zig");
pub const BackendKeyData = @import("backend_key_data.zig");
@@ -20,6 +23,7 @@ const ProtocolError = @import("../main.zig").ProtocolError;
pub const BackendMessage = union(enum) {
AuthenticationRequest: AuthenticationRequest,
ErrorResponse: ErrorResponse,
+ NoticeResponse: NoticeResponse,
ReadyForQuery: ReadyForQuery,
ParameterStatus: ParameterStatus,
BackendKeyData: BackendKeyData,
@@ -58,10 +62,27 @@ pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !Backe
}
} else {
allocator.free(buf);
+ log.err("InvalidMessageType {c}", .{tag});
return ProtocolError.InvalidMessageType;
}
}
+// Caller owns the resulting message.
+// 'self' must be one of the message types above.
+pub fn clone_message(self: anytype, a: std.mem.Allocator) !@TypeOf(self) {
+ var ba = ByteArrayList.init(a);
+ defer ba.deinit();
+ try self.write(a, ba.writer());
+ var fbs = std.io.fixedBufferStream(ba.items);
+ var reader = fbs.reader();
+ _ = try reader.readByte();
+ const len = try reader.readIntBig(u32);
+ var buf = try a.alloc(u8, len-4);
+ errdefer a.free(buf);
+ try reader.readNoEof(buf);
+ return try @TypeOf(self).read(a, buf);
+}
+
test {
_ = AuthenticationRequest;
_ = PasswordMessage;