const std = @import("std"); const ByteArrayList = std.ArrayList(u8); const log = std.log.scoped(.pgz); pub const StartupMessage = @import("startup_message.zig"); pub const AuthenticationRequest = @import("authentication_request.zig"); pub const PasswordMessage = @import("password_message.zig"); pub const ErrorResponse = @import("error_response.zig").ErrorNoticeResponse('E'); pub const NoticeResponse = @import("error_response.zig").ErrorNoticeResponse('N'); pub const ReadyForQuery = @import("ready_for_query.zig"); pub const ParameterStatus = @import("parameter_status.zig"); pub const BackendKeyData = @import("backend_key_data.zig"); pub const Query = @import("query.zig"); pub const DataRow = @import("data_row.zig"); pub const RowDescription = @import("row_description.zig"); pub const CommandComplete = @import("command_complete.zig"); const CopyXResponse = @import("copy_x_response.zig").CopyXResponse; pub const CopyInResponse = CopyXResponse('G'); pub const CopyOutResponse = CopyXResponse('H'); pub const CopyBothResponse = CopyXResponse('W'); const enum_from_int = @import("../main.zig").enum_from_int; const ProtocolError = @import("../main.zig").ProtocolError; pub const BackendMessage = union(enum) { AuthenticationRequest: AuthenticationRequest, ErrorResponse: ErrorResponse, NoticeResponse: NoticeResponse, ReadyForQuery: ReadyForQuery, ParameterStatus: ParameterStatus, BackendKeyData: BackendKeyData, DataRow: DataRow, RowDescription: RowDescription, CommandComplete: CommandComplete, CopyInResponse: CopyInResponse, CopyOutResponse: CopyOutResponse, CopyBothResponse: CopyBothResponse, pub fn deinit(self: *BackendMessage, a: std.mem.Allocator) void { switch (self.*) { inline else => |*sf| { sf.deinit(a); }, } } }; test { inline for (@typeInfo(BackendMessage).Union.fields) |field| { const msg_type = field.type; if (!@hasDecl(msg_type, "Tag")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .Tag decl"); if (!@hasDecl(msg_type, "read")) @compileError("message type " ++ @typeName(msg_type) ++ " must have a .read decl"); } } pub fn read_message(allocator: std.mem.Allocator, stream_reader: anytype) !BackendMessage { const tag = try stream_reader.readByte(); const len = try stream_reader.readIntBig(u32); const buf = try allocator.alloc(u8, @as(u32, @intCast(len - 4))); try stream_reader.readNoEof(buf); inline for (@typeInfo(BackendMessage).Union.fields) |field| { if (field.type.Tag == tag) { return @unionInit(BackendMessage, field.name, try field.type.read(allocator, buf)); } } else { allocator.free(buf); log.err("InvalidMessageType {c}", .{tag}); return ProtocolError.InvalidMessageType; } } // Caller owns the resulting message. // 'self' must be one of the message types above. pub fn clone_message(self: anytype, a: std.mem.Allocator) !@TypeOf(self) { var ba = ByteArrayList.init(a); defer ba.deinit(); try self.write(a, ba.writer()); var fbs = std.io.fixedBufferStream(ba.items); var reader = fbs.reader(); _ = try reader.readByte(); const len = try reader.readIntBig(u32); var buf = try a.alloc(u8, len-4); errdefer a.free(buf); try reader.readNoEof(buf); return try @TypeOf(self).read(a, buf); } test { _ = AuthenticationRequest; _ = PasswordMessage; _ = ErrorResponse; _ = ReadyForQuery; _ = ParameterStatus; _ = BackendKeyData; _ = Query; _ = DataRow; _ = RowDescription; _ = CommandComplete; _ = CopyXResponse; }