aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/conn/config.zig4
-rw-r--r--src/conn/conn.zig42
-rw-r--r--src/main.zig11
-rw-r--r--src/proto/error_response.zig39
4 files changed, 77 insertions, 19 deletions
diff --git a/src/conn/config.zig b/src/conn/config.zig
index b4e7cff..3f577d1 100644
--- a/src/conn/config.zig
+++ b/src/conn/config.zig
@@ -4,9 +4,9 @@ const SSHashMap = std.StringHashMap([]const u8);
const Config = @This();
allocator: std.mem.Allocator,
-address: union(enum){
+address: union(enum) {
net: std.net.Address,
- unix: []const u8,
+ unix: []const u8, // std.net.Address looks like it handles unix sockets but it doesn't really.
},
database: ?[]const u8 = null,
user: []const u8,
diff --git a/src/conn/conn.zig b/src/conn/conn.zig
index 1b2bf2d..4d62f57 100644
--- a/src/conn/conn.zig
+++ b/src/conn/conn.zig
@@ -6,6 +6,7 @@ const Proto = @import("../proto/proto.zig");
const read_message = @import("../main.zig").read_message;
const ProtocolError = @import("../main.zig").ProtocolError;
const ServerError = @import("../main.zig").ServerError;
+const ClientError = @import("../main.zig").ClientError;
const diagnosticReader = @import("../main.zig").diagnosticReader;
const Conn = @This();
@@ -57,6 +58,17 @@ pub fn connect(config: Config) !Conn {
var ar = try read_message(Proto.AuthenticationRequest, allocator, reader);
defer ar.deinit(allocator);
// TODO handle the authentication request
+ switch (ar.inner_type) {
+ .AuthRequestTypeOk => {}, // fine do nothing!
+ .AuthRequestTypeCleartextPassword => {
+ if (config.password) |password| {
+ const pm = Proto.PasswordMessage{ .password = password };
+ try pm.write(allocator, writer);
+ } else {
+ return ClientError.NoPasswordSupplied;
+ }
+ },
+ }
log.info("authentication request", .{});
},
Proto.ReadyForQuery.Tag => {
@@ -94,9 +106,7 @@ fn deinit(self: *Conn) void {
self.stream.close();
}
-//pub fn exec(self: *Conn)
-
-test "connect" {
+test "connect unix" {
// must have a local postgres runnning
// TODO maybe use docker to start one?
const allocator = std.testing.allocator;
@@ -109,3 +119,29 @@ test "connect" {
var conn = try Conn.connect(cfg);
defer conn.deinit();
}
+
+test "connect tcp with password" {
+ const allocator = std.testing.allocator;
+ const cfg = Config{
+ .allocator = allocator,
+ .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } },
+ .database = "martin",
+ .user = "martin",
+ .password = "martin",
+ };
+ var conn = try Conn.connect(cfg);
+ defer conn.deinit();
+}
+
+test "connect tcp with wrong password" {
+ // TODO how to disable failing tests on error log
+ // const allocator = std.testing.allocator;
+ // const cfg = Config{
+ // .allocator = allocator,
+ // .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } },
+ // .database = "martin",
+ // .user = "martin",
+ // .password = "foobar",
+ // };
+ // try std.testing.expectError(ServerError.ErrorResponse, Conn.connect(cfg));
+}
diff --git a/src/main.zig b/src/main.zig
index c8a9d8e..818b3b7 100644
--- a/src/main.zig
+++ b/src/main.zig
@@ -12,19 +12,20 @@ pub const ProtocolError = error{
InvalidFormatCode,
};
+pub const FormatCode = enum(u16) {
+ Text = 0,
+ Binary = 1,
+};
+
pub const ClientError = error{
UnsupportedAuthType,
+ NoPasswordSupplied,
};
pub const ServerError = error{
ErrorResponse,
};
-pub const FormatCode = enum(u16) {
- Text = 0,
- Binary = 1,
-};
-
// Fallible version of enumFromInt
pub fn enum_from_int(comptime e: type, i: anytype) ?e {
const enum_ti = @typeInfo(e);
diff --git a/src/proto/error_response.zig b/src/proto/error_response.zig
index 2aafe8d..dc75053 100644
--- a/src/proto/error_response.zig
+++ b/src/proto/error_response.zig
@@ -28,12 +28,18 @@ routine: ?[]const u8 = null,
unknown_fields: HMByteString,
pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse {
- var res: ErrorResponse = undefined;
- res.unknown_fields = HMByteString.init(allocator);
- res.buf = try allocator.dupe(u8, b);
- errdefer allocator.free(res.buf.?);
+ var res = ErrorResponse{
+ .severity = "",
+ .code = "",
+ .message = "",
+ .unknown_fields = HMByteString.init(allocator),
+ .buf = try allocator.dupe(u8, b),
+ };
+ errdefer res.deinit(allocator);
var it = std.mem.splitScalar(u8, res.buf.?, 0);
- var setSev = false; var setCode = false; var setMsg = false;
+ var setSev = false;
+ var setCode = false;
+ var setMsg = false;
while (it.next()) |next| {
if (next.len < 1) break;
switch (next[0]) {
@@ -97,7 +103,7 @@ pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse {
},
else => {
try res.unknown_fields.put(next[0], next[1..]);
- }
+ },
}
}
if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField;
@@ -112,11 +118,10 @@ pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: a
try writer.writeIntBig(u32, 0); // Length placeholder.
try write_field_nt('S', self, "severity", writer);
- if (self.severity_unlocalized) |severity_unlocalized| {
- try write_nt('V', severity_unlocalized, writer);
- }
+ if (self.severity_unlocalized) |severity_unlocalized| try write_nt('V', severity_unlocalized, writer);
try write_field_nt('C', self, "code", writer);
try write_field_nt('M', self, "message", writer);
+ if (self.detail) |detail| try write_nt('D', detail, writer);
// TODO rest of the fields
// replace the length and write it to the actual stream
@@ -137,6 +142,20 @@ pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void {
if (self.buf != null) allocator.free(self.buf.?);
}
+pub fn format(self: ErrorResponse, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
+ _ = options;
+ _ = fmt;
+ try writer.writeAll("ErrorResponse severity [");
+ try writer.writeAll(self.severity);
+ try writer.writeAll("] ");
+ try writer.writeAll("code [");
+ try writer.writeAll(self.code);
+ try writer.writeAll("] ");
+ try writer.writeAll("message [");
+ try writer.writeAll(self.message);
+ try writer.writeAll("]");
+}
+
test "round trip" {
const allocator = std.testing.allocator;
var sm = ErrorResponse{
@@ -144,6 +163,7 @@ test "round trip" {
.severity_unlocalized = "foo_unlocal",
.code = "bar",
.message = "baz",
+ .detail = "bang, and that's the end",
.unknown_fields = HMByteString.init(allocator),
};
defer sm.deinit(allocator);
@@ -167,4 +187,5 @@ test "round trip" {
try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?);
try std.testing.expectEqualStrings("bar", sm2.code);
try std.testing.expectEqualStrings("baz", sm2.message);
+ try std.testing.expectEqualStrings("bang, and that's the end", sm2.detail.?);
}