const std = @import("std"); const testing = std.testing; const StartupMessage = @import("startup_message.zig"); const AuthenticationRequest = @import("authentication_request.zig"); const PasswordMessage = @import("password_message.zig"); const ErrorResponse = @import("error_response.zig"); const ReadyForQuery = @import("ready_for_query.zig"); const ParameterStatus = @import("parameter_status.zig"); const BackendKeyData = @import("backend_key_data.zig"); const Query = @import("query.zig"); const DataRow = @import("data_row.zig"); const RowDescription = @import("row_description.zig"); const CommandComplete = @import("command_complete.zig"); const Conn = @import("conn.zig"); pub const ProtocolError = error{ InvalidProtocolVersion, InvalidKeyValuePair, InvalidMessageLength, InvalidAuthType, MissingField, WrongMessageType, InvalidTransactionStatus, InvalidFormatCode, }; pub const ClientError = error{ UnsupportedAuthType, }; pub const ServerError = error{ ErrorResponse, }; pub const FormatCode = enum(u16) { Text = 0, Binary = 1, }; // Fallible version of enumFromInt pub fn enum_from_int(comptime e: type, i: anytype) ?e { const enum_ti = @typeInfo(e); if (enum_ti != .Enum) @compileError("e should be an enum but instead it's a " ++ @typeName(e)); const ei = enum_ti.Enum; if (@TypeOf(i) != ei.tag_type) @compileError("i should be of type " ++ @typeName(e) ++ " but instead it's " ++ @typeName(@TypeOf(i))); inline for (ei.fields) |field| { if (field.value == i) { return @enumFromInt(i); } } else { return null; } } // Tag should already have been read in order to determine msg_type! pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, stream_reader: anytype) !msg_type { if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!"); if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!"); const len = try stream_reader.readIntBig(u32); const buf = try allocator.alloc(u8, @as(u32, @intCast(len-4))); defer allocator.free(buf); try stream_reader.readNoEof(buf); return try msg_type.read(allocator, buf); } pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) { return .{.child_reader = base_reader}; } // keeps a buffer of the last n bytes read pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type { return struct { child_reader: ReaderType, ring: [n]u8 = [_]u8{0}**n, pos: usize = 0, pub const Error = ReaderType.Error; pub const Reader = std.io.Reader(*@This(), Error, read); pub fn read(self: *@This(), buf: []u8) Error!usize { const amt = try self.child_reader.read(buf); for (0..amt) |i| { self.ring[self.pos] = buf[i]; self.pos += 1; self.pos %= n; } return amt; } pub fn reader(self: *@This()) Reader { return .{ .context = self }; } // Caller frees pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 { var buf = try allocator.alloc(u8, n); errdefer allocator.free(buf); @memcpy(buf[0..(n-self.pos)], self.ring[self.pos..n]); @memcpy(buf[(n-self.pos)..n], self.ring[0..self.pos]); return buf; } }; } test "diagnostc reader" { const a = std.testing.allocator; const string = "The quick brown fox jumped over the lazy dog"; var fbs = std.io.fixedBufferStream(string); var dr = diagnosticReader(15, fbs.reader()); var reader = dr.reader(); var buf = [_]u8{0}**20; try reader.readNoEof(&buf); const diag = try dr.get(a); defer a.free(diag); try std.testing.expectEqualStrings("uick brown fox ", diag); } test { _ = StartupMessage; _ = AuthenticationRequest; _ = PasswordMessage; _ = ErrorResponse; _ = Conn; _ = ReadyForQuery; _ = ParameterStatus; _ = BackendKeyData; _ = Query; _ = DataRow; _ = RowDescription; _ = CommandComplete; }