From 24439a295ca80a3b9a9e65d8b3436859d4ada46a Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Sun, 24 Sep 2023 22:12:52 +0100 Subject: Add RowDescription structure --- src/conn.zig | 2 + src/main.zig | 8 +++ src/row_description.zig | 136 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+) create mode 100644 src/row_description.zig (limited to 'src') diff --git a/src/conn.zig b/src/conn.zig index 36c89df..99018da 100644 --- a/src/conn.zig +++ b/src/conn.zig @@ -99,6 +99,8 @@ fn deinit(self: *Conn) void { self.stream.close(); } +//pub fn exec(self: *Conn) + test "connect" { // must have a local postgres runnning // TODO maybe use docker to start one? diff --git a/src/main.zig b/src/main.zig index 07b6628..8c2aed9 100644 --- a/src/main.zig +++ b/src/main.zig @@ -9,6 +9,7 @@ const ParameterStatus = @import("parameter_status.zig"); const BackendKeyData = @import("backend_key_data.zig"); const Query = @import("query.zig"); const DataRow = @import("data_row.zig"); +const RowDescription = @import("row_description.zig"); const Conn = @import("conn.zig"); pub const ProtocolError = error{ @@ -19,6 +20,7 @@ pub const ProtocolError = error{ MissingField, WrongMessageType, InvalidTransactionStatus, + InvalidFormatCode, }; pub const ClientError = error{ @@ -29,6 +31,11 @@ pub const ServerError = error{ ErrorResponse, }; +pub const FormatCode = enum(u16) { + Text = 0, + Binary = 1, +}; + // Fallible version of enumFromInt pub fn enum_from_int(comptime e: type, i: anytype) ?e { const enum_ti = @typeInfo(e); @@ -119,4 +126,5 @@ test { _ = BackendKeyData; _ = Query; _ = DataRow; + _ = RowDescription; } diff --git a/src/row_description.zig b/src/row_description.zig new file mode 100644 index 0000000..414c174 --- /dev/null +++ b/src/row_description.zig @@ -0,0 +1,136 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("main.zig").ProtocolError; +const ClientError = @import("main.zig").ClientError; +const enum_from_int = @import("main.zig").enum_from_int; +const FormatCode = @import("main.zig").FormatCode; + +pub const Tag: u8 = 'T'; + +const RowDescription = @This(); + +buf: ?[]const u8 = null, // owned +fields: ?[]Field = null, // owned + +pub const Field = struct { + name: []const u8, + table_oid: u32, + attr_no: u16, + data_type_oid: u32, + data_type_size: i16, + data_type_modifier: u32, + format_code: FormatCode, +}; + +pub fn read(a: std.mem.Allocator, b: []const u8) !RowDescription { + var res: RowDescription = undefined; + res.buf = try a.dupe(u8, b); + errdefer res.deinit(a); + var fbs = std.io.fixedBufferStream(res.buf.?); + var reader = fbs.reader(); + const n_fields = try reader.readIntBig(u16); + res.fields = try a.alloc(Field, n_fields); + for (0..n_fields) |i| { + const name_start = fbs.pos; + try reader.skipUntilDelimiterOrEof(0); + const name_end = fbs.pos-1; + const name = res.buf.?[name_start..name_end]; + const field = Field{ + .name = name, + .table_oid = try reader.readIntBig(u32), + .attr_no = try reader.readIntBig(u16), + .data_type_oid = try reader.readIntBig(u32), + .data_type_size = try reader.readIntBig(i16), + .data_type_modifier = try reader.readIntBig(u32), + .format_code = enum_from_int(FormatCode, try reader.readIntBig(u16)) orelse return ProtocolError.InvalidFormatCode, + }; + res.fields.?[i] = field; + } + return res; +} + +pub fn write(self: RowDescription, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeIntBig(u16, @as(u16, @intCast(self.fields.?.len))); + for (self.fields.?) |field| { + try writer.writeAll(field.name); + try writer.writeByte(0); + try writer.writeIntBig(u32, field.table_oid); + try writer.writeIntBig(u16, field.attr_no); + try writer.writeIntBig(u32, field.data_type_oid); + try writer.writeIntBig(i16, field.data_type_size); + try writer.writeIntBig(u32, field.data_type_modifier); + try writer.writeIntBig(u16, @intFromEnum(field.format_code)); + } + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} + +pub fn deinit(self: *RowDescription, a: std.mem.Allocator) void { + if (self.fields != null) a.free(self.fields.?); + if (self.buf != null) a.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var fields = try allocator.alloc(Field, 3); + fields[0] = .{ + .name = "foo", + .table_oid = 1, + .attr_no = 2, + .data_type_oid = 3, + .data_type_size = 4, + .data_type_modifier = 5, + .format_code = .Binary, + }; + fields[1] = .{ + .name = "bar", + .table_oid = 1, + .attr_no = 2, + .data_type_oid = 3, + .data_type_size = 4, + .data_type_modifier = 5, + .format_code = .Binary, + }; + fields[2] = .{ + .name = "BAZZZZZ", + .table_oid = 99, + .attr_no = 98, + .data_type_oid = 97, + .data_type_size = 96, + .data_type_modifier = 95, + .format_code = .Text, + }; + var f0 = fields[0]; + var f1 = fields[1]; + var f2 = fields[2]; + var sm = RowDescription{ + .fields = fields, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try RowDescription.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualDeep(f0, sm2.fields.?[0]); + try std.testing.expectEqualDeep(f1, sm2.fields.?[1]); + try std.testing.expectEqualDeep(f2, sm2.fields.?[2]); +} -- cgit v1.2.3-ZIG