From 5a91b37ee7dd36db52dfde1727b780ec3fa4c67d Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Sat, 23 Sep 2023 15:18:38 +0100 Subject: Add error_response Start adding connection abstraction --- src/config.zig | 13 ++++ src/conn.zig | 49 ++++++++++++++ src/error_response.zig | 170 +++++++++++++++++++++++++++++++++++++++++++++++ src/main.zig | 3 + src/password_message.zig | 3 +- 5 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 src/config.zig create mode 100644 src/conn.zig create mode 100644 src/error_response.zig (limited to 'src') diff --git a/src/config.zig b/src/config.zig new file mode 100644 index 0000000..eb2a52b --- /dev/null +++ b/src/config.zig @@ -0,0 +1,13 @@ +const std = @import("std"); +const SSHashMap = std.StringHashMap([]const u8); + +const Config = @This(); + +allocator: std.mem.Allocator, +address: union(enum){ + net: std.net.Address, + unix: []const u8, +}, +database: ?[]const u8, +user: []const u8, +password: []const u8, diff --git a/src/conn.zig b/src/conn.zig new file mode 100644 index 0000000..870ab01 --- /dev/null +++ b/src/conn.zig @@ -0,0 +1,49 @@ +const std = @import("std"); +const SSHashMap = std.StringHashMap([]const u8); +const Config = @import("config.zig"); +const StartupMessage = @import("startup_message.zig"); +const AuthenticationOk = @import("authentication_ok.zig"); +const AuthenticationCleartextPassword = @import("authentication_cleartext_password.zig"); +const ErrorResponse = @import("error_response.zig"); + +const Conn = @This(); + +const ConnStatus = enum { + connStatusUninitialized, + connStatusConnecting, + connStatusClosed, + connStatusIdle, + connStatusBusy, +}; + +stream: std.net.Stream, +config: Config, +status: ConnStatus, + +pub fn connect(config: Config) !Conn { + const allocator = config.allocator; + var stream = switch (config.address) { + .net => |addr| try std.net.tcpConnectToAddress(addr), + .unix => |path| try std.net.connectUnixSocket(path), + }; + var writer = stream.writer(); + + errdefer stream.close(); + var params = SSHashMap.init(allocator); + errdefer params.deinit(); + try params.put("user", config.user); + if (config.database) |database| try params.put(database); + var sm = StartupMessage{ + .parameters = params, + }; + defer sm.deinit(allocator); + try sm.write(allocator, writer); +} + +const StartupMessageResponseType = enum(u8) { + ErrorResponse = 'E', + AuthenticationResponse = AuthenticationOk.Tag, // All the authentication responses share a message type and must be decoded by the next field +}; +const StartupMessageResponse = union(StartupMessageResponseType) { + error: ErrorResponse, +}; \ No newline at end of file diff --git a/src/error_response.zig b/src/error_response.zig new file mode 100644 index 0000000..4d99ba4 --- /dev/null +++ b/src/error_response.zig @@ -0,0 +1,170 @@ +const std = @import("std"); +const HMByteString = std.AutoHashMap(u8, []const u8); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("main.zig").ProtocolError; + +const ErrorResponse = @This(); +const Tag: u8 = 'E'; + +buf: ?[]const u8 = null, // owned +severity: []const u8, +severity_unlocalized: ?[]const u8 = null, +code: []const u8, +message: []const u8, +detail: ?[]const u8 = null, +hint: ?[]const u8 = null, +position: ?u32 = null, +internal_position: ?u32 = null, +internal_query: ?[]const u8 = null, +where: ?[]const u8 = null, +schema_name: ?[]const u8 = null, +table_name: ?[]const u8 = null, +column_name: ?[]const u8 = null, +data_type_name: ?[]const u8 = null, +constraint_name: ?[]const u8 = null, +file_name: ?[]const u8 = null, +line: ?u32 = null, +routine: ?[]const u8 = null, +unknown_fields: HMByteString, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse { + var res: ErrorResponse = undefined; + res.unknown_fields = HMByteString.init(allocator); + res.buf = try allocator.dupe(u8, b); + errdefer allocator.free(res.buf.?); + var it = std.mem.splitScalar(u8, res.buf.?, 0); + var setSev = false; var setCode = false; var setMsg = false; + while (it.next()) |next| { + if (next.len < 1) break; + switch (next[0]) { + 0 => break, + 'S' => { + res.severity = next[1..]; + setSev = true; + }, + 'V' => { + res.severity_unlocalized = next[1..]; + }, + 'C' => { + res.code = next[1..]; + setCode = true; + }, + 'M' => { + res.message = next[1..]; + setMsg = true; + }, + 'D' => { + res.detail = next[1..]; + }, + 'H' => { + res.hint = next[1..]; + }, + 'P' => { + res.position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'p' => { + res.internal_position = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'q' => { + res.internal_query = next[1..]; + }, + 'W' => { + res.where = next[1..]; + }, + 's' => { + res.schema_name = next[1..]; + }, + 't' => { + res.table_name = next[1..]; + }, + 'c' => { + res.column_name = next[1..]; + }, + 'd' => { + res.data_type_name = next[1..]; + }, + 'n' => { + res.constraint_name = next[1..]; + }, + 'F' => { + res.file_name = next[1..]; + }, + 'L' => { + res.line = try std.fmt.parseInt(u32, next[1..], 10); + }, + 'R' => { + res.routine = next[1..]; + }, + else => { + try res.unknown_fields.put(next[0], next[1..]); + } + } + } + if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField; + return res; +} +pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(allocator); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // Length placeholder. + + try write_field_nt('S', self, "severity", writer); + if (self.severity_unlocalized) |severity_unlocalized| { + try write_nt('V', severity_unlocalized, writer); + } + try write_field_nt('C', self, "code", writer); + try write_field_nt('M', self, "message", writer); + // TODO rest of the fields + + // replace the length and write it to the actual stream + std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written))); + try stream_writer.writeAll(al.items); +} +fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void { + try write_nt(tag, @field(self, field), writer); +} +fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void { + try writer.writeByte(tag); + try writer.writeAll(value); + try writer.writeByte(0); +} + +pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void { + self.unknown_fields.deinit(); + if (self.buf != null) allocator.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ErrorResponse{ + .severity = "foo", + .severity_unlocalized = "foo_unlocal", + .code = "bar", + .message = "baz", + .unknown_fields = HMByteString.init(allocator), + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ErrorResponse.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqualStrings("foo", sm2.severity); + try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?); + try std.testing.expectEqualStrings("bar", sm2.code); + try std.testing.expectEqualStrings("baz", sm2.message); +} diff --git a/src/main.zig b/src/main.zig index 24e86c4..e540256 100644 --- a/src/main.zig +++ b/src/main.zig @@ -4,12 +4,14 @@ const StartupMessage = @import("startup_message.zig"); const AuthenticationOk = @import("authentication_ok.zig"); const AuthenticationCleartextPassword = @import("authentication_cleartext_password.zig"); const PasswordMessage = @import("password_message.zig"); +const ErrorResponse = @import("error_response.zig"); pub const ProtocolError = error{ InvalidProtocolVersion, InvalidKeyValuePair, InvalidMessageLength, InvalidAuthType, + MissingField, }; pub const ClientError = error{ @@ -41,4 +43,5 @@ test { _ = AuthenticationOk; _ = AuthenticationCleartextPassword; _ = PasswordMessage; + _ = ErrorResponse; } diff --git a/src/password_message.zig b/src/password_message.zig index 33214bf..1a8c17a 100644 --- a/src/password_message.zig +++ b/src/password_message.zig @@ -6,16 +6,17 @@ pub const Tag: u8 = 'p'; password: []const u8, password_owned: bool = false, - pub fn read(allocator: std.mem.Allocator, b: []const u8) !PasswordMessage { return .{ .password = try allocator.dupe(u8, b), .password_owned = true }; } + pub fn write(self: PasswordMessage, _: std.mem.Allocator, stream_writer: anytype) !void { try stream_writer.writeByte(Tag); try stream_writer.writeIntBig(u32, 5 + @as(u32, @intCast(self.password.len))); try stream_writer.writeAll(self.password); try stream_writer.writeByte(0); } + pub fn deinit(self: *PasswordMessage, allocator: std.mem.Allocator) void { if (self.password_owned) allocator.free(self.password); } -- cgit v1.2.3-ZIG