aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/config.zig13
-rw-r--r--src/conn.zig49
-rw-r--r--src/error_response.zig170
-rw-r--r--src/main.zig3
-rw-r--r--src/password_message.zig3
5 files changed, 237 insertions, 1 deletions
diff --git a/src/config.zig b/src/config.zig
new file mode 100644
index 0000000..eb2a52b
--- /dev/null
+++ b/src/config.zig
@@ -0,0 +1,13 @@
+const std = @import("std");
+const SSHashMap = std.StringHashMap([]const u8);
+
+const Config = @This();
+
+allocator: std.mem.Allocator,
+address: union(enum){
+ net: std.net.Address,
+ unix: []const u8,
+},
+database: ?[]const u8,
+user: []const u8,
+password: []const u8,
diff --git a/src/conn.zig b/src/conn.zig
new file mode 100644
index 0000000..870ab01
--- /dev/null
+++ b/src/conn.zig
@@ -0,0 +1,49 @@
+const std = @import("std");
+const SSHashMap = std.StringHashMap([]const u8);
+const Config = @import("config.zig");
+const StartupMessage = @import("startup_message.zig");
+const AuthenticationOk = @import("authentication_ok.zig");
+const AuthenticationCleartextPassword = @import("authentication_cleartext_password.zig");
+const ErrorResponse = @import("error_response.zig");
+
+const Conn = @This();
+
+const ConnStatus = enum {
+ connStatusUninitialized,
+ connStatusConnecting,
+ connStatusClosed,
+ connStatusIdle,
+ connStatusBusy,
+};
+
+stream: std.net.Stream,
+config: Config,
+status: ConnStatus,
+
+pub fn connect(config: Config) !Conn {
+ const allocator = config.allocator;
+ var stream = switch (config.address) {
+ .net => |addr| try std.net.tcpConnectToAddress(addr),
+ .unix => |path| try std.net.connectUnixSocket(path),
+ };
+ var writer = stream.writer();
+
+ errdefer stream.close();
+ var params = SSHashMap.init(allocator);
+ errdefer params.deinit();
+ try params.put("user", config.user);
+ if (config.database) |database| try params.put(database);
+ var sm = StartupMessage{
+ .parameters = params,
+ };
+ defer sm.deinit(allocator);
+ try sm.write(allocator, writer);
+}
+
+const StartupMessageResponseType = enum(u8) {
+ ErrorResponse = 'E',
+ AuthenticationResponse = AuthenticationOk.Tag, // All the authentication responses share a message type and must be decoded by the next field
+};
+const StartupMessageResponse = union(StartupMessageResponseType) {
+ error: ErrorResponse,
+}; \ No newline at end of file
diff --git a/src/error_response.zig b/src/error_response.zig
new file mode 100644
index 0000000..4d99ba4
--- /dev/null
+++ b/src/error_response.zig
@@ -0,0 +1,170 @@
+const std = @import("std");
+const HMByteString = std.AutoHashMap(u8, []const u8);
+const ByteArrayList = std.ArrayList(u8);
+const ProtocolError = @import("main.zig").ProtocolError;
+
+const ErrorResponse = @This();
+const Tag: u8 = 'E';
+
+buf: ?[]const u8 = null, // owned
+severity: []const u8,
+severity_unlocalized: ?[]const u8 = null,
+code: []const u8,
+message: []const u8,
+detail: ?[]const u8 = null,
+hint: ?[]const u8 = null,
+position: ?u32 = null,
+internal_position: ?u32 = null,
+internal_query: ?[]const u8 = null,
+where: ?[]const u8 = null,
+schema_name: ?[]const u8 = null,
+table_name: ?[]const u8 = null,
+column_name: ?[]const u8 = null,
+data_type_name: ?[]const u8 = null,
+constraint_name: ?[]const u8 = null,
+file_name: ?[]const u8 = null,
+line: ?u32 = null,
+routine: ?[]const u8 = null,
+unknown_fields: HMByteString,
+
+pub fn read(allocator: std.mem.Allocator, b: []const u8) !ErrorResponse {
+ var res: ErrorResponse = undefined;
+ res.unknown_fields = HMByteString.init(allocator);
+ res.buf = try allocator.dupe(u8, b);
+ errdefer allocator.free(res.buf.?);
+ var it = std.mem.splitScalar(u8, res.buf.?, 0);
+ var setSev = false; var setCode = false; var setMsg = false;
+ while (it.next()) |next| {
+ if (next.len < 1) break;
+ switch (next[0]) {
+ 0 => break,
+ 'S' => {
+ res.severity = next[1..];
+ setSev = true;
+ },
+ 'V' => {
+ res.severity_unlocalized = next[1..];
+ },
+ 'C' => {
+ res.code = next[1..];
+ setCode = true;
+ },
+ 'M' => {
+ res.message = next[1..];
+ setMsg = true;
+ },
+ 'D' => {
+ res.detail = next[1..];
+ },
+ 'H' => {
+ res.hint = next[1..];
+ },
+ 'P' => {
+ res.position = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'p' => {
+ res.internal_position = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'q' => {
+ res.internal_query = next[1..];
+ },
+ 'W' => {
+ res.where = next[1..];
+ },
+ 's' => {
+ res.schema_name = next[1..];
+ },
+ 't' => {
+ res.table_name = next[1..];
+ },
+ 'c' => {
+ res.column_name = next[1..];
+ },
+ 'd' => {
+ res.data_type_name = next[1..];
+ },
+ 'n' => {
+ res.constraint_name = next[1..];
+ },
+ 'F' => {
+ res.file_name = next[1..];
+ },
+ 'L' => {
+ res.line = try std.fmt.parseInt(u32, next[1..], 10);
+ },
+ 'R' => {
+ res.routine = next[1..];
+ },
+ else => {
+ try res.unknown_fields.put(next[0], next[1..]);
+ }
+ }
+ }
+ if (!(setSev and setCode and setMsg)) return ProtocolError.MissingField;
+ return res;
+}
+pub fn write(self: ErrorResponse, allocator: std.mem.Allocator, stream_writer: anytype) !void {
+ try stream_writer.writeByte(Tag);
+ var al = ByteArrayList.init(allocator);
+ defer al.deinit();
+ var cw = std.io.countingWriter(al.writer());
+ var writer = cw.writer();
+ try writer.writeIntBig(u32, 0); // Length placeholder.
+
+ try write_field_nt('S', self, "severity", writer);
+ if (self.severity_unlocalized) |severity_unlocalized| {
+ try write_nt('V', severity_unlocalized, writer);
+ }
+ try write_field_nt('C', self, "code", writer);
+ try write_field_nt('M', self, "message", writer);
+ // TODO rest of the fields
+
+ // replace the length and write it to the actual stream
+ std.mem.writeIntBig(u32, al.items[0..4], @as(u32, @intCast(cw.bytes_written)));
+ try stream_writer.writeAll(al.items);
+}
+fn write_field_nt(comptime tag: u8, self: ErrorResponse, comptime field: []const u8, writer: anytype) !void {
+ try write_nt(tag, @field(self, field), writer);
+}
+fn write_nt(comptime tag: u8, value: []const u8, writer: anytype) !void {
+ try writer.writeByte(tag);
+ try writer.writeAll(value);
+ try writer.writeByte(0);
+}
+
+pub fn deinit(self: *ErrorResponse, allocator: std.mem.Allocator) void {
+ self.unknown_fields.deinit();
+ if (self.buf != null) allocator.free(self.buf.?);
+}
+
+test "round trip" {
+ const allocator = std.testing.allocator;
+ var sm = ErrorResponse{
+ .severity = "foo",
+ .severity_unlocalized = "foo_unlocal",
+ .code = "bar",
+ .message = "baz",
+ .unknown_fields = HMByteString.init(allocator),
+ };
+ defer sm.deinit(allocator);
+
+ var bal = ByteArrayList.init(allocator);
+ defer bal.deinit();
+ try sm.write(allocator, bal.writer());
+
+ var fbs = std.io.fixedBufferStream(bal.items);
+ var reader = fbs.reader();
+ const tag = try reader.readByte();
+ try std.testing.expectEqual(Tag, tag);
+ const len = try reader.readIntBig(u32);
+ const buf = try allocator.alloc(u8, len - 4);
+ defer allocator.free(buf);
+ try reader.readNoEof(buf);
+ var sm2 = try ErrorResponse.read(allocator, buf);
+ defer sm2.deinit(allocator);
+
+ try std.testing.expectEqualStrings("foo", sm2.severity);
+ try std.testing.expectEqualStrings("foo_unlocal", sm2.severity_unlocalized.?);
+ try std.testing.expectEqualStrings("bar", sm2.code);
+ try std.testing.expectEqualStrings("baz", sm2.message);
+}
diff --git a/src/main.zig b/src/main.zig
index 24e86c4..e540256 100644
--- a/src/main.zig
+++ b/src/main.zig
@@ -4,12 +4,14 @@ const StartupMessage = @import("startup_message.zig");
const AuthenticationOk = @import("authentication_ok.zig");
const AuthenticationCleartextPassword = @import("authentication_cleartext_password.zig");
const PasswordMessage = @import("password_message.zig");
+const ErrorResponse = @import("error_response.zig");
pub const ProtocolError = error{
InvalidProtocolVersion,
InvalidKeyValuePair,
InvalidMessageLength,
InvalidAuthType,
+ MissingField,
};
pub const ClientError = error{
@@ -41,4 +43,5 @@ test {
_ = AuthenticationOk;
_ = AuthenticationCleartextPassword;
_ = PasswordMessage;
+ _ = ErrorResponse;
}
diff --git a/src/password_message.zig b/src/password_message.zig
index 33214bf..1a8c17a 100644
--- a/src/password_message.zig
+++ b/src/password_message.zig
@@ -6,16 +6,17 @@ pub const Tag: u8 = 'p';
password: []const u8,
password_owned: bool = false,
-
pub fn read(allocator: std.mem.Allocator, b: []const u8) !PasswordMessage {
return .{ .password = try allocator.dupe(u8, b), .password_owned = true };
}
+
pub fn write(self: PasswordMessage, _: std.mem.Allocator, stream_writer: anytype) !void {
try stream_writer.writeByte(Tag);
try stream_writer.writeIntBig(u32, 5 + @as(u32, @intCast(self.password.len)));
try stream_writer.writeAll(self.password);
try stream_writer.writeByte(0);
}
+
pub fn deinit(self: *PasswordMessage, allocator: std.mem.Allocator) void {
if (self.password_owned) allocator.free(self.password);
}