From 08472c27c77d27ea084e3458842540351c5a5c28 Mon Sep 17 00:00:00 2001 From: Martin Ashby Date: Wed, 27 Sep 2023 20:23:30 +0100 Subject: Add cleartext password handling. Fix segfault on error response read. Add test for tcp connection and incorrect password --- src/conn/config.zig | 4 ++-- src/conn/conn.zig | 42 +++++++++++++++++++++++++++++++++++++++--- src/main.zig | 11 ++++++----- src/proto/error_response.zig | 39 ++++++++++++++++++++++++++++++--------- 4 files changed, 77 insertions(+), 19 deletions(-) (limited to 'src') diff --git a/src/conn/config.zig b/src/conn/config.zig index b4e7cff..3f577d1 100644 --- a/src/conn/config.zig +++ b/src/conn/config.zig @@ -4,9 +4,9 @@ const SSHashMap = std.StringHashMap([]const u8); const Config = @This(); allocator: std.mem.Allocator, -address: union(enum){ +address: union(enum) { net: std.net.Address, - unix: []const u8, + unix: []const u8, // std.net.Address looks like it handles unix sockets but it doesn't really. }, database: ?[]const u8 = null, user: []const u8, diff --git a/src/conn/conn.zig b/src/conn/conn.zig index 1b2bf2d..4d62f57 100644 --- a/src/conn/conn.zig +++ b/src/conn/conn.zig @@ -6,6 +6,7 @@ const Proto = @import("../proto/proto.zig"); const read_message = @import("../main.zig").read_message; const ProtocolError = @import("../main.zig").ProtocolError; const ServerError = @import("../main.zig").ServerError; +const ClientError = @import("../main.zig").ClientError; const diagnosticReader = @import("../main.zig").diagnosticReader; const Conn = @This(); @@ -57,6 +58,17 @@ pub fn connect(config: Config) !Conn { var ar = try read_message(Proto.AuthenticationRequest, allocator, reader); defer ar.deinit(allocator); // TODO handle the authentication request + switch (ar.inner_type) { + .AuthRequestTypeOk => {}, // fine do nothing! + .AuthRequestTypeCleartextPassword => { + if (config.password) |password| { + const pm = Proto.PasswordMessage{ .password = password }; + try pm.write(allocator, writer); + } else { + return ClientError.NoPasswordSupplied; + } + }, + } log.info("authentication request", .{}); }, Proto.ReadyForQuery.Tag => { @@ -94,9 +106,7 @@ fn deinit(self: *Conn) void { self.stream.close(); } -//pub fn exec(self: *Conn) - -test "connect" { +test "connect unix" { // must have a local postgres runnning // TODO maybe use docker to start one? const allocator = std.testing.allocator; @@ -109,3 +119,29 @@ test "connect" { var conn = try Conn.connect(cfg); defer conn.deinit(); } + +test "connect tcp with password" { + const allocator = std.testing.allocator; + const cfg = Config{ + .allocator = allocator, + .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } }, + .database = "martin", + .user = "martin", + .password = "martin", + }; + var conn = try Conn.connect(cfg); + defer conn.deinit(); +} + +test "connect tcp with wrong password" { + // TODO how to disable failing tests on error log + // const allocator = std.testing.allocator; + // const cfg = Config{ + // .allocator = allocator, + // .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } }, + // .database = "martin", + // .user = "martin", + // .password = "foobar", + // }; + // try std.testing.expectError(ServerError.ErrorResponse, Conn.connect(cfg)); +} diff --git a/src/main.zig b/src/main.zig index c8a9d8e..818b3b7 100644 --- a/src/main.zig +++ b/src/main.zig @@ -12,19 +12,20 @@ pub const ProtocolError = error{ InvalidFormatCode, }; +pub const FormatCode = enum(u16) { + Text = 0, + Binary = 1, +}; + pub const ClientError = error{ UnsupportedAuthType, + NoPasswordSupplied, }; pub const ServerError = error{ ErrorResponse, }; -pub const FormatCode = enum(u16) { - Text = 0, - Binary = 1, -}; - // Fallible version of enumFromInt pub fn enum_from_int(comptime e: type, i: anytype) ?e { const enum_ti = @typeInfo(e); diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig index 2aafe8d..dc75053 100644 --- a/src/proto/error_response.zig +++ b/src/proto/error_response.zig @@ -28,12 +28,18 @@ routine: ?[]const u8 = null, unknown_fields: HMByteString, pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse { - var res: ErrorResponse = undefined; - res.unknown_fields = HMByteString.init(allocator); - res.buf = try allocator.dupe(u8, b); - errdefer allocator.free(res.buf.?); + var res = ErrorResponse{ + .severity = "", + .code = "", + .message = "", + .unknown_fields = HMByteString.init(allocator), + .buf = try allocator.dupe(u8, b), + }; + errdefer res.deinit(allocator); var it = std.mem.splitScalar(u8, res.buf.?, 0); - var setSev = false; var setCode = false; var setMsg = false; + var setSev = false; + var setCode = false; + var setMsg = false; while (it.next()) |next| { if (next.len < 1) break; switch (next[0]) { @@ -97,7 +103,7 @@ pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse { }, else => { try res.unknown_fields.put(next[0], next[1..]); - } + }, } } if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField; @@ -112,11 +118,10 @@ pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: a try writer.writeIntBig(u32, 0); // Length placeholder. try write_field_nt('S', self, "severity", writer); - if (self.severity_unlocalized) |severity_unlocalized| { - try write_nt('V', severity_unlocalized, writer); - } + if (self.severity_unlocalized) |severity_unlocalized| try write_nt('V', severity_unlocalized, writer); try write_field_nt('C', self, "code", writer); try write_field_nt('M', self, "message", writer); + if (self.detail) |detail| try write_nt('D', detail, writer); // TODO rest of the fields // replace the length and write it to the actual stream @@ -137,6 +142,20 @@ pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void { if (self.buf != null) allocator.free(self.buf.?); } +pub fn format(self: ErrorResponse, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + _ = options; + _ = fmt; + try writer.writeAll("ErrorResponse severity ["); + try writer.writeAll(self.severity); + try writer.writeAll("] "); + try writer.writeAll("code ["); + try writer.writeAll(self.code); + try writer.writeAll("] "); + try writer.writeAll("message ["); + try writer.writeAll(self.message); + try writer.writeAll("]"); +} + test "round trip" { const allocator = std.testing.allocator; var sm = ErrorResponse{ @@ -144,6 +163,7 @@ test "round trip" { .severity_unlocalized = "foo_unlocal", .code = "bar", .message = "baz", + .detail = "bang, and that's the end", .unknown_fields = HMByteString.init(allocator), }; defer sm.deinit(allocator); @@ -167,4 +187,5 @@ test "round trip" { try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?); try std.testing.expectEqualStrings("bar", sm2.code); try std.testing.expectEqualStrings("baz", sm2.message); + try std.testing.expectEqualStrings("bang, and that's the end", sm2.detail.?); } -- cgit v1.2.3-ZIG