const std = @import("std"); const log = std.log.scoped(.pgz); const ByteArrayList = std.ArrayList(u8); const ProtocolError = @import("../main.zig").ProtocolError; const ClientError = @import("../main.zig").ClientError; const enum_from_int = @import("../main.zig").enum_from_int; const FormatCode = @import("../main.zig").FormatCode; pub fn CopyXResponse(comptime tag: u8) type { return struct { pub const Tag: u8 = tag; overall_format_code: u8, format_codes: []const FormatCode, // owned pub fn read(a: std.mem.Allocator, buf: []const u8) !@This() { defer a.free(buf); var fbs = std.io.fixedBufferStream(buf); var reader = fbs.reader(); const overall_format_code = try reader.readIntBig(u8); const n_columns = try reader.readIntBig(u16); var format_codes = try a.alloc(FormatCode, n_columns); errdefer a.free(format_codes); for (0..n_columns) |i| { const int_format_code = try reader.readIntBig(u16); format_codes[i] = enum_from_int(FormatCode, int_format_code) orelse return ProtocolError.InvalidFormatCode; } return .{ .overall_format_code = overall_format_code, .format_codes = format_codes, }; } pub fn write(self: @This(), a: std.mem.Allocator, stream_writer: anytype) !void { try stream_writer.writeIntBig(u8, Tag); var al = ByteArrayList.init(a); defer al.deinit(); var cw = std.io.countingWriter(al.writer()); var writer = cw.writer(); try writer.writeIntBig(u32, 0); // length placeholder try writer.writeIntBig(u8, self.overall_format_code); try writer.writeIntBig(u16, @intCast(self.format_codes.len)); for (self.format_codes) |format_code| { try writer.writeIntBig(u16, @intFromEnum(format_code)); } std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); // Update length try stream_writer.writeAll(al.items); } pub fn deinit(self: *@This(), a: std.mem.Allocator) void { a.free(self.format_codes); } }; } test "round trip" { const Tag: u8 = 'G'; const CopyInResponse = CopyXResponse(Tag); const allocator = std.testing.allocator; var format_codes = try allocator.alloc(FormatCode, 3); format_codes[0] = .Binary; format_codes[1] = .Binary; format_codes[2] = .Text; var sm = CopyInResponse{ .overall_format_code = 1, .format_codes = format_codes, }; defer sm.deinit(allocator); var bal = ByteArrayList.init(allocator); defer bal.deinit(); try sm.write(allocator, bal.writer()); var fbs = std.io.fixedBufferStream(bal.items); var reader = fbs.reader(); const tag = try reader.readByte(); try std.testing.expectEqual(Tag, tag); const len = try reader.readIntBig(u32); const buf = try allocator.alloc(u8, len - 4); try reader.readNoEof(buf); var sm2 = try CopyInResponse.read(allocator, buf); defer sm2.deinit(allocator); try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[0]); try std.testing.expectEqual(FormatCode.Binary, sm2.format_codes[1]); try std.testing.expectEqual(FormatCode.Text, sm2.format_codes[2]); }