From 183d60a6e87230cc767c56900b94c9c694596de1 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Tue, 26 Sep 2023 06:51:06 +0100 Subject: Move protocol definitions into a subfolder --- src/authentication_request.zig | 75 ---------------- src/backend_key_data.zig | 52 ----------- src/command_complete.zig | 56 ------------ src/conn.zig | 38 ++++---- src/copy_in_response.zig | 83 ----------------- src/data_row.zig | 88 ------------------ src/error_response.zig | 170 ----------------------------------- src/main.zig | 39 ++------ src/parameter_status.zig | 64 ------------- src/password_message.zig | 45 ---------- src/proto/authentication_request.zig | 75 ++++++++++++++++ src/proto/backend_key_data.zig | 52 +++++++++++ src/proto/command_complete.zig | 56 ++++++++++++ src/proto/copy_in_response.zig | 83 +++++++++++++++++ src/proto/data_row.zig | 88 ++++++++++++++++++ src/proto/error_response.zig | 170 +++++++++++++++++++++++++++++++++++ src/proto/parameter_status.zig | 64 +++++++++++++ src/proto/password_message.zig | 45 ++++++++++ src/proto/proto.zig | 26 ++++++ src/proto/query.zig | 55 ++++++++++++ src/proto/ready_for_query.zig | 53 +++++++++++ src/proto/row_description.zig | 136 ++++++++++++++++++++++++++++ src/proto/startup_message.zig | 85 ++++++++++++++++++ src/query.zig | 55 ------------ src/ready_for_query.zig | 55 ------------ src/row_description.zig | 136 ---------------------------- src/startup_message.zig | 85 ------------------ 27 files changed, 1012 insertions(+), 1017 deletions(-) delete mode 100644 src/authentication_request.zig delete mode 100644 src/backend_key_data.zig delete mode 100644 src/command_complete.zig delete mode 100644 src/copy_in_response.zig delete mode 100644 src/data_row.zig delete mode 100644 src/error_response.zig delete mode 100644 src/parameter_status.zig delete mode 100644 src/password_message.zig create mode 100644 src/proto/authentication_request.zig create mode 100644 src/proto/backend_key_data.zig create mode 100644 src/proto/command_complete.zig create mode 100644 src/proto/copy_in_response.zig create mode 100644 src/proto/data_row.zig create mode 100644 src/proto/error_response.zig create mode 100644 src/proto/parameter_status.zig create mode 100644 src/proto/password_message.zig create mode 100644 src/proto/proto.zig create mode 100644 src/proto/query.zig create mode 100644 src/proto/ready_for_query.zig create mode 100644 src/proto/row_description.zig create mode 100644 src/proto/startup_message.zig delete mode 100644 src/query.zig delete mode 100644 src/ready_for_query.zig delete mode 100644 src/row_description.zig delete mode 100644 src/startup_message.zig (limited to 'src') diff --git a/src/authentication_request.zig b/src/authentication_request.zig deleted file mode 100644 index 549a26b..0000000 --- a/src/authentication_request.zig +++ /dev/null @@ -1,75 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; -const enum_from_int = @import("main.zig").enum_from_int; - -pub const Tag: u8 = 'R'; - -const AuthenticationRequest = @This(); - -pub const InnerAuthRequestType = enum(u32) { - AuthRequestTypeOk = 0, - AuthRequestTypeCleartextPassword = 3, -}; -pub const InnerAuthRequest = union{ - ok: AuthRequestOk, - cleartext_password: AuthRequestCleartextPassword, -}; -pub const AuthRequestOk = struct {}; -pub const AuthRequestCleartextPassword = struct {}; - -// Authentication requests have multiple subtypes. -// It's not possible to have a tagged union with a custom backing integer, so do it the long way -inner_type: InnerAuthRequestType, -inner: InnerAuthRequest, - -pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest { - if (b.len != 4) { - log.err("invalid message length, expected 4 got {}",.{b.len}); - return ProtocolError.InvalidMessageLength; - } - const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; - var inner: InnerAuthRequest = switch (inner_type) { - .AuthRequestTypeOk => .{.ok = AuthRequestOk{}}, - .AuthRequestTypeCleartextPassword => .{.cleartext_password = AuthRequestCleartextPassword{}}, - }; - return .{ - .inner_type = inner_type, - .inner = inner, - }; -} - -pub fn write(self: AuthenticationRequest, _: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, 8); - try stream_writer.writeIntBig(u32, @intFromEnum(self.inner_type)); -} - -pub fn deinit(_: *AuthenticationRequest, _: std.mem.Allocator) void {} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = AuthenticationRequest{ - .inner_type = .AuthRequestTypeOk, - .inner = .{.ok = AuthRequestOk{}}, - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try AuthenticationRequest.read(allocator, buf); - defer sm2.deinit(allocator); - try std.testing.expectEqual(InnerAuthRequestType.AuthRequestTypeOk, sm2.inner_type); -} diff --git a/src/backend_key_data.zig b/src/backend_key_data.zig deleted file mode 100644 index 525c309..0000000 --- a/src/backend_key_data.zig +++ /dev/null @@ -1,52 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; - -const BackendKeyData = @This(); -pub const Tag: u8 = 'K'; - -process_id: u32, -secret_key: u32, - -pub fn read(_: std.mem.Allocator, b: []const u8) !BackendKeyData { - if (b.len != 8) return ProtocolError.InvalidMessageLength; - return .{ - .process_id = std.mem.readIntBig(u32, b[0..4]), - .secret_key = std.mem.readIntBig(u32, b[4..8]), - }; -} -pub fn write(self: BackendKeyData, _: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, 12); // length - try stream_writer.writeIntBig(u32, self.process_id); - try stream_writer.writeIntBig(u32, self.secret_key); -} -pub fn deinit(_: *BackendKeyData, _: std.mem.Allocator) void {} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = BackendKeyData{ - .process_id = 123, - .secret_key = 345, - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try BackendKeyData.read(allocator, buf); - defer sm2.deinit(allocator); - try std.testing.expectEqual(@as(u32, 123), sm2.process_id); - try std.testing.expectEqual(@as(u32, 345), sm2.secret_key); -} diff --git a/src/command_complete.zig b/src/command_complete.zig deleted file mode 100644 index 5478547..0000000 --- a/src/command_complete.zig +++ /dev/null @@ -1,56 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; -const enum_from_int = @import("main.zig").enum_from_int; -const FormatCode = @import("main.zig").FormatCode; - -pub const Tag: u8 = 'C'; - -const CommandComplete = @This(); - -command_tag: []const u8, -owned: bool = false, - -pub fn read(a: std.mem.Allocator, b: []const u8) !CommandComplete { - return .{ - .command_tag = try a.dupe(u8, b), - .owned = true, - }; -} - -pub fn write(self: CommandComplete, _: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, @as(u32, @intCast(4+self.command_tag.len))); - try stream_writer.writeAll(self.command_tag); -} - -pub fn deinit(self: *CommandComplete, a: std.mem.Allocator) void { - if (self.owned) a.free(self.command_tag); -} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = CommandComplete{ - .command_tag = "foo", - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try CommandComplete.read(allocator, buf); - defer sm2.deinit(allocator); - - try std.testing.expectEqualStrings("foo", sm2.command_tag); -} diff --git a/src/conn.zig b/src/conn.zig index 99018da..f9f4fb5 100644 --- a/src/conn.zig +++ b/src/conn.zig @@ -2,12 +2,7 @@ const std = @import("std"); const log = std.log.scoped(.pgz); const SSHashMap = std.StringHashMap([]const u8); const Config = @import("config.zig"); -const StartupMessage = @import("startup_message.zig"); -const ErrorResponse = @import("error_response.zig"); -const AuthenticationRequest = @import("authentication_request.zig"); -const ReadyForQuery = @import("ready_for_query.zig"); -const ParameterStatus = @import("parameter_status.zig"); -const BackendKeyData = @import("backend_key_data.zig"); +const Proto = @import("proto/proto.zig"); const read_message = @import("main.zig").read_message; const ProtocolError = @import("main.zig").ProtocolError; const ServerError = @import("main.zig").ServerError; @@ -44,7 +39,7 @@ pub fn connect(config: Config) !Conn { var params = SSHashMap.init(allocator); try params.put("user", config.user); if (config.database) |database| try params.put("database", database); - var sm = StartupMessage{ + var sm = Proto.StartupMessage{ .parameters = params, }; defer sm.deinit(allocator); @@ -52,36 +47,36 @@ pub fn connect(config: Config) !Conn { lp: while (true) { const response_type = try reader.readByte(); switch (response_type) { - ErrorResponse.Tag => { - var err = try read_message(ErrorResponse, allocator, reader); + Proto.ErrorResponse.Tag => { + var err = try read_message(Proto.ErrorResponse, allocator, reader); defer err.deinit(allocator); log.err("Error connecting to server {any}", .{err}); return ServerError.ErrorResponse; }, - AuthenticationRequest.Tag => { - var ar = try read_message(AuthenticationRequest, allocator, reader); + Proto.AuthenticationRequest.Tag => { + var ar = try read_message(Proto.AuthenticationRequest, allocator, reader); defer ar.deinit(allocator); // TODO handle the authentication request log.info("authentication request", .{}); }, - ReadyForQuery.Tag => { - var rfq = try read_message(ReadyForQuery, allocator, reader); + Proto.ReadyForQuery.Tag => { + var rfq = try read_message(Proto.ReadyForQuery, allocator, reader); defer rfq.deinit(allocator); // TODO do something about transaction state? res.status = .connStatusIdle; log.info("ready for query", .{}); break :lp; }, - ParameterStatus.Tag => { - var ps = try read_message(ParameterStatus, allocator, reader); + Proto.ParameterStatus.Tag => { + var ps = try read_message(Proto.ParameterStatus, allocator, reader); defer ps.deinit(allocator); // TODO Handle this somehow? - log.info("ParameterStatus: {s}:{s}", .{ps.name, ps.value}); + log.info("ParameterStatus: {s}:{s}", .{ ps.name, ps.value }); }, - BackendKeyData.Tag =>{ - var bkd = try read_message(BackendKeyData, allocator, reader); + Proto.BackendKeyData.Tag => { + var bkd = try read_message(Proto.BackendKeyData, allocator, reader); defer bkd.deinit(allocator); - log.info("BackendKeyData process_id {} secret_key {}" , .{bkd.process_id, bkd.secret_key}); + log.info("BackendKeyData process_id {} secret_key {}", .{ bkd.process_id, bkd.secret_key }); }, else => { log.err("unhandled message type [{c}]", .{response_type}); @@ -99,7 +94,7 @@ fn deinit(self: *Conn) void { self.stream.close(); } -//pub fn exec(self: *Conn) +//pub fn exec(self: *Conn) test "connect" { // must have a local postgres runnning @@ -107,11 +102,10 @@ test "connect" { const allocator = std.testing.allocator; const cfg = Config{ .allocator = allocator, - .address = .{.unix = "/run/postgresql/.s.PGSQL.5432"}, + .address = .{ .unix = "/run/postgresql/.s.PGSQL.5432" }, .database = "martin", .user = "martin", }; var conn = try Conn.connect(cfg); defer conn.deinit(); } - diff --git a/src/copy_in_response.zig b/src/copy_in_response.zig deleted file mode 100644 index f6ed84b..0000000 --- a/src/copy_in_response.zig +++ /dev/null @@ -1,83 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; -const enum_from_int = @import("main.zig").enum_from_int; -const FormatCode = @import("main.zig").FormatCode; - -pub const Tag: u8 = 'G'; -// TODO generics it's the same as CopyOutResponse and CopyBothResponse. -const CopyInResponse = @This(); - -overall_format_code: u8, -format_codes: []const FormatCode, // owned - -pub fn read(a: std.mem.Allocator, b: []const u8) !CopyInResponse { - var fbs = std.io.fixedBufferStream(b); - var reader = fbs.reader(); - const overall_format_code = try reader.readIntBig(u8); - const n_columns = try reader.readIntBig(u16); - var format_codes = try a.alloc(FormatCode, n_columns); - errdefer a.free(format_codes); - for (0..n_columns) |i| { - const int_format_code = try reader.readIntBig(u16); - format_codes[i] = enum_from_int(FormatCode, int_format_code) orelse return ProtocolError.InvalidFormatCode; - } - return .{ - .overall_format_code = overall_format_code, - .format_codes = format_codes, - }; -} - -pub fn write(self: CopyInResponse, a: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeIntBig(u8, Tag); - var al = ByteArrayList.init(a); - defer al.deinit(); - var cw = std.io.countingWriter(al.writer()); - var writer = cw.writer(); - try writer.writeIntBig(u32, 0); // length placeholder - try writer.writeIntBig(u8, self.overall_format_code); - try writer.writeIntBig(u16, @intCast(self.format_codes.len)); - for (self.format_codes) |format_code| { - try writer.writeIntBig(u16, @intFromEnum(format_code)); - } - std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); // Update length - try stream_writer.writeAll(al.items); -} - -pub fn deinit(self: *CopyInResponse, a: std.mem.Allocator) void { - a.free(self.format_codes); -} - -test "round trip" { - const allocator = std.testing.allocator; - var format_codes = try allocator.alloc(FormatCode, 3); - format_codes[0] = .Binary; - format_codes[1] = .Binary; - format_codes[2] = .Text; - var sm = CopyInResponse{ - .overall_format_code = 1, - .format_codes = format_codes, - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try CopyInResponse.read(allocator, buf); - defer sm2.deinit(allocator); - - try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[0]); - try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[1]); - try std.testing.expectEqual(FormatCode.Text, sm2.format_codes[2]); -} diff --git a/src/data_row.zig b/src/data_row.zig deleted file mode 100644 index 558ebca..0000000 --- a/src/data_row.zig +++ /dev/null @@ -1,88 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; -const enum_from_int = @import("main.zig").enum_from_int; - -pub const Tag: u8 = 'D'; - -const DataRow = @This(); - -buf: ?[]const u8 = null, // owned -columns: [][]const u8, // also owned - -pub fn read(a: std.mem.Allocator, b: []const u8) !DataRow { - if (b.len < 2) return ProtocolError.InvalidMessageLength; - var buf = try a.dupe(u8, b); - var res: DataRow = undefined; - res.buf = buf; - errdefer res.deinit(a); - - const n_columns = std.mem.readIntBig(u16, buf[0..2]); - const columns = try a.alloc([]const u8, n_columns); - errdefer a.free(columns); - var pos: usize = 2; - for (0..n_columns) |col| { - const len = std.mem.readIntBig(u32, buf[pos..(pos+4)][0..4]); // second slice forces the slice size to be known at comptime and satisfy the type check on readIntBig! - const data = if (len > 0) buf[(pos+4)..(pos+4+len)] else &[_]u8{}; - columns[col] = data; - pos += (4+len); - } - res.columns = columns; - return res; -} - -pub fn write(self: DataRow, a: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - var al = ByteArrayList.init(a); - defer al.deinit(); - var cw = std.io.countingWriter(al.writer()); - var writer = cw.writer(); - try writer.writeIntBig(u32, 0); // length placeholder - try writer.writeIntBig(u16, @as(u16, @intCast(self.columns.len))); - for (self.columns) |column| { - const len = @as(u32, @intCast(column.len)); - try writer.writeIntBig(u32, len); - try writer.writeAll(column); - } - // Fixup the length and write to the original stream - std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); - try stream_writer.writeAll(al.items); -} - -pub fn deinit(self: *DataRow, a: std.mem.Allocator) void { - if (self.buf != null) a.free(self.buf.?); - a.free(self.columns); -} - -test "round trip" { - const allocator = std.testing.allocator; - const columns = try allocator.alloc([]const u8, 3); - columns[0] = "Hello"; - columns[1] = "FooBar"; - columns[2] = ""; - var sm = DataRow{ - .columns = columns, - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try DataRow.read(allocator, buf); - defer sm2.deinit(allocator); - - try std.testing.expectEqualStrings("Hello", sm2.columns[0]); - try std.testing.expectEqualStrings("FooBar", sm2.columns[1]); - try std.testing.expectEqualStrings("", sm2.columns[2]); -} diff --git a/src/error_response.zig b/src/error_response.zig deleted file mode 100644 index 3b66c15..0000000 --- a/src/error_response.zig +++ /dev/null @@ -1,170 +0,0 @@ -const std = @import("std"); -const HMByteString = std.AutoHashMap(u8, []const u8); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; - -const ErrorResponse = @This(); -pub const Tag: u8 = 'E'; - -buf: ?[]const u8 = null, // owned -severity: []const u8, -severity_unlocalized: ?[]const u8 = null, -code: []const u8, -message: []const u8, -detail: ?[]const u8 = null, -hint: ?[]const u8 = null, -position: ?u32 = null, -internal_position: ?u32 = null, -internal_query: ?[]const u8 = null, -where: ?[]const u8 = null, -schema_name: ?[]const u8 = null, -table_name: ?[]const u8 = null, -column_name: ?[]const u8 = null, -data_type_name: ?[]const u8 = null, -constraint_name: ?[]const u8 = null, -file_name: ?[]const u8 = null, -line: ?u32 = null, -routine: ?[]const u8 = null, -unknown_fields: HMByteString, - -pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse { - var res: ErrorResponse = undefined; - res.unknown_fields = HMByteString.init(allocator); - res.buf = try allocator.dupe(u8, b); - errdefer allocator.free(res.buf.?); - var it = std.mem.splitScalar(u8, res.buf.?, 0); - var setSev = false; var setCode = false; var setMsg = false; - while (it.next()) |next| { - if (next.len < 1) break; - switch (next[0]) { - 0 => break, - 'S' => { - res.severity = next[1..]; - setSev = true; - }, - 'V' => { - res.severity_unlocalized = next[1..]; - }, - 'C' => { - res.code = next[1..]; - setCode = true; - }, - 'M' => { - res.message = next[1..]; - setMsg = true; - }, - 'D' => { - res.detail = next[1..]; - }, - 'H' => { - res.hint = next[1..]; - }, - 'P' => { - res.position = try std.fmt.parseInt(u32, next[1..], 10); - }, - 'p' => { - res.internal_position = try std.fmt.parseInt(u32, next[1..], 10); - }, - 'q' => { - res.internal_query = next[1..]; - }, - 'W' => { - res.where = next[1..]; - }, - 's' => { - res.schema_name = next[1..]; - }, - 't' => { - res.table_name = next[1..]; - }, - 'c' => { - res.column_name = next[1..]; - }, - 'd' => { - res.data_type_name = next[1..]; - }, - 'n' => { - res.constraint_name = next[1..]; - }, - 'F' => { - res.file_name = next[1..]; - }, - 'L' => { - res.line = try std.fmt.parseInt(u32, next[1..], 10); - }, - 'R' => { - res.routine = next[1..]; - }, - else => { - try res.unknown_fields.put(next[0], next[1..]); - } - } - } - if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField; - return res; -} -pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - var al = ByteArrayList.init(allocator); - defer al.deinit(); - var cw = std.io.countingWriter(al.writer()); - var writer = cw.writer(); - try writer.writeIntBig(u32, 0); // Length placeholder. - - try write_field_nt('S', self, "severity", writer); - if (self.severity_unlocalized) |severity_unlocalized| { - try write_nt('V', severity_unlocalized, writer); - } - try write_field_nt('C', self, "code", writer); - try write_field_nt('M', self, "message", writer); - // TODO rest of the fields - - // replace the length and write it to the actual stream - std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); - try stream_writer.writeAll(al.items); -} -fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void { - try write_nt(tag, @field(self, field), writer); -} -fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void { - try writer.writeByte(tag); - try writer.writeAll(value); - try writer.writeByte(0); -} - -pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void { - self.unknown_fields.deinit(); - if (self.buf != null) allocator.free(self.buf.?); -} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = ErrorResponse{ - .severity = "foo", - .severity_unlocalized = "foo_unlocal", - .code = "bar", - .message = "baz", - .unknown_fields = HMByteString.init(allocator), - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try ErrorResponse.read(allocator, buf); - defer sm2.deinit(allocator); - - try std.testing.expectEqualStrings("foo", sm2.severity); - try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?); - try std.testing.expectEqualStrings("bar", sm2.code); - try std.testing.expectEqualStrings("baz", sm2.message); -} diff --git a/src/main.zig b/src/main.zig index 8bc4649..f5941b7 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,18 +1,7 @@ const std = @import("std"); const testing = std.testing; -const StartupMessage = @import("startup_message.zig"); -const AuthenticationRequest = @import("authentication_request.zig"); -const PasswordMessage = @import("password_message.zig"); -const ErrorResponse = @import("error_response.zig"); -const ReadyForQuery = @import("ready_for_query.zig"); -const ParameterStatus = @import("parameter_status.zig"); -const BackendKeyData = @import("backend_key_data.zig"); -const Query = @import("query.zig"); -const DataRow = @import("data_row.zig"); -const RowDescription = @import("row_description.zig"); -const CommandComplete = @import("command_complete.zig"); -const CopyInResponse = @import("copy_in_response.zig"); const Conn = @import("conn.zig"); +const Proto = @import("proto/proto.zig"); pub const ProtocolError = error{ InvalidProtocolVersion, @@ -58,22 +47,21 @@ pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, strea if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!"); if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!"); const len = try stream_reader.readIntBig(u32); - const buf = try allocator.alloc(u8, @as(u32, @intCast(len-4))); + const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); defer allocator.free(buf); try stream_reader.readNoEof(buf); return try msg_type.read(allocator, buf); } - pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) { - return .{.child_reader = base_reader}; + return .{ .child_reader = base_reader }; } // keeps a buffer of the last n bytes read pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type { return struct { child_reader: ReaderType, - ring: [n]u8 = [_]u8{0}**n, + ring: [n]u8 = [_]u8{0} ** n, pos: usize = 0, pub const Error = ReaderType.Error; @@ -97,8 +85,8 @@ pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type { pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 { var buf = try allocator.alloc(u8, n); errdefer allocator.free(buf); - @memcpy(buf[0..(n-self.pos)], self.ring[self.pos..n]); - @memcpy(buf[(n-self.pos)..n], self.ring[0..self.pos]); + @memcpy(buf[0..(n - self.pos)], self.ring[self.pos..n]); + @memcpy(buf[(n - self.pos)..n], self.ring[0..self.pos]); return buf; } }; @@ -110,7 +98,7 @@ test "diagnostc reader" { var fbs = std.io.fixedBufferStream(string); var dr = diagnosticReader(15, fbs.reader()); var reader = dr.reader(); - var buf = [_]u8{0}**20; + var buf = [_]u8{0} ** 20; try reader.readNoEof(&buf); const diag = try dr.get(a); defer a.free(diag); @@ -118,17 +106,6 @@ test "diagnostc reader" { } test { - _ = StartupMessage; - _ = AuthenticationRequest; - _ = PasswordMessage; - _ = ErrorResponse; + _ = Proto; _ = Conn; - _ = ReadyForQuery; - _ = ParameterStatus; - _ = BackendKeyData; - _ = Query; - _ = DataRow; - _ = RowDescription; - _ = CommandComplete; - _ = CopyInResponse; } diff --git a/src/parameter_status.zig b/src/parameter_status.zig deleted file mode 100644 index d938fd2..0000000 --- a/src/parameter_status.zig +++ /dev/null @@ -1,64 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; - -const ParameterStatus = @This(); -pub const Tag: u8 = 'S'; - -buf: ?[]const u8 = null, // owned -name: []const u8, -value: []const u8, - -pub fn read(allocator: std.mem.Allocator, b: []const u8) !ParameterStatus { - var res: ParameterStatus = undefined; - res.buf = try allocator.dupe(u8, b); - var it = std.mem.splitScalar(u8, res.buf.?, 0); - res.name = it.first(); - res.value = it.next() orelse return ProtocolError.MissingField; - return res; -} -pub fn write(self: ParameterStatus, a: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - var al = ByteArrayList.init(a); - defer al.deinit(); - var cw = std.io.countingWriter(al.writer()); - var writer = cw.writer(); - try writer.writeIntBig(u32, 0); // length placeholder - try writer.writeAll(self.name); - try writer.writeByte(0); - try writer.writeAll(self.value); - try writer.writeByte(0); - std.mem.writeIntBig(u32, al.items[0..4], @as(u32,@intCast(cw.bytes_written))); // Fix length - try stream_writer.writeAll(al.items); -} -pub fn deinit(self: *ParameterStatus, allocator: std.mem.Allocator) void { - if (self.buf != null) allocator.free(self.buf.?); -} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = ParameterStatus{ - .name = "Hello", - .value = "world", - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try ParameterStatus.read(allocator, buf); - defer sm2.deinit(allocator); - try std.testing.expectEqualStrings("Hello", sm2.name); - try std.testing.expectEqualStrings("world", sm2.value); -} diff --git a/src/password_message.zig b/src/password_message.zig deleted file mode 100644 index 1a8c17a..0000000 --- a/src/password_message.zig +++ /dev/null @@ -1,45 +0,0 @@ -const std = @import("std"); -const ByteArrayList = std.ArrayList(u8); -const PasswordMessage = @This(); - -pub const Tag: u8 = 'p'; -password: []const u8, -password_owned: bool = false, - -pub fn read(allocator: std.mem.Allocator, b: []const u8) !PasswordMessage { - return .{ .password = try allocator.dupe(u8, b), .password_owned = true }; -} - -pub fn write(self: PasswordMessage, _: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, 5 + @as(u32, @intCast(self.password.len))); - try stream_writer.writeAll(self.password); - try stream_writer.writeByte(0); -} - -pub fn deinit(self: *PasswordMessage, allocator: std.mem.Allocator) void { - if (self.password_owned) allocator.free(self.password); -} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = PasswordMessage{ - .password = "foobar", - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try PasswordMessage.read(allocator, buf); - defer sm2.deinit(allocator); -} diff --git a/src/proto/authentication_request.zig b/src/proto/authentication_request.zig new file mode 100644 index 0000000..9203482 --- /dev/null +++ b/src/proto/authentication_request.zig @@ -0,0 +1,75 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; + +pub const Tag: u8 = 'R'; + +const AuthenticationRequest = @This(); + +pub const InnerAuthRequestType = enum(u32) { + AuthRequestTypeOk = 0, + AuthRequestTypeCleartextPassword = 3, +}; +pub const InnerAuthRequest = union { + ok: AuthRequestOk, + cleartext_password: AuthRequestCleartextPassword, +}; +pub const AuthRequestOk = struct {}; +pub const AuthRequestCleartextPassword = struct {}; + +// Authentication requests have multiple subtypes. +// It's not possible to have a tagged union with a custom backing integer, so do it the long way +inner_type: InnerAuthRequestType, +inner: InnerAuthRequest, + +pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest { + if (b.len != 4) { + log.err("invalid message length, expected 4 got {}", .{b.len}); + return ProtocolError.InvalidMessageLength; + } + const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; + var inner: InnerAuthRequest = switch (inner_type) { + .AuthRequestTypeOk => .{ .ok = AuthRequestOk{} }, + .AuthRequestTypeCleartextPassword => .{ .cleartext_password = AuthRequestCleartextPassword{} }, + }; + return .{ + .inner_type = inner_type, + .inner = inner, + }; +} + +pub fn write(self: AuthenticationRequest, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 8); + try stream_writer.writeIntBig(u32, @intFromEnum(self.inner_type)); +} + +pub fn deinit(_: *AuthenticationRequest, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = AuthenticationRequest{ + .inner_type = .AuthRequestTypeOk, + .inner = .{ .ok = AuthRequestOk{} }, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try AuthenticationRequest.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqual(InnerAuthRequestType.AuthRequestTypeOk, sm2.inner_type); +} diff --git a/src/proto/backend_key_data.zig b/src/proto/backend_key_data.zig new file mode 100644 index 0000000..7c32178 --- /dev/null +++ b/src/proto/backend_key_data.zig @@ -0,0 +1,52 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; + +const BackendKeyData = @This(); +pub const Tag: u8 = 'K'; + +process_id: u32, +secret_key: u32, + +pub fn read(_: std.mem.Allocator, b: []const u8) !BackendKeyData { + if (b.len != 8) return ProtocolError.InvalidMessageLength; + return .{ + .process_id = std.mem.readIntBig(u32, b[0..4]), + .secret_key = std.mem.readIntBig(u32, b[4..8]), + }; +} +pub fn write(self: BackendKeyData, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 12); // length + try stream_writer.writeIntBig(u32, self.process_id); + try stream_writer.writeIntBig(u32, self.secret_key); +} +pub fn deinit(_: *BackendKeyData, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = BackendKeyData{ + .process_id = 123, + .secret_key = 345, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try BackendKeyData.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqual(@as(u32, 123), sm2.process_id); + try std.testing.expectEqual(@as(u32, 345), sm2.secret_key); +} diff --git a/src/proto/command_complete.zig b/src/proto/command_complete.zig new file mode 100644 index 0000000..80014e9 --- /dev/null +++ b/src/proto/command_complete.zig @@ -0,0 +1,56 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; +const FormatCode = @import("../main.zig").FormatCode; + +pub const Tag: u8 = 'C'; + +const CommandComplete = @This(); + +command_tag: []const u8, +owned: bool = false, + +pub fn read(a: std.mem.Allocator, b: []const u8) !CommandComplete { + return .{ + .command_tag = try a.dupe(u8, b), + .owned = true, + }; +} + +pub fn write(self: CommandComplete, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, @as(u32, @intCast(4+self.command_tag.len))); + try stream_writer.writeAll(self.command_tag); +} + +pub fn deinit(self: *CommandComplete, a: std.mem.Allocator) void { + if (self.owned) a.free(self.command_tag); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = CommandComplete{ + .command_tag = "foo", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try CommandComplete.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("foo", sm2.command_tag); +} diff --git a/src/proto/copy_in_response.zig b/src/proto/copy_in_response.zig new file mode 100644 index 0000000..65d9ad7 --- /dev/null +++ b/src/proto/copy_in_response.zig @@ -0,0 +1,83 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; +const FormatCode = @import("../main.zig").FormatCode; + +pub const Tag: u8 = 'G'; +// TODO generics it's the same as CopyOutResponse and CopyBothResponse. +const CopyInResponse = @This(); + +overall_format_code: u8, +format_codes: []const FormatCode, // owned + +pub fn read(a: std.mem.Allocator, b: []const u8) !CopyInResponse { + var fbs = std.io.fixedBufferStream(b); + var reader = fbs.reader(); + const overall_format_code = try reader.readIntBig(u8); + const n_columns = try reader.readIntBig(u16); + var format_codes = try a.alloc(FormatCode, n_columns); + errdefer a.free(format_codes); + for (0..n_columns) |i| { + const int_format_code = try reader.readIntBig(u16); + format_codes[i] = enum_from_int(FormatCode, int_format_code) orelse return ProtocolError.InvalidFormatCode; + } + return .{ + .overall_format_code = overall_format_code, + .format_codes = format_codes, + }; +} + +pub fn write(self: CopyInResponse, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeIntBig(u8, Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u8, self.overall_format_code); + try writer.writeIntBig(u16, @intCast(self.format_codes.len)); + for (self.format_codes) |format_code| { + try writer.writeIntBig(u16, @intFromEnum(format_code)); + } + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); // Update length + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *CopyInResponse, a: std.mem.Allocator) void { + a.free(self.format_codes); +} + +test "round trip" { + const allocator = std.testing.allocator; + var format_codes = try allocator.alloc(FormatCode, 3); + format_codes[0] = .Binary; + format_codes[1] = .Binary; + format_codes[2] = .Text; + var sm = CopyInResponse{ + .overall_format_code = 1, + .format_codes = format_codes, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try CopyInResponse.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[0]); + try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[1]); + try std.testing.expectEqual(FormatCode.Text, sm2.format_codes[2]); +} diff --git a/src/proto/data_row.zig b/src/proto/data_row.zig new file mode 100644 index 0000000..c20b794 --- /dev/null +++ b/src/proto/data_row.zig @@ -0,0 +1,88 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; + +pub const Tag: u8 = 'D'; + +const DataRow = @This(); + +buf: ?[]const u8 = null, // owned +columns: [][]const u8, // also owned + +pub fn read(a: std.mem.Allocator, b: []const u8) !DataRow { + if (b.len < 2) return ProtocolError.InvalidMessageLength; + var buf = try a.dupe(u8, b); + var res: DataRow = undefined; + res.buf = buf; + errdefer res.deinit(a); + + const n_columns = std.mem.readIntBig(u16, buf[0..2]); + const columns = try a.alloc([]const u8, n_columns); + errdefer a.free(columns); + var pos: usize = 2; + for (0..n_columns) |col| { + const len = std.mem.readIntBig(u32, buf[pos..(pos+4)][0..4]); // second slice forces the slice size to be known at comptime and satisfy the type check on readIntBig! + const data = if (len > 0) buf[(pos+4)..(pos+4+len)] else &[_]u8{}; + columns[col] = data; + pos += (4+len); + } + res.columns = columns; + return res; +} + +pub fn write(self: DataRow, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u16, @as(u16, @intCast(self.columns.len))); + for (self.columns) |column| { + const len = @as(u32, @intCast(column.len)); + try writer.writeIntBig(u32, len); + try writer.writeAll(column); + } + // Fixup the length and write to the original stream + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *DataRow, a: std.mem.Allocator) void { + if (self.buf != null) a.free(self.buf.?); + a.free(self.columns); +} + +test "round trip" { + const allocator = std.testing.allocator; + const columns = try allocator.alloc([]const u8, 3); + columns[0] = "Hello"; + columns[1] = "FooBar"; + columns[2] = ""; + var sm = DataRow{ + .columns = columns, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try DataRow.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("Hello", sm2.columns[0]); + try std.testing.expectEqualStrings("FooBar", sm2.columns[1]); + try std.testing.expectEqualStrings("", sm2.columns[2]); +} diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig new file mode 100644 index 0000000..2aafe8d --- /dev/null +++ b/src/proto/error_response.zig @@ -0,0 +1,170 @@ +const std = @import("std"); +const HMByteString = std.AutoHashMap(u8, []const u8); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; + +const ErrorResponse = @This(); +pub const Tag: u8 = 'E'; + +buf: ?[]const u8 = null, // owned +severity: []const u8, +severity_unlocalized: ?[]const u8 = null, +code: []const u8, +message: []const u8, +detail: ?[]const u8 = null, +hint: ?[]const u8 = null, +position: ?u32 = null, +internal_position: ?u32 = null, +internal_query: ?[]const u8 = null, +where: ?[]const u8 = null, +schema_name: ?[]const u8 = null, +table_name: ?[]const u8 = null, +column_name: ?[]const u8 = null, +data_type_name: ?[]const u8 = null, +constraint_name: ?[]const u8 = null, +file_name: ?[]const u8 = null, +line: ?u32 = null, +routine: ?[]const u8 = null, +unknown_fields: HMByteString, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse { + var res: ErrorResponse = undefined; + res.unknown_fields = HMByteString.init(allocator); + res.buf = try allocator.dupe(u8, b); + errdefer allocator.free(res.buf.?); + var it = std.mem.splitScalar(u8, res.buf.?, 0); + var setSev = false; var setCode = false; var setMsg = false; + while (it.next()) |next| { + if (next.len < 1) break; + switch (next[0]) { + 0 => break, + 'S' => { + res.severity = next[1..]; + setSev = true; + }, + 'V' => { + res.severity_unlocalized = next[1..]; + }, + 'C' => { + res.code = next[1..]; + setCode = true; + }, + 'M' => { + res.message = next[1..]; + setMsg = true; + }, + 'D' => { + res.detail = next[1..]; + }, + 'H' => { + res.hint = next[1..]; + }, + 'P' => { + res.position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'p' => { + res.internal_position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'q' => { + res.internal_query = next[1..]; + }, + 'W' => { + res.where = next[1..]; + }, + 's' => { + res.schema_name = next[1..]; + }, + 't' => { + res.table_name = next[1..]; + }, + 'c' => { + res.column_name = next[1..]; + }, + 'd' => { + res.data_type_name = next[1..]; + }, + 'n' => { + res.constraint_name = next[1..]; + }, + 'F' => { + res.file_name = next[1..]; + }, + 'L' => { + res.line = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'R' => { + res.routine = next[1..]; + }, + else => { + try res.unknown_fields.put(next[0], next[1..]); + } + } + } + if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField; + return res; +} +pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(allocator); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // Length placeholder. + + try write_field_nt('S', self, "severity", writer); + if (self.severity_unlocalized) |severity_unlocalized| { + try write_nt('V', severity_unlocalized, writer); + } + try write_field_nt('C', self, "code", writer); + try write_field_nt('M', self, "message", writer); + // TODO rest of the fields + + // replace the length and write it to the actual stream + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} +fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void { + try write_nt(tag, @field(self, field), writer); +} +fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void { + try writer.writeByte(tag); + try writer.writeAll(value); + try writer.writeByte(0); +} + +pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void { + self.unknown_fields.deinit(); + if (self.buf != null) allocator.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ErrorResponse{ + .severity = "foo", + .severity_unlocalized = "foo_unlocal", + .code = "bar", + .message = "baz", + .unknown_fields = HMByteString.init(allocator), + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ErrorResponse.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("foo", sm2.severity); + try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?); + try std.testing.expectEqualStrings("bar", sm2.code); + try std.testing.expectEqualStrings("baz", sm2.message); +} diff --git a/src/proto/parameter_status.zig b/src/proto/parameter_status.zig new file mode 100644 index 0000000..5f95695 --- /dev/null +++ b/src/proto/parameter_status.zig @@ -0,0 +1,64 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; + +const ParameterStatus = @This(); +pub const Tag: u8 = 'S'; + +buf: ?[]const u8 = null, // owned +name: []const u8, +value: []const u8, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ParameterStatus { + var res: ParameterStatus = undefined; + res.buf = try allocator.dupe(u8, b); + var it = std.mem.splitScalar(u8, res.buf.?, 0); + res.name = it.first(); + res.value = it.next() orelse return ProtocolError.MissingField; + return res; +} +pub fn write(self: ParameterStatus, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeAll(self.name); + try writer.writeByte(0); + try writer.writeAll(self.value); + try writer.writeByte(0); + std.mem.writeIntBig(u32, al.items[0..4], @as(u32,@intCast(cw.bytes_written))); // Fix length + try stream_writer.writeAll(al.items); +} +pub fn deinit(self: *ParameterStatus, allocator: std.mem.Allocator) void { + if (self.buf != null) allocator.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ParameterStatus{ + .name = "Hello", + .value = "world", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ParameterStatus.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqualStrings("Hello", sm2.name); + try std.testing.expectEqualStrings("world", sm2.value); +} diff --git a/src/proto/password_message.zig b/src/proto/password_message.zig new file mode 100644 index 0000000..1a8c17a --- /dev/null +++ b/src/proto/password_message.zig @@ -0,0 +1,45 @@ +const std = @import("std"); +const ByteArrayList = std.ArrayList(u8); +const PasswordMessage = @This(); + +pub const Tag: u8 = 'p'; +password: []const u8, +password_owned: bool = false, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !PasswordMessage { + return .{ .password = try allocator.dupe(u8, b), .password_owned = true }; +} + +pub fn write(self: PasswordMessage, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 5 + @as(u32, @intCast(self.password.len))); + try stream_writer.writeAll(self.password); + try stream_writer.writeByte(0); +} + +pub fn deinit(self: *PasswordMessage, allocator: std.mem.Allocator) void { + if (self.password_owned) allocator.free(self.password); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = PasswordMessage{ + .password = "foobar", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try PasswordMessage.read(allocator, buf); + defer sm2.deinit(allocator); +} diff --git a/src/proto/proto.zig b/src/proto/proto.zig new file mode 100644 index 0000000..4465e4b --- /dev/null +++ b/src/proto/proto.zig @@ -0,0 +1,26 @@ +pub const StartupMessage = @import("startup_message.zig"); +pub const AuthenticationRequest = @import("authentication_request.zig"); +pub const PasswordMessage = @import("password_message.zig"); +pub const ErrorResponse = @import("error_response.zig"); +pub const ReadyForQuery = @import("ready_for_query.zig"); +pub const ParameterStatus = @import("parameter_status.zig"); +pub const BackendKeyData = @import("backend_key_data.zig"); +pub const Query = @import("query.zig"); +pub const DataRow = @import("data_row.zig"); +pub const RowDescription = @import("row_description.zig"); +pub const CommandComplete = @import("command_complete.zig"); +pub const CopyInResponse = @import("copy_in_response.zig"); + +test { + _ = AuthenticationRequest; + _ = PasswordMessage; + _ = ErrorResponse; + _ = ReadyForQuery; + _ = ParameterStatus; + _ = BackendKeyData; + _ = Query; + _ = DataRow; + _ = RowDescription; + _ = CommandComplete; + _ = CopyInResponse; +} diff --git a/src/proto/query.zig b/src/proto/query.zig new file mode 100644 index 0000000..9f238fc --- /dev/null +++ b/src/proto/query.zig @@ -0,0 +1,55 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; + +pub const Tag: u8 = 'Q'; + +const Query = @This(); + +string: []const u8, +owned: bool = false, + +pub fn read(a: std.mem.Allocator, b: []const u8) !Query { + return .{ + .string = try a.dupe(u8, b[0..(b.len-1)]), // Drop the null terminator + .owned = true, + }; +} + +pub fn write(self: Query, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, @as(u32, @intCast(self.string.len+5))); + try stream_writer.writeAll(self.string); + try stream_writer.writeByte(0); +} + +pub fn deinit(self: *Query, a: std.mem.Allocator) void { + if (self.owned) a.free(self.string); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = Query{ + .string = "Hello", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try Query.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqualStrings("Hello", sm2.string); +} diff --git a/src/proto/ready_for_query.zig b/src/proto/ready_for_query.zig new file mode 100644 index 0000000..ef99e60 --- /dev/null +++ b/src/proto/ready_for_query.zig @@ -0,0 +1,53 @@ +const std = @import("std"); +const ProtocolError = @import("../main.zig").ProtocolError; +const enum_from_int = @import("../main.zig").enum_from_int; +const ByteArrayList = std.ArrayList(u8); + +const ReadyForQuery = @This(); +pub const Tag: u8 = 'Z'; + +const TransactionStatus = enum(u8) { + idle = 'I', + transaction = 'T', + err = 'E', +}; + +transaction_status: TransactionStatus, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ReadyForQuery { + _ = allocator; + if (b.len != 1) return ProtocolError.InvalidMessageLength; + return .{ .transaction_status = enum_from_int(TransactionStatus, b[0]) orelse return ProtocolError.InvalidTransactionStatus }; +} +pub fn write(self: ReadyForQuery, allocator: std.mem.Allocator, stream_writer: anytype) !void { + _ = allocator; + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 5); + try stream_writer.writeByte(@intFromEnum(self.transaction_status)); +} +pub fn deinit(_: *ReadyForQuery, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ReadyForQuery{ + .transaction_status = TransactionStatus.idle, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ReadyForQuery.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqual(TransactionStatus.idle, sm2.transaction_status); +} diff --git a/src/proto/row_description.zig b/src/proto/row_description.zig new file mode 100644 index 0000000..a0e8810 --- /dev/null +++ b/src/proto/row_description.zig @@ -0,0 +1,136 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; +const FormatCode = @import("../main.zig").FormatCode; + +pub const Tag: u8 = 'T'; + +const RowDescription = @This(); + +buf: ?[]const u8 = null, // owned +fields: ?[]Field = null, // owned + +pub const Field = struct { + name: []const u8, + table_oid: u32, + attr_no: u16, + data_type_oid: u32, + data_type_size: i16, + data_type_modifier: u32, + format_code: FormatCode, +}; + +pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { + var res: RowDescription = undefined; + res.buf = try a.dupe(u8, b); + errdefer res.deinit(a); + var fbs = std.io.fixedBufferStream(res.buf.?); + var reader = fbs.reader(); + const n_fields = try reader.readIntBig(u16); + res.fields = try a.alloc(Field, n_fields); + for (0..n_fields) |i| { + const name_start = fbs.pos; + try reader.skipUntilDelimiterOrEof(0); + const name_end = fbs.pos-1; + const name = res.buf.?[name_start..name_end]; + const field = Field{ + .name = name, + .table_oid = try reader.readIntBig(u32), + .attr_no = try reader.readIntBig(u16), + .data_type_oid = try reader.readIntBig(u32), + .data_type_size = try reader.readIntBig(i16), + .data_type_modifier = try reader.readIntBig(u32), + .format_code = enum_from_int(FormatCode, try reader.readIntBig(u16)) orelse return ProtocolError.InvalidFormatCode, + }; + res.fields.?[i] = field; + } + return res; +} + +pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u16, @as(u16, @intCast(self.fields.?.len))); + for (self.fields.?) |field| { + try writer.writeAll(field.name); + try writer.writeByte(0); + try writer.writeIntBig(u32, field.table_oid); + try writer.writeIntBig(u16, field.attr_no); + try writer.writeIntBig(u32, field.data_type_oid); + try writer.writeIntBig(i16, field.data_type_size); + try writer.writeIntBig(u32, field.data_type_modifier); + try writer.writeIntBig(u16, @intFromEnum(field.format_code)); + } + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *RowDescription, a: std.mem.Allocator) void { + if (self.fields != null) a.free(self.fields.?); + if (self.buf != null) a.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var fields = try allocator.alloc(Field, 3); + fields[0] = .{ + .name = "foo", + .table_oid = 1, + .attr_no = 2, + .data_type_oid = 3, + .data_type_size = 4, + .data_type_modifier = 5, + .format_code = .Binary, + }; + fields[1] = .{ + .name = "bar", + .table_oid = 1, + .attr_no = 2, + .data_type_oid = 3, + .data_type_size = 4, + .data_type_modifier = 5, + .format_code = .Binary, + }; + fields[2] = .{ + .name = "BAZZZZZ", + .table_oid = 99, + .attr_no = 98, + .data_type_oid = 97, + .data_type_size = 96, + .data_type_modifier = 95, + .format_code = .Text, + }; + var f0 = fields[0]; + var f1 = fields[1]; + var f2 = fields[2]; + var sm = RowDescription{ + .fields = fields, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try RowDescription.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualDeep(f0, sm2.fields.?[0]); + try std.testing.expectEqualDeep(f1, sm2.fields.?[1]); + try std.testing.expectEqualDeep(f2, sm2.fields.?[2]); +} diff --git a/src/proto/startup_message.zig b/src/proto/startup_message.zig new file mode 100644 index 0000000..8224bdb --- /dev/null +++ b/src/proto/startup_message.zig @@ -0,0 +1,85 @@ +const std = @import("std"); +const ProtocolError = @import("../main.zig").ProtocolError; +const SSHashMap = std.StringHashMap([]const u8); +const ByteArrayList = std.ArrayList(u8); + +const StartupMessage = @This(); + +const ProtocolVersionNumber: u32 = 196608; // 3.0 + +bytes: ?[]const u8 = null, // Owned +parameters: SSHashMap, + +// message length should already have been read, b should contain the payload +pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage { + if (b.len < 4) return ProtocolError.InvalidMessageLength; + + var bytes = try allocator.dupe(u8, b); + errdefer allocator.free(bytes); + const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]); + if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion; + + var parameters = SSHashMap.init(allocator); + var it = std.mem.splitScalar(u8, bytes[4..], 0); + while (it.next()) |next| { + const key = next; + const value = it.next() orelse return ProtocolError.InvalidKeyValuePair; + try parameters.put(key, value); + } + + return .{ + .bytes = bytes, + .parameters = parameters, + }; +} + +pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void { + var al = ByteArrayList.init(allocator); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u32, ProtocolVersionNumber); + var it = self.parameters.iterator(); + while (it.next()) |entry| { + try writer.writeAll(entry.key_ptr.*); + try writer.writeByte(0); + try writer.writeAll(entry.value_ptr.*); + try writer.writeByte(0); + } + try writer.writeByte(0); + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void { + self.parameters.deinit(); + if (self.bytes != null) { + allocator.free(self.bytes.?); + } +} + +test "round trip" { + const allocator = std.testing.allocator; + var params = SSHashMap.init(allocator); + try params.put("hello", "postgres"); + var sm = StartupMessage{ + .parameters = params, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try StartupMessage.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?); +} diff --git a/src/query.zig b/src/query.zig deleted file mode 100644 index 8f07b5f..0000000 --- a/src/query.zig +++ /dev/null @@ -1,55 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; -const enum_from_int = @import("main.zig").enum_from_int; - -pub const Tag: u8 = 'Q'; - -const Query = @This(); - -string: []const u8, -owned: bool = false, - -pub fn read(a: std.mem.Allocator, b: []const u8) !Query { - return .{ - .string = try a.dupe(u8, b[0..(b.len-1)]), // Drop the null terminator - .owned = true, - }; -} - -pub fn write(self: Query, _: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, @as(u32, @intCast(self.string.len+5))); - try stream_writer.writeAll(self.string); - try stream_writer.writeByte(0); -} - -pub fn deinit(self: *Query, a: std.mem.Allocator) void { - if (self.owned) a.free(self.string); -} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = Query{ - .string = "Hello", - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try Query.read(allocator, buf); - defer sm2.deinit(allocator); - try std.testing.expectEqualStrings("Hello", sm2.string); -} diff --git a/src/ready_for_query.zig b/src/ready_for_query.zig deleted file mode 100644 index 883d9e1..0000000 --- a/src/ready_for_query.zig +++ /dev/null @@ -1,55 +0,0 @@ -const std = @import("std"); -const ProtocolError = @import("main.zig").ProtocolError; -const enum_from_int = @import("main.zig").enum_from_int; -const ByteArrayList = std.ArrayList(u8); - -const ReadyForQuery = @This(); -pub const Tag: u8 = 'Z'; - -const TransactionStatus = enum(u8) { - idle = 'I', - transaction = 'T', - err = 'E', -}; - -transaction_status: TransactionStatus, - -pub fn read(allocator: std.mem.Allocator, b: []const u8) !ReadyForQuery { - _ = allocator; - if (b.len != 1) return ProtocolError.InvalidMessageLength; - return .{ - .transaction_status = enum_from_int(TransactionStatus,b[0]) orelse return ProtocolError.InvalidTransactionStatus - }; -} -pub fn write(self: ReadyForQuery, allocator: std.mem.Allocator, stream_writer: anytype) !void { - _ = allocator; - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, 5); - try stream_writer.writeByte(@intFromEnum(self.transaction_status)); -} -pub fn deinit(_: *ReadyForQuery, _: std.mem.Allocator) void {} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = ReadyForQuery{ - .transaction_status = TransactionStatus.idle, - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try ReadyForQuery.read(allocator, buf); - defer sm2.deinit(allocator); - - try std.testing.expectEqual(TransactionStatus.idle, sm2.transaction_status); -} diff --git a/src/row_description.zig b/src/row_description.zig deleted file mode 100644 index 414c174..0000000 --- a/src/row_description.zig +++ /dev/null @@ -1,136 +0,0 @@ -const std = @import("std"); -const log = std.log.scoped(.pgz); -const ByteArrayList = std.ArrayList(u8); -const ProtocolError = @import("main.zig").ProtocolError; -const ClientError = @import("main.zig").ClientError; -const enum_from_int = @import("main.zig").enum_from_int; -const FormatCode = @import("main.zig").FormatCode; - -pub const Tag: u8 = 'T'; - -const RowDescription = @This(); - -buf: ?[]const u8 = null, // owned -fields: ?[]Field = null, // owned - -pub const Field = struct { - name: []const u8, - table_oid: u32, - attr_no: u16, - data_type_oid: u32, - data_type_size: i16, - data_type_modifier: u32, - format_code: FormatCode, -}; - -pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { - var res: RowDescription = undefined; - res.buf = try a.dupe(u8, b); - errdefer res.deinit(a); - var fbs = std.io.fixedBufferStream(res.buf.?); - var reader = fbs.reader(); - const n_fields = try reader.readIntBig(u16); - res.fields = try a.alloc(Field, n_fields); - for (0..n_fields) |i| { - const name_start = fbs.pos; - try reader.skipUntilDelimiterOrEof(0); - const name_end = fbs.pos-1; - const name = res.buf.?[name_start..name_end]; - const field = Field{ - .name = name, - .table_oid = try reader.readIntBig(u32), - .attr_no = try reader.readIntBig(u16), - .data_type_oid = try reader.readIntBig(u32), - .data_type_size = try reader.readIntBig(i16), - .data_type_modifier = try reader.readIntBig(u32), - .format_code = enum_from_int(FormatCode, try reader.readIntBig(u16)) orelse return ProtocolError.InvalidFormatCode, - }; - res.fields.?[i] = field; - } - return res; -} - -pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - var al = ByteArrayList.init(a); - defer al.deinit(); - var cw = std.io.countingWriter(al.writer()); - var writer = cw.writer(); - try writer.writeIntBig(u32, 0); // length placeholder - try writer.writeIntBig(u16, @as(u16, @intCast(self.fields.?.len))); - for (self.fields.?) |field| { - try writer.writeAll(field.name); - try writer.writeByte(0); - try writer.writeIntBig(u32, field.table_oid); - try writer.writeIntBig(u16, field.attr_no); - try writer.writeIntBig(u32, field.data_type_oid); - try writer.writeIntBig(i16, field.data_type_size); - try writer.writeIntBig(u32, field.data_type_modifier); - try writer.writeIntBig(u16, @intFromEnum(field.format_code)); - } - std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); - try stream_writer.writeAll(al.items); -} - -pub fn deinit(self: *RowDescription, a: std.mem.Allocator) void { - if (self.fields != null) a.free(self.fields.?); - if (self.buf != null) a.free(self.buf.?); -} - -test "round trip" { - const allocator = std.testing.allocator; - var fields = try allocator.alloc(Field, 3); - fields[0] = .{ - .name = "foo", - .table_oid = 1, - .attr_no = 2, - .data_type_oid = 3, - .data_type_size = 4, - .data_type_modifier = 5, - .format_code = .Binary, - }; - fields[1] = .{ - .name = "bar", - .table_oid = 1, - .attr_no = 2, - .data_type_oid = 3, - .data_type_size = 4, - .data_type_modifier = 5, - .format_code = .Binary, - }; - fields[2] = .{ - .name = "BAZZZZZ", - .table_oid = 99, - .attr_no = 98, - .data_type_oid = 97, - .data_type_size = 96, - .data_type_modifier = 95, - .format_code = .Text, - }; - var f0 = fields[0]; - var f1 = fields[1]; - var f2 = fields[2]; - var sm = RowDescription{ - .fields = fields, - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try RowDescription.read(allocator, buf); - defer sm2.deinit(allocator); - - try std.testing.expectEqualDeep(f0, sm2.fields.?[0]); - try std.testing.expectEqualDeep(f1, sm2.fields.?[1]); - try std.testing.expectEqualDeep(f2, sm2.fields.?[2]); -} diff --git a/src/startup_message.zig b/src/startup_message.zig deleted file mode 100644 index b5031c1..0000000 --- a/src/startup_message.zig +++ /dev/null @@ -1,85 +0,0 @@ -const std = @import("std"); -const ProtocolError = @import("main.zig").ProtocolError; -const SSHashMap = std.StringHashMap([]const u8); -const ByteArrayList = std.ArrayList(u8); - -const StartupMessage = @This(); - -const ProtocolVersionNumber: u32 = 196608; // 3.0 - -bytes: ?[]const u8 = null, // Owned -parameters: SSHashMap, - -// message length should already have been read, b should contain the payload -pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage { - if (b.len < 4) return ProtocolError.InvalidMessageLength; - - var bytes = try allocator.dupe(u8, b); - errdefer allocator.free(bytes); - const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]); - if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion; - - var parameters = SSHashMap.init(allocator); - var it = std.mem.splitScalar(u8, bytes[4..], 0); - while (it.next()) |next| { - const key = next; - const value = it.next() orelse return ProtocolError.InvalidKeyValuePair; - try parameters.put(key, value); - } - - return .{ - .bytes = bytes, - .parameters = parameters, - }; -} - -pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void { - var al = ByteArrayList.init(allocator); - defer al.deinit(); - var cw = std.io.countingWriter(al.writer()); - var writer = cw.writer(); - try writer.writeIntBig(u32, 0); // length placeholder - try writer.writeIntBig(u32, ProtocolVersionNumber); - var it = self.parameters.iterator(); - while (it.next()) |entry| { - try writer.writeAll(entry.key_ptr.*); - try writer.writeByte(0); - try writer.writeAll(entry.value_ptr.*); - try writer.writeByte(0); - } - try writer.writeByte(0); - std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); - try stream_writer.writeAll(al.items); -} - -pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void { - self.parameters.deinit(); - if (self.bytes != null) { - allocator.free(self.bytes.?); - } -} - -test "round trip" { - const allocator = std.testing.allocator; - var params = SSHashMap.init(allocator); - try params.put("hello", "postgres"); - var sm = StartupMessage{ - .parameters = params, - }; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try StartupMessage.read(allocator, buf); - defer sm2.deinit(allocator); - - try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?); -} -- cgit v1.2.3-ZIG