aboutsummaryrefslogtreecommitdiff
path: root/src/proto
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-09-26 06:51:06 +0100
committerMartin Ashby <martin@ashbysoft.com>2023-09-26 06:51:06 +0100
commit183d60a6e87230cc767c56900b94c9c694596de1 (patch)
treec08b473a293dc465989a09c5d681191898cf2c2f /src/proto
parent02f9e99bfccad8837d327880f756ec7bab711783 (diff)
downloadpgz-183d60a6e87230cc767c56900b94c9c694596de1.tar.gz
pgz-183d60a6e87230cc767c56900b94c9c694596de1.tar.bz2
pgz-183d60a6e87230cc767c56900b94c9c694596de1.tar.xz
pgz-183d60a6e87230cc767c56900b94c9c694596de1.zip
Move protocol definitions into a subfolder
Diffstat (limited to 'src/proto')
-rw-r--r--src/proto/authentication_request.zig75
-rw-r--r--src/proto/backend_key_data.zig52
-rw-r--r--src/proto/command_complete.zig56
-rw-r--r--src/proto/copy_in_response.zig83
-rw-r--r--src/proto/data_row.zig88
-rw-r--r--src/proto/error_response.zig170
-rw-r--r--src/proto/parameter_status.zig64
-rw-r--r--src/proto/password_message.zig45
-rw-r--r--src/proto/proto.zig26
-rw-r--r--src/proto/query.zig55
-rw-r--r--src/proto/ready_for_query.zig53
-rw-r--r--src/proto/row_description.zig136
-rw-r--r--src/proto/startup_message.zig85
13 files changed, 988 insertions, 0 deletions
diff --git a/src/proto/authentication_request.zig b/src/proto/authentication_request.zig
new file mode 100644
index 0000000..9203482
--- /dev/null
+++ b/src/proto/authentication_request.zig
@@ -0,0 +1,75 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+const enum_from_int = @import("../main.zig").enum_from_int;
+
+pub const Tag: u8 = 'R';
+
+const AuthenticationRequest = @This();
+
+pub const InnerAuthRequestType = enum(u32) {
+ AuthRequestTypeOk = 0,
+ AuthRequestTypeCleartextPassword = 3,
+};
+pub const InnerAuthRequest = union {
+ ok: AuthRequestOk,
+ cleartext_password: AuthRequestCleartextPassword,
+};
+pub const AuthRequestOk = struct {};
+pub const AuthRequestCleartextPassword = struct {};
+
+// Authentication requests have multiple subtypes.
+// It's not possible to have a tagged union with a custom backing integer, so do it the long way
+inner_type: InnerAuthRequestType,
+inner: InnerAuthRequest,
+
+pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest {
+ if (b.len != 4) {
+ log.err("invalid message length, expected 4 got {}", .{b.len});
+ return ProtocolError.InvalidMessageLength;
+ }
+ const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType;
+ var inner: InnerAuthRequest = switch (inner_type) {
+ .AuthRequestTypeOk => .{ .ok = AuthRequestOk{} },
+ .AuthRequestTypeCleartextPassword => .{ .cleartext_password = AuthRequestCleartextPassword{} },
+ };
+ return .{
+ .inner_type = inner_type,
+ .inner = inner,
+ };
+}
+
+pub fn write(self: AuthenticationRequest, _: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, 8);
+ try stream_writer.writeIntBig(u32, @intFromEnum(self.inner_type));
+}
+
+pub fn deinit(_: *AuthenticationRequest, _: std.mem.Allocator) void {}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = AuthenticationRequest{
+ .inner_type = .AuthRequestTypeOk,
+ .inner = .{ .ok = AuthRequestOk{} },
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try AuthenticationRequest.read(allocator, buf);
+ defer sm2.deinit(allocator);
+ try std.testing.expectEqual(InnerAuthRequestType.AuthRequestTypeOk, sm2.inner_type);
+}
diff --git a/src/proto/backend_key_data.zig b/src/proto/backend_key_data.zig
new file mode 100644
index 0000000..7c32178
--- /dev/null
+++ b/src/proto/backend_key_data.zig
@@ -0,0 +1,52 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+
+const BackendKeyData = @This();
+pub const Tag: u8 = 'K';
+
+process_id: u32,
+secret_key: u32,
+
+pub fn read(_: std.mem.Allocator, b: []const u8) !BackendKeyData {
+ if (b.len != 8) return ProtocolError.InvalidMessageLength;
+ return .{
+ .process_id = std.mem.readIntBig(u32, b[0..4]),
+ .secret_key = std.mem.readIntBig(u32, b[4..8]),
+ };
+}
+pub fn write(self: BackendKeyData, _: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, 12); // length
+ try stream_writer.writeIntBig(u32, self.process_id);
+ try stream_writer.writeIntBig(u32, self.secret_key);
+}
+pub fn deinit(_: *BackendKeyData, _: std.mem.Allocator) void {}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = BackendKeyData{
+ .process_id = 123,
+ .secret_key = 345,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try BackendKeyData.read(allocator, buf);
+ defer sm2.deinit(allocator);
+ try std.testing.expectEqual(@as(u32, 123), sm2.process_id);
+ try std.testing.expectEqual(@as(u32, 345), sm2.secret_key);
+}
diff --git a/src/proto/command_complete.zig b/src/proto/command_complete.zig
new file mode 100644
index 0000000..80014e9
--- /dev/null
+++ b/src/proto/command_complete.zig
@@ -0,0 +1,56 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+const enum_from_int = @import("../main.zig").enum_from_int;
+const FormatCode = @import("../main.zig").FormatCode;
+
+pub const Tag: u8 = 'C';
+
+const CommandComplete = @This();
+
+command_tag: []const u8,
+owned: bool = false,
+
+pub fn read(a: std.mem.Allocator, b: []const u8) !CommandComplete {
+ return .{
+ .command_tag = try a.dupe(u8, b),
+ .owned = true,
+ };
+}
+
+pub fn write(self: CommandComplete, _: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, @as(u32, @intCast(4+self.command_tag.len)));
+ try stream_writer.writeAll(self.command_tag);
+}
+
+pub fn deinit(self: *CommandComplete, a: std.mem.Allocator) void {
+ if (self.owned) a.free(self.command_tag);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = CommandComplete{
+ .command_tag = "foo",
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try CommandComplete.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("foo", sm2.command_tag);
+}
diff --git a/src/proto/copy_in_response.zig b/src/proto/copy_in_response.zig
new file mode 100644
index 0000000..65d9ad7
--- /dev/null
+++ b/src/proto/copy_in_response.zig
@@ -0,0 +1,83 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+const enum_from_int = @import("../main.zig").enum_from_int;
+const FormatCode = @import("../main.zig").FormatCode;
+
+pub const Tag: u8 = 'G';
+// TODO generics it's the same as CopyOutResponse and CopyBothResponse.
+const CopyInResponse = @This();
+
+overall_format_code: u8,
+format_codes: []const FormatCode, // owned
+
+pub fn read(a: std.mem.Allocator, b: []const u8) !CopyInResponse {
+ var fbs = std.io.fixedBufferStream(b);
+ var reader = fbs.reader();
+ const overall_format_code = try reader.readIntBig(u8);
+ const n_columns = try reader.readIntBig(u16);
+ var format_codes = try a.alloc(FormatCode, n_columns);
+ errdefer a.free(format_codes);
+ for (0..n_columns) |i| {
+ const int_format_code = try reader.readIntBig(u16);
+ format_codes[i] = enum_from_int(FormatCode, int_format_code) orelse return ProtocolError.InvalidFormatCode;
+ }
+ return .{
+ .overall_format_code = overall_format_code,
+ .format_codes = format_codes,
+ };
+}
+
+pub fn write(self: CopyInResponse, a: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeIntBig(u8, Tag);
+ var al = ByteArrayList.init(a);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeIntBig(u8, self.overall_format_code);
+ try writer.writeIntBig(u16, @intCast(self.format_codes.len));
+ for (self.format_codes) |format_code| {
+ try writer.writeIntBig(u16, @intFromEnum(format_code));
+ }
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); // Update length
+ try stream_writer.writeAll(al.items);
+}
+
+pub fn deinit(self: *CopyInResponse, a: std.mem.Allocator) void {
+ a.free(self.format_codes);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var format_codes = try allocator.alloc(FormatCode, 3);
+ format_codes[0] = .Binary;
+ format_codes[1] = .Binary;
+ format_codes[2] = .Text;
+ var sm = CopyInResponse{
+ .overall_format_code = 1,
+ .format_codes = format_codes,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try CopyInResponse.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[0]);
+ try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[1]);
+ try std.testing.expectEqual(FormatCode.Text, sm2.format_codes[2]);
+}
diff --git a/src/proto/data_row.zig b/src/proto/data_row.zig
new file mode 100644
index 0000000..c20b794
--- /dev/null
+++ b/src/proto/data_row.zig
@@ -0,0 +1,88 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+const enum_from_int = @import("../main.zig").enum_from_int;
+
+pub const Tag: u8 = 'D';
+
+const DataRow = @This();
+
+buf: ?[]const u8 = null, // owned
+columns: [][]const u8, // also owned
+
+pub fn read(a: std.mem.Allocator, b: []const u8) !DataRow {
+ if (b.len < 2) return ProtocolError.InvalidMessageLength;
+ var buf = try a.dupe(u8, b);
+ var res: DataRow = undefined;
+ res.buf = buf;
+ errdefer res.deinit(a);
+
+ const n_columns = std.mem.readIntBig(u16, buf[0..2]);
+ const columns = try a.alloc([]const u8, n_columns);
+ errdefer a.free(columns);
+ var pos: usize = 2;
+ for (0..n_columns) |col| {
+ const len = std.mem.readIntBig(u32, buf[pos..(pos+4)][0..4]); // second slice forces the slice size to be known at comptime and satisfy the type check on readIntBig!
+ const data = if (len > 0) buf[(pos+4)..(pos+4+len)] else &[_]u8{};
+ columns[col] = data;
+ pos += (4+len);
+ }
+ res.columns = columns;
+ return res;
+}
+
+pub fn write(self: DataRow, a: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ var al = ByteArrayList.init(a);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeIntBig(u16, @as(u16, @intCast(self.columns.len)));
+ for (self.columns) |column| {
+ const len = @as(u32, @intCast(column.len));
+ try writer.writeIntBig(u32, len);
+ try writer.writeAll(column);
+ }
+ // Fixup the length and write to the original stream
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+
+pub fn deinit(self: *DataRow, a: std.mem.Allocator) void {
+ if (self.buf != null) a.free(self.buf.?);
+ a.free(self.columns);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ const columns = try allocator.alloc([]const u8, 3);
+ columns[0] = "Hello";
+ columns[1] = "FooBar";
+ columns[2] = "";
+ var sm = DataRow{
+ .columns = columns,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try DataRow.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("Hello", sm2.columns[0]);
+ try std.testing.expectEqualStrings("FooBar", sm2.columns[1]);
+ try std.testing.expectEqualStrings("", sm2.columns[2]);
+}
diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig
new file mode 100644
index 0000000..2aafe8d
--- /dev/null
+++ b/src/proto/error_response.zig
@@ -0,0 +1,170 @@
+const std = @import("std");
+const HMByteString = std.AutoHashMap(u8, []const u8);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+
+const ErrorResponse = @This();
+pub const Tag: u8 = 'E';
+
+buf: ?[]const u8 = null, // owned
+severity: []const u8,
+severity_unlocalized: ?[]const u8 = null,
+code: []const u8,
+message: []const u8,
+detail: ?[]const u8 = null,
+hint: ?[]const u8 = null,
+position: ?u32 = null,
+internal_position: ?u32 = null,
+internal_query: ?[]const u8 = null,
+where: ?[]const u8 = null,
+schema_name: ?[]const u8 = null,
+table_name: ?[]const u8 = null,
+column_name: ?[]const u8 = null,
+data_type_name: ?[]const u8 = null,
+constraint_name: ?[]const u8 = null,
+file_name: ?[]const u8 = null,
+line: ?u32 = null,
+routine: ?[]const u8 = null,
+unknown_fields: HMByteString,
+
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse {
+ var res: ErrorResponse = undefined;
+ res.unknown_fields = HMByteString.init(allocator);
+ res.buf = try allocator.dupe(u8, b);
+ errdefer allocator.free(res.buf.?);
+ var it = std.mem.splitScalar(u8, res.buf.?, 0);
+ var setSev = false; var setCode = false; var setMsg = false;
+ while (it.next()) |next| {
+ if (next.len < 1) break;
+ switch (next[0]) {
+ 0 => break,
+ 'S' => {
+ res.severity = next[1..];
+ setSev = true;
+ },
+ 'V' => {
+ res.severity_unlocalized = next[1..];
+ },
+ 'C' => {
+ res.code = next[1..];
+ setCode = true;
+ },
+ 'M' => {
+ res.message = next[1..];
+ setMsg = true;
+ },
+ 'D' => {
+ res.detail = next[1..];
+ },
+ 'H' => {
+ res.hint = next[1..];
+ },
+ 'P' => {
+ res.position = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'p' => {
+ res.internal_position = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'q' => {
+ res.internal_query = next[1..];
+ },
+ 'W' => {
+ res.where = next[1..];
+ },
+ 's' => {
+ res.schema_name = next[1..];
+ },
+ 't' => {
+ res.table_name = next[1..];
+ },
+ 'c' => {
+ res.column_name = next[1..];
+ },
+ 'd' => {
+ res.data_type_name = next[1..];
+ },
+ 'n' => {
+ res.constraint_name = next[1..];
+ },
+ 'F' => {
+ res.file_name = next[1..];
+ },
+ 'L' => {
+ res.line = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'R' => {
+ res.routine = next[1..];
+ },
+ else => {
+ try res.unknown_fields.put(next[0], next[1..]);
+ }
+ }
+ }
+ if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField;
+ return res;
+}
+pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ var al = ByteArrayList.init(allocator);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // Length placeholder.
+
+ try write_field_nt('S', self, "severity", writer);
+ if (self.severity_unlocalized) |severity_unlocalized| {
+ try write_nt('V', severity_unlocalized, writer);
+ }
+ try write_field_nt('C', self, "code", writer);
+ try write_field_nt('M', self, "message", writer);
+ // TODO rest of the fields
+
+ // replace the length and write it to the actual stream
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void {
+ try write_nt(tag, @field(self, field), writer);
+}
+fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void {
+ try writer.writeByte(tag);
+ try writer.writeAll(value);
+ try writer.writeByte(0);
+}
+
+pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void {
+ self.unknown_fields.deinit();
+ if (self.buf != null) allocator.free(self.buf.?);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = ErrorResponse{
+ .severity = "foo",
+ .severity_unlocalized = "foo_unlocal",
+ .code = "bar",
+ .message = "baz",
+ .unknown_fields = HMByteString.init(allocator),
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try ErrorResponse.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("foo", sm2.severity);
+ try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?);
+ try std.testing.expectEqualStrings("bar", sm2.code);
+ try std.testing.expectEqualStrings("baz", sm2.message);
+}
diff --git a/src/proto/parameter_status.zig b/src/proto/parameter_status.zig
new file mode 100644
index 0000000..5f95695
--- /dev/null
+++ b/src/proto/parameter_status.zig
@@ -0,0 +1,64 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+
+const ParameterStatus = @This();
+pub const Tag: u8 = 'S';
+
+buf: ?[]const u8 = null, // owned
+name: []const u8,
+value: []const u8,
+
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !ParameterStatus {
+ var res: ParameterStatus = undefined;
+ res.buf = try allocator.dupe(u8, b);
+ var it = std.mem.splitScalar(u8, res.buf.?, 0);
+ res.name = it.first();
+ res.value = it.next() orelse return ProtocolError.MissingField;
+ return res;
+}
+pub fn write(self: ParameterStatus, a: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ var al = ByteArrayList.init(a);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeAll(self.name);
+ try writer.writeByte(0);
+ try writer.writeAll(self.value);
+ try writer.writeByte(0);
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32,@intCast(cw.bytes_written))); // Fix length
+ try stream_writer.writeAll(al.items);
+}
+pub fn deinit(self: *ParameterStatus, allocator: std.mem.Allocator) void {
+ if (self.buf != null) allocator.free(self.buf.?);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = ParameterStatus{
+ .name = "Hello",
+ .value = "world",
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try ParameterStatus.read(allocator, buf);
+ defer sm2.deinit(allocator);
+ try std.testing.expectEqualStrings("Hello", sm2.name);
+ try std.testing.expectEqualStrings("world", sm2.value);
+}
diff --git a/src/proto/password_message.zig b/src/proto/password_message.zig
new file mode 100644
index 0000000..1a8c17a
--- /dev/null
+++ b/src/proto/password_message.zig
@@ -0,0 +1,45 @@
+const std = @import("std");
+const ByteArrayList = std.ArrayList(u8);
+const PasswordMessage = @This();
+
+pub const Tag: u8 = 'p';
+password: []const u8,
+password_owned: bool = false,
+
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !PasswordMessage {
+ return .{ .password = try allocator.dupe(u8, b), .password_owned = true };
+}
+
+pub fn write(self: PasswordMessage, _: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, 5 + @as(u32, @intCast(self.password.len)));
+ try stream_writer.writeAll(self.password);
+ try stream_writer.writeByte(0);
+}
+
+pub fn deinit(self: *PasswordMessage, allocator: std.mem.Allocator) void {
+ if (self.password_owned) allocator.free(self.password);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = PasswordMessage{
+ .password = "foobar",
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try PasswordMessage.read(allocator, buf);
+ defer sm2.deinit(allocator);
+}
diff --git a/src/proto/proto.zig b/src/proto/proto.zig
new file mode 100644
index 0000000..4465e4b
--- /dev/null
+++ b/src/proto/proto.zig
@@ -0,0 +1,26 @@
+pub const StartupMessage = @import("startup_message.zig");
+pub const AuthenticationRequest = @import("authentication_request.zig");
+pub const PasswordMessage = @import("password_message.zig");
+pub const ErrorResponse = @import("error_response.zig");
+pub const ReadyForQuery = @import("ready_for_query.zig");
+pub const ParameterStatus = @import("parameter_status.zig");
+pub const BackendKeyData = @import("backend_key_data.zig");
+pub const Query = @import("query.zig");
+pub const DataRow = @import("data_row.zig");
+pub const RowDescription = @import("row_description.zig");
+pub const CommandComplete = @import("command_complete.zig");
+pub const CopyInResponse = @import("copy_in_response.zig");
+
+test {
+ _ = AuthenticationRequest;
+ _ = PasswordMessage;
+ _ = ErrorResponse;
+ _ = ReadyForQuery;
+ _ = ParameterStatus;
+ _ = BackendKeyData;
+ _ = Query;
+ _ = DataRow;
+ _ = RowDescription;
+ _ = CommandComplete;
+ _ = CopyInResponse;
+}
diff --git a/src/proto/query.zig b/src/proto/query.zig
new file mode 100644
index 0000000..9f238fc
--- /dev/null
+++ b/src/proto/query.zig
@@ -0,0 +1,55 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+const enum_from_int = @import("../main.zig").enum_from_int;
+
+pub const Tag: u8 = 'Q';
+
+const Query = @This();
+
+string: []const u8,
+owned: bool = false,
+
+pub fn read(a: std.mem.Allocator, b: []const u8) !Query {
+ return .{
+ .string = try a.dupe(u8, b[0..(b.len-1)]), // Drop the null terminator
+ .owned = true,
+ };
+}
+
+pub fn write(self: Query, _: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, @as(u32, @intCast(self.string.len+5)));
+ try stream_writer.writeAll(self.string);
+ try stream_writer.writeByte(0);
+}
+
+pub fn deinit(self: *Query, a: std.mem.Allocator) void {
+ if (self.owned) a.free(self.string);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = Query{
+ .string = "Hello",
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try Query.read(allocator, buf);
+ defer sm2.deinit(allocator);
+ try std.testing.expectEqualStrings("Hello", sm2.string);
+}
diff --git a/src/proto/ready_for_query.zig b/src/proto/ready_for_query.zig
new file mode 100644
index 0000000..ef99e60
--- /dev/null
+++ b/src/proto/ready_for_query.zig
@@ -0,0 +1,53 @@
+const std = @import("std");
+const ProtocolError = @import("../main.zig").ProtocolError;
+const enum_from_int = @import("../main.zig").enum_from_int;
+const ByteArrayList = std.ArrayList(u8);
+
+const ReadyForQuery = @This();
+pub const Tag: u8 = 'Z';
+
+const TransactionStatus = enum(u8) {
+ idle = 'I',
+ transaction = 'T',
+ err = 'E',
+};
+
+transaction_status: TransactionStatus,
+
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !ReadyForQuery {
+ _ = allocator;
+ if (b.len != 1) return ProtocolError.InvalidMessageLength;
+ return .{ .transaction_status = enum_from_int(TransactionStatus, b[0]) orelse return ProtocolError.InvalidTransactionStatus };
+}
+pub fn write(self: ReadyForQuery, allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ _ = allocator;
+ try stream_writer.writeByte(Tag);
+ try stream_writer.writeIntBig(u32, 5);
+ try stream_writer.writeByte(@intFromEnum(self.transaction_status));
+}
+pub fn deinit(_: *ReadyForQuery, _: std.mem.Allocator) void {}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = ReadyForQuery{
+ .transaction_status = TransactionStatus.idle,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try ReadyForQuery.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqual(TransactionStatus.idle, sm2.transaction_status);
+}
diff --git a/src/proto/row_description.zig b/src/proto/row_description.zig
new file mode 100644
index 0000000..a0e8810
--- /dev/null
+++ b/src/proto/row_description.zig
@@ -0,0 +1,136 @@
+const std = @import("std");
+const log = std.log.scoped(.pgz);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("../main.zig").ProtocolError;
+const ClientError = @import("../main.zig").ClientError;
+const enum_from_int = @import("../main.zig").enum_from_int;
+const FormatCode = @import("../main.zig").FormatCode;
+
+pub const Tag: u8 = 'T';
+
+const RowDescription = @This();
+
+buf: ?[]const u8 = null, // owned
+fields: ?[]Field = null, // owned
+
+pub const Field = struct {
+ name: []const u8,
+ table_oid: u32,
+ attr_no: u16,
+ data_type_oid: u32,
+ data_type_size: i16,
+ data_type_modifier: u32,
+ format_code: FormatCode,
+};
+
+pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription {
+ var res: RowDescription = undefined;
+ res.buf = try a.dupe(u8, b);
+ errdefer res.deinit(a);
+ var fbs = std.io.fixedBufferStream(res.buf.?);
+ var reader = fbs.reader();
+ const n_fields = try reader.readIntBig(u16);
+ res.fields = try a.alloc(Field, n_fields);
+ for (0..n_fields) |i| {
+ const name_start = fbs.pos;
+ try reader.skipUntilDelimiterOrEof(0);
+ const name_end = fbs.pos-1;
+ const name = res.buf.?[name_start..name_end];
+ const field = Field{
+ .name = name,
+ .table_oid = try reader.readIntBig(u32),
+ .attr_no = try reader.readIntBig(u16),
+ .data_type_oid = try reader.readIntBig(u32),
+ .data_type_size = try reader.readIntBig(i16),
+ .data_type_modifier = try reader.readIntBig(u32),
+ .format_code = enum_from_int(FormatCode, try reader.readIntBig(u16)) orelse return ProtocolError.InvalidFormatCode,
+ };
+ res.fields.?[i] = field;
+ }
+ return res;
+}
+
+pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ var al = ByteArrayList.init(a);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeIntBig(u16, @as(u16, @intCast(self.fields.?.len)));
+ for (self.fields.?) |field| {
+ try writer.writeAll(field.name);
+ try writer.writeByte(0);
+ try writer.writeIntBig(u32, field.table_oid);
+ try writer.writeIntBig(u16, field.attr_no);
+ try writer.writeIntBig(u32, field.data_type_oid);
+ try writer.writeIntBig(i16, field.data_type_size);
+ try writer.writeIntBig(u32, field.data_type_modifier);
+ try writer.writeIntBig(u16, @intFromEnum(field.format_code));
+ }
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+
+pub fn deinit(self: *RowDescription, a: std.mem.Allocator) void {
+ if (self.fields != null) a.free(self.fields.?);
+ if (self.buf != null) a.free(self.buf.?);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var fields = try allocator.alloc(Field, 3);
+ fields[0] = .{
+ .name = "foo",
+ .table_oid = 1,
+ .attr_no = 2,
+ .data_type_oid = 3,
+ .data_type_size = 4,
+ .data_type_modifier = 5,
+ .format_code = .Binary,
+ };
+ fields[1] = .{
+ .name = "bar",
+ .table_oid = 1,
+ .attr_no = 2,
+ .data_type_oid = 3,
+ .data_type_size = 4,
+ .data_type_modifier = 5,
+ .format_code = .Binary,
+ };
+ fields[2] = .{
+ .name = "BAZZZZZ",
+ .table_oid = 99,
+ .attr_no = 98,
+ .data_type_oid = 97,
+ .data_type_size = 96,
+ .data_type_modifier = 95,
+ .format_code = .Text,
+ };
+ var f0 = fields[0];
+ var f1 = fields[1];
+ var f2 = fields[2];
+ var sm = RowDescription{
+ .fields = fields,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try RowDescription.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualDeep(f0, sm2.fields.?[0]);
+ try std.testing.expectEqualDeep(f1, sm2.fields.?[1]);
+ try std.testing.expectEqualDeep(f2, sm2.fields.?[2]);
+}
diff --git a/src/proto/startup_message.zig b/src/proto/startup_message.zig
new file mode 100644
index 0000000..8224bdb
--- /dev/null
+++ b/src/proto/startup_message.zig
@@ -0,0 +1,85 @@
+const std = @import("std");
+const ProtocolError = @import("../main.zig").ProtocolError;
+const SSHashMap = std.StringHashMap([]const u8);
+const ByteArrayList = std.ArrayList(u8);
+
+const StartupMessage = @This();
+
+const ProtocolVersionNumber: u32 = 196608; // 3.0
+
+bytes: ?[]const u8 = null, // Owned
+parameters: SSHashMap,
+
+// message length should already have been read, b should contain the payload
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !StartupMessage {
+ if (b.len < 4) return ProtocolError.InvalidMessageLength;
+
+ var bytes = try allocator.dupe(u8, b);
+ errdefer allocator.free(bytes);
+ const protocol_version = std.mem.readIntSliceBig(u32, bytes[0..4]);
+ if (protocol_version != ProtocolVersionNumber) return ProtocolError.InvalidProtocolVersion;
+
+ var parameters = SSHashMap.init(allocator);
+ var it = std.mem.splitScalar(u8, bytes[4..], 0);
+ while (it.next()) |next| {
+ const key = next;
+ const value = it.next() orelse return ProtocolError.InvalidKeyValuePair;
+ try parameters.put(key, value);
+ }
+
+ return .{
+ .bytes = bytes,
+ .parameters = parameters,
+ };
+}
+
+pub fn write(self: StartupMessage, allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ var al = ByteArrayList.init(allocator);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // length placeholder
+ try writer.writeIntBig(u32, ProtocolVersionNumber);
+ var it = self.parameters.iterator();
+ while (it.next()) |entry| {
+ try writer.writeAll(entry.key_ptr.*);
+ try writer.writeByte(0);
+ try writer.writeAll(entry.value_ptr.*);
+ try writer.writeByte(0);
+ }
+ try writer.writeByte(0);
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+
+pub fn deinit(self: *StartupMessage, allocator: std.mem.Allocator) void {
+ self.parameters.deinit();
+ if (self.bytes != null) {
+ allocator.free(self.bytes.?);
+ }
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var params = SSHashMap.init(allocator);
+ try params.put("hello", "postgres");
+ var sm = StartupMessage{
+ .parameters = params,
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try StartupMessage.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("postgres", sm2.parameters.get("hello").?);
+}