diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/authentication_cleartext_password.zig | 47 | ||||
-rw-r--r-- | src/authentication_ok.zig | 47 | ||||
-rw-r--r-- | src/authentication_request.zig | 75 | ||||
-rw-r--r-- | src/backend_key_data.zig | 52 | ||||
-rw-r--r-- | src/config.zig | 4 | ||||
-rw-r--r-- | src/conn.zig | 94 | ||||
-rw-r--r-- | src/error_response.zig | 2 | ||||
-rw-r--r-- | src/main.zig | 85 | ||||
-rw-r--r-- | src/parameter_status.zig | 64 | ||||
-rw-r--r-- | src/ready_for_query.zig | 55 |
10 files changed, 407 insertions, 118 deletions
diff --git a/src/authentication_cleartext_password.zig b/src/authentication_cleartext_password.zig deleted file mode 100644 index e72f28d..0000000 --- a/src/authentication_cleartext_password.zig +++ /dev/null @@ -1,47 +0,0 @@ -const std = @import("std"); -const ProtocolError = @import("main.zig").ProtocolError; -const AuthType = @import("main.zig").AuthType; -const enum_from_int = @import("main.zig").enum_from_int; -const ClientError = @import("main.zig").ClientError; -const AuthenticationCleartextPassword = @This(); -const ByteArrayList = std.ArrayList(u8); - -pub const Tag: u8 = 'R'; -pub const Type: AuthType = AuthType.AuthTypeCleartextPassword; - -pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationCleartextPassword { - if (b.len != 4) return ProtocolError.InvalidMessageLength; - - const auth_type = enum_from_int(AuthType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; - if (auth_type != Type) return ProtocolError.InvalidAuthType; - return .{}; -} - -pub fn write(_: AuthenticationCleartextPassword, _: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, 8); - try stream_writer.writeIntBig(u32, @intFromEnum(Type)); -} - -pub fn deinit(_: *AuthenticationCleartextPassword, _: std.mem.Allocator) void {} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = AuthenticationCleartextPassword{}; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try AuthenticationCleartextPassword.read(allocator, buf); - defer sm2.deinit(allocator); -} diff --git a/src/authentication_ok.zig b/src/authentication_ok.zig deleted file mode 100644 index 0f0702b..0000000 --- a/src/authentication_ok.zig +++ /dev/null @@ -1,47 +0,0 @@ -const std = @import("std"); -const ProtocolError = @import("main.zig").ProtocolError; -const AuthType = @import("main.zig").AuthType; -const enum_from_int = @import("main.zig").enum_from_int; -const ClientError = @import("main.zig").ClientError; -const AuthenticationOk = @This(); -const ByteArrayList = std.ArrayList(u8); - -pub const Tag: u8 = 'R'; -pub const Type: AuthType = AuthType.AuthTypeCleartextPassword; - -pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationOk { - if (b.len != 4) return ProtocolError.InvalidMessageLength; - - const auth_type = enum_from_int(AuthType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; - if (auth_type != Type) return ProtocolError.InvalidAuthType; - return .{}; -} - -pub fn write(_: AuthenticationOk, _: std.mem.Allocator, stream_writer: anytype) !void { - try stream_writer.writeByte(Tag); - try stream_writer.writeIntBig(u32, 8); - try stream_writer.writeIntBig(u32, @intFromEnum(Type)); -} - -pub fn deinit(_: *AuthenticationOk, _: std.mem.Allocator) void {} - -test "round trip" { - const allocator = std.testing.allocator; - var sm = AuthenticationOk{}; - defer sm.deinit(allocator); - - var bal = ByteArrayList.init(allocator); - defer bal.deinit(); - try sm.write(allocator, bal.writer()); - - var fbs = std.io.fixedBufferStream(bal.items); - var reader = fbs.reader(); - const tag = try reader.readByte(); - try std.testing.expectEqual(Tag, tag); - const len = try reader.readIntBig(u32); - const buf = try allocator.alloc(u8, len - 4); - defer allocator.free(buf); - try reader.readNoEof(buf); - var sm2 = try AuthenticationOk.read(allocator, buf); - defer sm2.deinit(allocator); -} diff --git a/src/authentication_request.zig b/src/authentication_request.zig new file mode 100644 index 0000000..549a26b --- /dev/null +++ b/src/authentication_request.zig @@ -0,0 +1,75 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("main.zig").ProtocolError; +const ClientError = @import("main.zig").ClientError; +const enum_from_int = @import("main.zig").enum_from_int; + +pub const Tag: u8 = 'R'; + +const AuthenticationRequest = @This(); + +pub const InnerAuthRequestType = enum(u32) { + AuthRequestTypeOk = 0, + AuthRequestTypeCleartextPassword = 3, +}; +pub const InnerAuthRequest = union{ + ok: AuthRequestOk, + cleartext_password: AuthRequestCleartextPassword, +}; +pub const AuthRequestOk = struct {}; +pub const AuthRequestCleartextPassword = struct {}; + +// Authentication requests have multiple subtypes. +// It's not possible to have a tagged union with a custom backing integer, so do it the long way +inner_type: InnerAuthRequestType, +inner: InnerAuthRequest, + +pub fn read(_: std.mem.Allocator, b: []const u8) !AuthenticationRequest { + if (b.len != 4) { + log.err("invalid message length, expected 4 got {}",.{b.len}); + return ProtocolError.InvalidMessageLength; + } + const inner_type = enum_from_int(InnerAuthRequestType, std.mem.readIntBig(u32, b[0..4])) orelse return ClientError.UnsupportedAuthType; + var inner: InnerAuthRequest = switch (inner_type) { + .AuthRequestTypeOk => .{.ok = AuthRequestOk{}}, + .AuthRequestTypeCleartextPassword => .{.cleartext_password = AuthRequestCleartextPassword{}}, + }; + return .{ + .inner_type = inner_type, + .inner = inner, + }; +} + +pub fn write(self: AuthenticationRequest, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 8); + try stream_writer.writeIntBig(u32, @intFromEnum(self.inner_type)); +} + +pub fn deinit(_: *AuthenticationRequest, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = AuthenticationRequest{ + .inner_type = .AuthRequestTypeOk, + .inner = .{.ok = AuthRequestOk{}}, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try AuthenticationRequest.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqual(InnerAuthRequestType.AuthRequestTypeOk, sm2.inner_type); +} diff --git a/src/backend_key_data.zig b/src/backend_key_data.zig new file mode 100644 index 0000000..525c309 --- /dev/null +++ b/src/backend_key_data.zig @@ -0,0 +1,52 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("main.zig").ProtocolError; +const ClientError = @import("main.zig").ClientError; + +const BackendKeyData = @This(); +pub const Tag: u8 = 'K'; + +process_id: u32, +secret_key: u32, + +pub fn read(_: std.mem.Allocator, b: []const u8) !BackendKeyData { + if (b.len != 8) return ProtocolError.InvalidMessageLength; + return .{ + .process_id = std.mem.readIntBig(u32, b[0..4]), + .secret_key = std.mem.readIntBig(u32, b[4..8]), + }; +} +pub fn write(self: BackendKeyData, _: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 12); // length + try stream_writer.writeIntBig(u32, self.process_id); + try stream_writer.writeIntBig(u32, self.secret_key); +} +pub fn deinit(_: *BackendKeyData, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = BackendKeyData{ + .process_id = 123, + .secret_key = 345, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try BackendKeyData.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqual(@as(u32, 123), sm2.process_id); + try std.testing.expectEqual(@as(u32, 345), sm2.secret_key); +} diff --git a/src/config.zig b/src/config.zig index eb2a52b..b4e7cff 100644 --- a/src/config.zig +++ b/src/config.zig @@ -8,6 +8,6 @@ address: union(enum){ net: std.net.Address, unix: []const u8, }, -database: ?[]const u8, +database: ?[]const u8 = null, user: []const u8, -password: []const u8, +password: ?[]const u8 = null, diff --git a/src/conn.zig b/src/conn.zig index 870ab01..36c89df 100644 --- a/src/conn.zig +++ b/src/conn.zig @@ -1,10 +1,17 @@ const std = @import("std"); +const log = std.log.scoped(.pgz); const SSHashMap = std.StringHashMap([]const u8); const Config = @import("config.zig"); const StartupMessage = @import("startup_message.zig"); -const AuthenticationOk = @import("authentication_ok.zig"); -const AuthenticationCleartextPassword = @import("authentication_cleartext_password.zig"); const ErrorResponse = @import("error_response.zig"); +const AuthenticationRequest = @import("authentication_request.zig"); +const ReadyForQuery = @import("ready_for_query.zig"); +const ParameterStatus = @import("parameter_status.zig"); +const BackendKeyData = @import("backend_key_data.zig"); +const read_message = @import("main.zig").read_message; +const ProtocolError = @import("main.zig").ProtocolError; +const ServerError = @import("main.zig").ServerError; +const diagnosticReader = @import("main.zig").diagnosticReader; const Conn = @This(); @@ -18,7 +25,7 @@ const ConnStatus = enum { stream: std.net.Stream, config: Config, -status: ConnStatus, +status: ConnStatus = .connStatusUninitialized, pub fn connect(config: Config) !Conn { const allocator = config.allocator; @@ -26,24 +33,83 @@ pub fn connect(config: Config) !Conn { .net => |addr| try std.net.tcpConnectToAddress(addr), .unix => |path| try std.net.connectUnixSocket(path), }; + var res = Conn{ + .stream = stream, + .config = config, + }; + errdefer res.deinit(); var writer = stream.writer(); - - errdefer stream.close(); + var dr = diagnosticReader(10000, stream.reader()); + var reader = dr.reader(); var params = SSHashMap.init(allocator); - errdefer params.deinit(); try params.put("user", config.user); - if (config.database) |database| try params.put(database); + if (config.database) |database| try params.put("database", database); var sm = StartupMessage{ .parameters = params, }; defer sm.deinit(allocator); try sm.write(allocator, writer); + lp: while (true) { + const response_type = try reader.readByte(); + switch (response_type) { + ErrorResponse.Tag => { + var err = try read_message(ErrorResponse, allocator, reader); + defer err.deinit(allocator); + log.err("Error connecting to server {any}", .{err}); + return ServerError.ErrorResponse; + }, + AuthenticationRequest.Tag => { + var ar = try read_message(AuthenticationRequest, allocator, reader); + defer ar.deinit(allocator); + // TODO handle the authentication request + log.info("authentication request", .{}); + }, + ReadyForQuery.Tag => { + var rfq = try read_message(ReadyForQuery, allocator, reader); + defer rfq.deinit(allocator); + // TODO do something about transaction state? + res.status = .connStatusIdle; + log.info("ready for query", .{}); + break :lp; + }, + ParameterStatus.Tag => { + var ps = try read_message(ParameterStatus, allocator, reader); + defer ps.deinit(allocator); + // TODO Handle this somehow? + log.info("ParameterStatus: {s}:{s}", .{ps.name, ps.value}); + }, + BackendKeyData.Tag =>{ + var bkd = try read_message(BackendKeyData, allocator, reader); + defer bkd.deinit(allocator); + log.info("BackendKeyData process_id {} secret_key {}" , .{bkd.process_id, bkd.secret_key}); + }, + else => { + log.err("unhandled message type [{c}]", .{response_type}); + const diag = try dr.get(allocator); + defer allocator.free(diag); + log.err("diag [{s}]", .{diag}); + return ProtocolError.WrongMessageType; + }, + } + } + return res; +} + +fn deinit(self: *Conn) void { + self.stream.close(); +} + +test "connect" { + // must have a local postgres runnning + // TODO maybe use docker to start one? + const allocator = std.testing.allocator; + const cfg = Config{ + .allocator = allocator, + .address = .{.unix = "/run/postgresql/.s.PGSQL.5432"}, + .database = "martin", + .user = "martin", + }; + var conn = try Conn.connect(cfg); + defer conn.deinit(); } -const StartupMessageResponseType = enum(u8) { - ErrorResponse = 'E', - AuthenticationResponse = AuthenticationOk.Tag, // All the authentication responses share a message type and must be decoded by the next field -}; -const StartupMessageResponse = union(StartupMessageResponseType) { - error: ErrorResponse, -};
\ No newline at end of file diff --git a/src/error_response.zig b/src/error_response.zig index 4d99ba4..3b66c15 100644 --- a/src/error_response.zig +++ b/src/error_response.zig @@ -4,7 +4,7 @@ const ByteArrayList = std.ArrayList(u8); const ProtocolError = @import("main.zig").ProtocolError; const ErrorResponse = @This(); -const Tag: u8 = 'E'; +pub const Tag: u8 = 'E'; buf: ?[]const u8 = null, // owned severity: []const u8, diff --git a/src/main.zig b/src/main.zig index e540256..e539f2f 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,10 +1,13 @@ const std = @import("std"); const testing = std.testing; const StartupMessage = @import("startup_message.zig"); -const AuthenticationOk = @import("authentication_ok.zig"); -const AuthenticationCleartextPassword = @import("authentication_cleartext_password.zig"); +const AuthenticationRequest = @import("authentication_request.zig"); const PasswordMessage = @import("password_message.zig"); const ErrorResponse = @import("error_response.zig"); +const ReadyForQuery = @import("ready_for_query.zig"); +const ParameterStatus = @import("parameter_status.zig"); +const BackendKeyData = @import("backend_key_data.zig"); +const Conn = @import("conn.zig"); pub const ProtocolError = error{ InvalidProtocolVersion, @@ -12,15 +15,16 @@ pub const ProtocolError = error{ InvalidMessageLength, InvalidAuthType, MissingField, + WrongMessageType, + InvalidTransactionStatus, }; pub const ClientError = error{ UnsupportedAuthType, }; -pub const AuthType = enum(u32) { - AuthTypeOk = 0, - AuthTypeCleartextPassword = 3, +pub const ServerError = error{ + ErrorResponse, }; // Fallible version of enumFromInt @@ -38,10 +42,77 @@ pub fn enum_from_int(comptime e: type, i: anytype) ?e { } } +// Tag should already have been read in order to determine msg_type! +pub fn read_message(comptime msg_type: type, allocator: std.mem.Allocator, stream_reader: anytype) !msg_type { + if (!@hasDecl(msg_type, "Tag")) @compileError("msg_type must have a Tag declaration!"); + if (!@hasDecl(msg_type, "read")) @compileError("msg_type must have a read() function!"); + const len = try stream_reader.readIntBig(u32); + const buf = try allocator.alloc(u8, @as(u32, @intCast(len-4))); + defer allocator.free(buf); + try stream_reader.readNoEof(buf); + return try msg_type.read(allocator, buf); +} + + +pub fn diagnosticReader(comptime n: usize, base_reader: anytype) DiagnosticReader(n, @TypeOf(base_reader)) { + return .{.child_reader = base_reader}; +} + +// keeps a buffer of the last n bytes read +pub fn DiagnosticReader(comptime n: usize, comptime ReaderType: anytype) type { + return struct { + child_reader: ReaderType, + ring: [n]u8 = [_]u8{0}**n, + pos: usize = 0, + + pub const Error = ReaderType.Error; + pub const Reader = std.io.Reader(*@This(), Error, read); + + pub fn read(self: *@This(), buf: []u8) Error!usize { + const amt = try self.child_reader.read(buf); + for (0..amt) |i| { + self.ring[self.pos] = buf[i]; + self.pos += 1; + self.pos %= n; + } + return amt; + } + + pub fn reader(self: *@This()) Reader { + return .{ .context = self }; + } + + // Caller frees + pub fn get(self: @This(), allocator: std.mem.Allocator) ![]const u8 { + var buf = try allocator.alloc(u8, n); + errdefer allocator.free(buf); + @memcpy(buf[0..(n-self.pos)], self.ring[self.pos..n]); + @memcpy(buf[(n-self.pos)..n], self.ring[0..self.pos]); + return buf; + } + }; +} + +test "diagnostc reader" { + const a = std.testing.allocator; + const string = "The quick brown fox jumped over the lazy dog"; + var fbs = std.io.fixedBufferStream(string); + var dr = diagnosticReader(15, fbs.reader()); + var reader = dr.reader(); + var buf = [_]u8{0}**20; + try reader.readNoEof(&buf); + const diag = try dr.get(a); + defer a.free(diag); + try std.testing.expectEqualStrings("uick brown fox ", diag); +} + test { _ = StartupMessage; - _ = AuthenticationOk; - _ = AuthenticationCleartextPassword; + _ = AuthenticationRequest; _ = PasswordMessage; _ = ErrorResponse; + _ = Conn; + _ = ReadyForQuery; + _ = ParameterStatus; + _ = BackendKeyData; } diff --git a/src/parameter_status.zig b/src/parameter_status.zig new file mode 100644 index 0000000..d938fd2 --- /dev/null +++ b/src/parameter_status.zig @@ -0,0 +1,64 @@ +const std = @import("std"); +const log = std.log.scoped(.pgz); +const ByteArrayList = std.ArrayList(u8); +const ProtocolError = @import("main.zig").ProtocolError; +const ClientError = @import("main.zig").ClientError; + +const ParameterStatus = @This(); +pub const Tag: u8 = 'S'; + +buf: ?[]const u8 = null, // owned +name: []const u8, +value: []const u8, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ParameterStatus { + var res: ParameterStatus = undefined; + res.buf = try allocator.dupe(u8, b); + var it = std.mem.splitScalar(u8, res.buf.?, 0); + res.name = it.first(); + res.value = it.next() orelse return ProtocolError.MissingField; + return res; +} +pub fn write(self: ParameterStatus, a: std.mem.Allocator, stream_writer: anytype) !void { + try stream_writer.writeByte(Tag); + var al = ByteArrayList.init(a); + defer al.deinit(); + var cw = std.io.countingWriter(al.writer()); + var writer = cw.writer(); + try writer.writeIntBig(u32, 0); // length placeholder + try writer.writeAll(self.name); + try writer.writeByte(0); + try writer.writeAll(self.value); + try writer.writeByte(0); + std.mem.writeIntBig(u32, al.items[0..4], @as(u32,@intCast(cw.bytes_written))); // Fix length + try stream_writer.writeAll(al.items); +} +pub fn deinit(self: *ParameterStatus, allocator: std.mem.Allocator) void { + if (self.buf != null) allocator.free(self.buf.?); +} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ParameterStatus{ + .name = "Hello", + .value = "world", + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ParameterStatus.read(allocator, buf); + defer sm2.deinit(allocator); + try std.testing.expectEqualStrings("Hello", sm2.name); + try std.testing.expectEqualStrings("world", sm2.value); +} diff --git a/src/ready_for_query.zig b/src/ready_for_query.zig new file mode 100644 index 0000000..883d9e1 --- /dev/null +++ b/src/ready_for_query.zig @@ -0,0 +1,55 @@ +const std = @import("std"); +const ProtocolError = @import("main.zig").ProtocolError; +const enum_from_int = @import("main.zig").enum_from_int; +const ByteArrayList = std.ArrayList(u8); + +const ReadyForQuery = @This(); +pub const Tag: u8 = 'Z'; + +const TransactionStatus = enum(u8) { + idle = 'I', + transaction = 'T', + err = 'E', +}; + +transaction_status: TransactionStatus, + +pub fn read(allocator: std.mem.Allocator, b: []const u8) !ReadyForQuery { + _ = allocator; + if (b.len != 1) return ProtocolError.InvalidMessageLength; + return .{ + .transaction_status = enum_from_int(TransactionStatus,b[0]) orelse return ProtocolError.InvalidTransactionStatus + }; +} +pub fn write(self: ReadyForQuery, allocator: std.mem.Allocator, stream_writer: anytype) !void { + _ = allocator; + try stream_writer.writeByte(Tag); + try stream_writer.writeIntBig(u32, 5); + try stream_writer.writeByte(@intFromEnum(self.transaction_status)); +} +pub fn deinit(_: *ReadyForQuery, _: std.mem.Allocator) void {} + +test "round trip" { + const allocator = std.testing.allocator; + var sm = ReadyForQuery{ + .transaction_status = TransactionStatus.idle, + }; + defer sm.deinit(allocator); + + var bal = ByteArrayList.init(allocator); + defer bal.deinit(); + try sm.write(allocator, bal.writer()); + + var fbs = std.io.fixedBufferStream(bal.items); + var reader = fbs.reader(); + const tag = try reader.readByte(); + try std.testing.expectEqual(Tag, tag); + const len = try reader.readIntBig(u32); + const buf = try allocator.alloc(u8, len - 4); + defer allocator.free(buf); + try reader.readNoEof(buf); + var sm2 = try ReadyForQuery.read(allocator, buf); + defer sm2.deinit(allocator); + + try std.testing.expectEqual(TransactionStatus.idle, sm2.transaction_status); +} |