diff options
Diffstat (limited to 'src/proto')
-rw-r--r-- | src/proto/authentication_request.zig | 75 | ||||
-rw-r--r-- | src/proto/backend_key_data.zig | 52 | ||||
-rw-r--r-- | src/proto/command_complete.zig | 56 | ||||
-rw-r--r-- | src/proto/copy_in_response.zig | 83 | ||||
-rw-r--r-- | src/proto/data_row.zig | 88 | ||||
-rw-r--r-- | src/proto/error_response.zig | 170 | ||||
-rw-r--r-- | src/proto/parameter_status.zig | 64 | ||||
-rw-r--r-- | src/proto/password_message.zig | 45 | ||||
-rw-r--r-- | src/proto/proto.zig | 26 | ||||
-rw-r--r-- | src/proto/query.zig | 55 | ||||
-rw-r--r-- | src/proto/ready_for_query.zig | 53 | ||||
-rw-r--r-- | src/proto/row_description.zig | 136 | ||||
-rw-r--r-- | src/proto/startup_message.zig | 85 |
13 files changed, 988 insertions, 0 deletions
diff --git a/src/proto/authentication_request.zig b/src/proto/authentication_request.zig new file mode 100644 index 0000000..9203482 --- /dev/null +++ b/src/proto/authentication_request.zig @@ -0,0 +1,75 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; + +pub const Tag: u8 = 'R'; + +const AuthenticationRequest = @This(); + +pub const InnerAuthRequestType = enum(u32) { + AuthRequestTypeOk = 0, + AuthRequestTypeCleartextPassword = 3, +}; +pub const InnerAuthRequest = union { + ok: AuthRequestOk, + cleartext_password: AuthRequestCleartextPassword, +}; +pub const AuthRequestOk = struct {}; +pub const AuthRequestCleartextPassword = struct {}; + +// Authentication requests have multiple subtypes. +// It's not possible to have a tagged union with a custom backing integer, so do it the long way +inner_type: InnerAuthRequestType, +inner: InnerAuthRequest, + +pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest { + if (b.len != 4) { + log.err("invalid message length, expected 4 got {}", .{b.len}); + return ProtocolError.InvalidMessageLength; + } + const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; + var inner: InnerAuthRequest = switch (inner_type) { + .AuthRequestTypeOk => .{ .ok = AuthRequestOk{} }, + .AuthRequestTypeCleartextPassword => .{ .cleartext_password = AuthRequestCleartextPassword{} }, + }; + return .{ + .inner_type = inner_type, + .inner = inner, + }; +} + +pub fn write(self: AuthenticationRequest, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 8); + try stream_writer.writeIntBig(u32, @intFromEnum(self.inner_type)); +} + +pub fn deinit(_: *AuthenticationRequest, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = AuthenticationRequest{ + .inner_type = .AuthRequestTypeOk, + .inner = .{ .ok = AuthRequestOk{} }, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try AuthenticationRequest.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqual(InnerAuthRequestType.AuthRequestTypeOk, sm2.inner_type); +} diff --git a/src/proto/backend_key_data.zig b/src/proto/backend_key_data.zig new file mode 100644 index 0000000..7c32178 --- /dev/null +++ b/src/proto/backend_key_data.zig @@ -0,0 +1,52 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; + +const BackendKeyData = @This(); +pub const Tag: u8 = 'K'; + +process_id: u32, +secret_key: u32, + +pub fn read(_: std.mem.Allocator, b: []const u8) !BackendKeyData { + if (b.len != 8) return ProtocolError.InvalidMessageLength; + return .{ + .process_id = std.mem.readIntBig(u32, b[0..4]), + .secret_key = std.mem.readIntBig(u32, b[4..8]), + }; +} +pub fn write(self: BackendKeyData, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 12); // length + try stream_writer.writeIntBig(u32, self.process_id); + try stream_writer.writeIntBig(u32, self.secret_key); +} +pub fn deinit(_: *BackendKeyData, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = BackendKeyData{ + .process_id = 123, + .secret_key = 345, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try BackendKeyData.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqual(@as(u32, 123), sm2.process_id); + try std.testing.expectEqual(@as(u32, 345), sm2.secret_key); +} diff --git a/src/proto/command_complete.zig b/src/proto/command_complete.zig new file mode 100644 index 0000000..80014e9 --- /dev/null +++ b/src/proto/command_complete.zig @@ -0,0 +1,56 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; +const FormatCode = @import("../main.zig").FormatCode; + +pub const Tag: u8 = 'C'; + +const CommandComplete = @This(); + +command_tag: []const u8, +owned: bool = false, + +pub fn read(a: std.mem.Allocator, b: []const u8) !CommandComplete { + return .{ + .command_tag = try a.dupe(u8, b), + .owned = true, + }; +} + +pub fn write(self: CommandComplete, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, @as(u32, @intCast(4+self.command_tag.len))); + try stream_writer.writeAll(self.command_tag); +} + +pub fn deinit(self: *CommandComplete, a: std.mem.Allocator) void { + if (self.owned) a.free(self.command_tag); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = CommandComplete{ + .command_tag = "foo", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try CommandComplete.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("foo", sm2.command_tag); +} diff --git a/src/proto/copy_in_response.zig b/src/proto/copy_in_response.zig new file mode 100644 index 0000000..65d9ad7 --- /dev/null +++ b/src/proto/copy_in_response.zig @@ -0,0 +1,83 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; +const FormatCode = @import("../main.zig").FormatCode; + +pub const Tag: u8 = 'G'; +// TODO generics it's the same as CopyOutResponse and CopyBothResponse. +const CopyInResponse = @This(); + +overall_format_code: u8, +format_codes: []const FormatCode, // owned + +pub fn read(a: std.mem.Allocator, b: []const u8) !CopyInResponse { + var fbs = std.io.fixedBufferStream(b); + var reader = fbs.reader(); + const overall_format_code = try reader.readIntBig(u8); + const n_columns = try reader.readIntBig(u16); + var format_codes = try a.alloc(FormatCode, n_columns); + errdefer a.free(format_codes); + for (0..n_columns) |i| { + const int_format_code = try reader.readIntBig(u16); + format_codes[i] = enum_from_int(FormatCode, int_format_code) orelse return ProtocolError.InvalidFormatCode; + } + return .{ + .overall_format_code = overall_format_code, + .format_codes = format_codes, + }; +} + +pub fn write(self: CopyInResponse, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeIntBig(u8, Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u8, self.overall_format_code); + try writer.writeIntBig(u16, @intCast(self.format_codes.len)); + for (self.format_codes) |format_code| { + try writer.writeIntBig(u16, @intFromEnum(format_code)); + } + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); // Update length + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *CopyInResponse, a: std.mem.Allocator) void { + a.free(self.format_codes); +} + +test "round trip" { + const allocator = std.testing.allocator; + var format_codes = try allocator.alloc(FormatCode, 3); + format_codes[0] = .Binary; + format_codes[1] = .Binary; + format_codes[2] = .Text; + var sm = CopyInResponse{ + .overall_format_code = 1, + .format_codes = format_codes, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try CopyInResponse.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[0]); + try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[1]); + try std.testing.expectEqual(FormatCode.Text, sm2.format_codes[2]); +} diff --git a/src/proto/data_row.zig b/src/proto/data_row.zig new file mode 100644 index 0000000..c20b794 --- /dev/null +++ b/src/proto/data_row.zig @@ -0,0 +1,88 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; + +pub const Tag: u8 = 'D'; + +const DataRow = @This(); + +buf: ?[]const u8 = null, // owned +columns: [][]const u8, // also owned + +pub fn read(a: std.mem.Allocator, b: []const u8) !DataRow { + if (b.len < 2) return ProtocolError.InvalidMessageLength; + var buf = try a.dupe(u8, b); + var res: DataRow = undefined; + res.buf = buf; + errdefer res.deinit(a); + + const n_columns = std.mem.readIntBig(u16, buf[0..2]); + const columns = try a.alloc([]const u8, n_columns); + errdefer a.free(columns); + var pos: usize = 2; + for (0..n_columns) |col| { + const len = std.mem.readIntBig(u32, buf[pos..(pos+4)][0..4]); // second slice forces the slice size to be known at comptime and satisfy the type check on readIntBig! + const data = if (len > 0) buf[(pos+4)..(pos+4+len)] else &[_]u8{}; + columns[col] = data; + pos += (4+len); + } + res.columns = columns; + return res; +} + +pub fn write(self: DataRow, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u16, @as(u16, @intCast(self.columns.len))); + for (self.columns) |column| { + const len = @as(u32, @intCast(column.len)); + try writer.writeIntBig(u32, len); + try writer.writeAll(column); + } + // Fixup the length and write to the original stream + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *DataRow, a: std.mem.Allocator) void { + if (self.buf != null) a.free(self.buf.?); + a.free(self.columns); +} + +test "round trip" { + const allocator = std.testing.allocator; + const columns = try allocator.alloc([]const u8, 3); + columns[0] = "Hello"; + columns[1] = "FooBar"; + columns[2] = ""; + var sm = DataRow{ + .columns = columns, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try DataRow.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("Hello", sm2.columns[0]); + try std.testing.expectEqualStrings("FooBar", sm2.columns[1]); + try std.testing.expectEqualStrings("", sm2.columns[2]); +} diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig new file mode 100644 index 0000000..2aafe8d --- /dev/null +++ b/src/proto/error_response.zig @@ -0,0 +1,170 @@ +const std = @import("std"); +const HMByteString = std.AutoHashMap(u8, []const u8); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; + +const ErrorResponse = @This(); +pub const Tag: u8 = 'E'; + +buf: ?[]const u8 = null, // owned +severity: []const u8, +severity_unlocalized: ?[]const u8 = null, +code: []const u8, +message: []const u8, +detail: ?[]const u8 = null, +hint: ?[]const u8 = null, +position: ?u32 = null, +internal_position: ?u32 = null, +internal_query: ?[]const u8 = null, +where: ?[]const u8 = null, +schema_name: ?[]const u8 = null, +table_name: ?[]const u8 = null, +column_name: ?[]const u8 = null, +data_type_name: ?[]const u8 = null, +constraint_name: ?[]const u8 = null, +file_name: ?[]const u8 = null, +line: ?u32 = null, +routine: ?[]const u8 = null, +unknown_fields: HMByteString, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse { + var res: ErrorResponse = undefined; + res.unknown_fields = HMByteString.init(allocator); + res.buf = try allocator.dupe(u8, b); + errdefer allocator.free(res.buf.?); + var it = std.mem.splitScalar(u8, res.buf.?, 0); + var setSev = false; var setCode = false; var setMsg = false; + while (it.next()) |next| { + if (next.len < 1) break; + switch (next[0]) { + 0 => break, + 'S' => { + res.severity = next[1..]; + setSev = true; + }, + 'V' => { + res.severity_unlocalized = next[1..]; + }, + 'C' => { + res.code = next[1..]; + setCode = true; + }, + 'M' => { + res.message = next[1..]; + setMsg = true; + }, + 'D' => { + res.detail = next[1..]; + }, + 'H' => { + res.hint = next[1..]; + }, + 'P' => { + res.position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'p' => { + res.internal_position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'q' => { + res.internal_query = next[1..]; + }, + 'W' => { + res.where = next[1..]; + }, + 's' => { + res.schema_name = next[1..]; + }, + 't' => { + res.table_name = next[1..]; + }, + 'c' => { + res.column_name = next[1..]; + }, + 'd' => { + res.data_type_name = next[1..]; + }, + 'n' => { + res.constraint_name = next[1..]; + }, + 'F' => { + res.file_name = next[1..]; + }, + 'L' => { + res.line = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'R' => { + res.routine = next[1..]; + }, + else => { + try res.unknown_fields.put(next[0], next[1..]); + } + } + } + if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField; + return res; +} +pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(allocator); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // Length placeholder. + + try write_field_nt('S', self, "severity", writer); + if (self.severity_unlocalized) |severity_unlocalized| { + try write_nt('V', severity_unlocalized, writer); + } + try write_field_nt('C', self, "code", writer); + try write_field_nt('M', self, "message", writer); + // TODO rest of the fields + + // replace the length and write it to the actual stream + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} +fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void { + try write_nt(tag, @field(self, field), writer); +} +fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void { + try writer.writeByte(tag); + try writer.writeAll(value); + try writer.writeByte(0); +} + +pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void { + self.unknown_fields.deinit(); + if (self.buf != null) allocator.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ErrorResponse{ + .severity = "foo", + .severity_unlocalized = "foo_unlocal", + .code = "bar", + .message = "baz", + .unknown_fields = HMByteString.init(allocator), + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ErrorResponse.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("foo", sm2.severity); + try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?); + try std.testing.expectEqualStrings("bar", sm2.code); + try std.testing.expectEqualStrings("baz", sm2.message); +} diff --git a/src/proto/parameter_status.zig b/src/proto/parameter_status.zig new file mode 100644 index 0000000..5f95695 --- /dev/null +++ b/src/proto/parameter_status.zig @@ -0,0 +1,64 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; + +const ParameterStatus = @This(); +pub const Tag: u8 = 'S'; + +buf: ?[]const u8 = null, // owned +name: []const u8, +value: []const u8, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ParameterStatus { + var res: ParameterStatus = undefined; + res.buf = try allocator.dupe(u8, b); + var it = std.mem.splitScalar(u8, res.buf.?, 0); + res.name = it.first(); + res.value = it.next() orelse return ProtocolError.MissingField; + return res; +} +pub fn write(self: ParameterStatus, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeAll(self.name); + try writer.writeByte(0); + try writer.writeAll(self.value); + try writer.writeByte(0); + std.mem.writeIntBig(u32, al.items[0..4], @as(u32,@intCast(cw.bytes_written))); // Fix length + try stream_writer.writeAll(al.items); +} +pub fn deinit(self: *ParameterStatus, allocator: std.mem.Allocator) void { + if (self.buf != null) allocator.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ParameterStatus{ + .name = "Hello", + .value = "world", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ParameterStatus.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqualStrings("Hello", sm2.name); + try std.testing.expectEqualStrings("world", sm2.value); +} diff --git a/src/proto/password_message.zig b/src/proto/password_message.zig new file mode 100644 index 0000000..1a8c17a --- /dev/null +++ b/src/proto/password_message.zig @@ -0,0 +1,45 @@ +const std = @import("std"); +const ByteArrayList = std.ArrayList(u8); +const PasswordMessage = @This(); + +pub const Tag: u8 = 'p'; +password: []const u8, +password_owned: bool = false, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !PasswordMessage { + return .{ .password = try allocator.dupe(u8, b), .password_owned = true }; +} + +pub fn write(self: PasswordMessage, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 5 + @as(u32, @intCast(self.password.len))); + try stream_writer.writeAll(self.password); + try stream_writer.writeByte(0); +} + +pub fn deinit(self: *PasswordMessage, allocator: std.mem.Allocator) void { + if (self.password_owned) allocator.free(self.password); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = PasswordMessage{ + .password = "foobar", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try PasswordMessage.read(allocator, buf); + defer sm2.deinit(allocator); +} diff --git a/src/proto/proto.zig b/src/proto/proto.zig new file mode 100644 index 0000000..4465e4b --- /dev/null +++ b/src/proto/proto.zig @@ -0,0 +1,26 @@ +pub const StartupMessage = @import("startup_message.zig"); +pub const AuthenticationRequest = @import("authentication_request.zig"); +pub const PasswordMessage = @import("password_message.zig"); +pub const ErrorResponse = @import("error_response.zig"); +pub const ReadyForQuery = @import("ready_for_query.zig"); +pub const ParameterStatus = @import("parameter_status.zig"); +pub const BackendKeyData = @import("backend_key_data.zig"); +pub const Query = @import("query.zig"); +pub const DataRow = @import("data_row.zig"); +pub const RowDescription = @import("row_description.zig"); +pub const CommandComplete = @import("command_complete.zig"); +pub const CopyInResponse = @import("copy_in_response.zig"); + +test { + _ = AuthenticationRequest; + _ = PasswordMessage; + _ = ErrorResponse; + _ = ReadyForQuery; + _ = ParameterStatus; + _ = BackendKeyData; + _ = Query; + _ = DataRow; + _ = RowDescription; + _ = CommandComplete; + _ = CopyInResponse; +} diff --git a/src/proto/query.zig b/src/proto/query.zig new file mode 100644 index 0000000..9f238fc --- /dev/null +++ b/src/proto/query.zig @@ -0,0 +1,55 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; + +pub const Tag: u8 = 'Q'; + +const Query = @This(); + +string: []const u8, +owned: bool = false, + +pub fn read(a: std.mem.Allocator, b: []const u8) !Query { + return .{ + .string = try a.dupe(u8, b[0..(b.len-1)]), // Drop the null terminator + .owned = true, + }; +} + +pub fn write(self: Query, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, @as(u32, @intCast(self.string.len+5))); + try stream_writer.writeAll(self.string); + try stream_writer.writeByte(0); +} + +pub fn deinit(self: *Query, a: std.mem.Allocator) void { + if (self.owned) a.free(self.string); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = Query{ + .string = "Hello", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try Query.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqualStrings("Hello", sm2.string); +} diff --git a/src/proto/ready_for_query.zig b/src/proto/ready_for_query.zig new file mode 100644 index 0000000..ef99e60 --- /dev/null +++ b/src/proto/ready_for_query.zig @@ -0,0 +1,53 @@ +const std = @import("std"); +const ProtocolError = @import("../main.zig").ProtocolError; +const enum_from_int = @import("../main.zig").enum_from_int; +const ByteArrayList = std.ArrayList(u8); + +const ReadyForQuery = @This(); +pub const Tag: u8 = 'Z'; + +const TransactionStatus = enum(u8) { + idle = 'I', + transaction = 'T', + err = 'E', +}; + +transaction_status: TransactionStatus, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ReadyForQuery { + _ = allocator; + if (b.len != 1) return ProtocolError.InvalidMessageLength; + return .{ .transaction_status = enum_from_int(TransactionStatus, b[0]) orelse return ProtocolError.InvalidTransactionStatus }; +} +pub fn write(self: ReadyForQuery, allocator: std.mem.Allocator, stream_writer: anytype) !void { + _ = allocator; + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 5); + try stream_writer.writeByte(@intFromEnum(self.transaction_status)); +} +pub fn deinit(_: *ReadyForQuery, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ReadyForQuery{ + .transaction_status = TransactionStatus.idle, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ReadyForQuery.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqual(TransactionStatus.idle, sm2.transaction_status); +} diff --git a/src/proto/row_description.zig b/src/proto/row_description.zig new file mode 100644 index 0000000..a0e8810 --- /dev/null +++ b/src/proto/row_description.zig @@ -0,0 +1,136 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("../main.zig").ProtocolError; +const ClientError = @import("../main.zig").ClientError; +const enum_from_int = @import("../main.zig").enum_from_int; +const FormatCode = @import("../main.zig").FormatCode; + +pub const Tag: u8 = 'T'; + +const RowDescription = @This(); + +buf: ?[]const u8 = null, // owned +fields: ?[]Field = null, // owned + +pub const Field = struct { + name: []const u8, + table_oid: u32, + attr_no: u16, + data_type_oid: u32, + data_type_size: i16, + data_type_modifier: u32, + format_code: FormatCode, +}; + +pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { + var res: RowDescription = undefined; + res.buf = try a.dupe(u8, b); + errdefer res.deinit(a); + var fbs = std.io.fixedBufferStream(res.buf.?); + var reader = fbs.reader(); + const n_fields = try reader.readIntBig(u16); + res.fields = try a.alloc(Field, n_fields); + for (0..n_fields) |i| { + const name_start = fbs.pos; + try reader.skipUntilDelimiterOrEof(0); + const name_end = fbs.pos-1; + const name = res.buf.?[name_start..name_end]; + const field = Field{ + .name = name, + .table_oid = try reader.readIntBig(u32), + .attr_no = try reader.readIntBig(u16), + .data_type_oid = try reader.readIntBig(u32), + .data_type_size = try reader.readIntBig(i16), + .data_type_modifier = try reader.readIntBig(u32), + .format_code = enum_from_int(FormatCode, try reader.readIntBig(u16)) orelse return ProtocolError.InvalidFormatCode, + }; + res.fields.?[i] = field; + } + return res; +} + +pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u16, @as(u16, @intCast(self.fields.?.len))); + for (self.fields.?) |field| { + try writer.writeAll(field.name); + try writer.writeByte(0); + try writer.writeIntBig(u32, field.table_oid); + try writer.writeIntBig(u16, field.attr_no); + try writer.writeIntBig(u32, field.data_type_oid); + try writer.writeIntBig(i16, field.data_type_size); + try writer.writeIntBig(u32, field.data_type_modifier); + try writer.writeIntBig(u16, @intFromEnum(field.format_code)); + } + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *RowDescription, a: std.mem.Allocator) void { + if (self.fields != null) a.free(self.fields.?); + if (self.buf != null) a.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var fields = try allocator.alloc(Field, 3); + fields[0] = .{ + .name = "foo", + .table_oid = 1, + .attr_no = 2, + .data_type_oid = 3, + .data_type_size = 4, + .data_type_modifier = 5, + .format_code = .Binary, + }; + fields[1] = .{ + .name = "bar", + .table_oid = 1, + .attr_no = 2, + .data_type_oid = 3, + .data_type_size = 4, + .data_type_modifier = 5, + .format_code = .Binary, + }; + fields[2] = .{ + .name = "BAZZZZZ", + .table_oid = 99, + .attr_no = 98, + .data_type_oid = 97, + .data_type_size = 96, + .data_type_modifier = 95, + .format_code = .Text, + }; + var f0 = fields[0]; + var f1 = fields[1]; + var f2 = fields[2]; + var sm = RowDescription{ + .fields = fields, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try RowDescription.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualDeep(f0, sm2.fields.?[0]); + try std.testing.expectEqualDeep(f1, sm2.fields.?[1]); + try std.testing.expectEqualDeep(f2, sm2.fields.?[2]); +} diff --git a/src/proto/startup_message.zig b/src/proto/startup_message.zig new file mode 100644 index 0000000..8224bdb --- /dev/null +++ b/src/proto/startup_message.zig @@ -0,0 +1,85 @@ +const std = @import("std"); +const ProtocolError = @import("../main.zig").ProtocolError; +const SSHashMap = std.StringHashMap([]const u8); +const ByteArrayList = std.ArrayList(u8); + +const StartupMessage = @This(); + +const ProtocolVersionNumber: u32 = 196608; // 3.0 + +bytes: ?[]const u8 = null, // Owned +parameters: SSHashMap, + +// message length should already have been read, b should contain the payload +pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage { + if (b.len < 4) return ProtocolError.InvalidMessageLength; + + var bytes = try allocator.dupe(u8, b); + errdefer allocator.free(bytes); + const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]); + if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion; + + var parameters = SSHashMap.init(allocator); + var it = std.mem.splitScalar(u8, bytes[4..], 0); + while (it.next()) |next| { + const key = next; + const value = it.next() orelse return ProtocolError.InvalidKeyValuePair; + try parameters.put(key, value); + } + + return .{ + .bytes = bytes, + .parameters = parameters, + }; +} + +pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void { + var al = ByteArrayList.init(allocator); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u32, ProtocolVersionNumber); + var it = self.parameters.iterator(); + while (it.next()) |entry| { + try writer.writeAll(entry.key_ptr.*); + try writer.writeByte(0); + try writer.writeAll(entry.value_ptr.*); + try writer.writeByte(0); + } + try writer.writeByte(0); + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void { + self.parameters.deinit(); + if (self.bytes != null) { + allocator.free(self.bytes.?); + } +} + +test "round trip" { + const allocator = std.testing.allocator; + var params = SSHashMap.init(allocator); + try params.put("hello", "postgres"); + var sm = StartupMessage{ + .parameters = params, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try StartupMessage.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?); +} |