const std = @import("std"); const log = std.log.scoped(.pgz); const SSHashMap = std.StringHashMap([]const u8); const Config = @import("config.zig"); const proto = @import("../proto/proto.zig"); const StartupMessage = proto.StartupMessage; const PasswordMessage = proto.PasswordMessage; const BackendMessage = proto.BackendMessage; const RowDescription = proto.RowDescription; const read_message = proto.read_message; const ProtocolError = @import("../main.zig").ProtocolError; const ServerError = @import("../main.zig").ServerError; const ClientError = @import("../main.zig").ClientError; const diagnosticReader = @import("../main.zig").diagnosticReader; const Conn = @This(); const ConnStatus = enum { connStatusUninitialized, connStatusConnecting, connStatusClosed, connStatusIdle, connStatusBusy, }; allocator: std.mem.Allocator, stream: std.net.Stream, config: Config, status: ConnStatus = .connStatusUninitialized, pub fn connect(config: Config) !Conn { const allocator = config.allocator; var stream = switch (config.address) { .net => |addr| try std.net.tcpConnectToAddress(addr), .unix => |path| try std.net.connectUnixSocket(path), }; var res = Conn{ .allocator = allocator, .stream = stream, .config = config, }; errdefer res.deinit(); var writer = stream.writer(); var dr = diagnosticReader(100, stream.reader()); var reader = dr.reader(); _ = reader; var params = SSHashMap.init(allocator); try params.put("user", config.user); if (config.database) |database| try params.put("database", database); var sm = StartupMessage{ .parameters = params, }; defer sm.deinit(allocator); try sm.write(allocator, writer); lp: while (true) { var anymsg = try res.receive_message(); defer anymsg.deinit(allocator); switch (anymsg) { .ReadyForQuery => { break :lp; }, .AuthenticationRequest => |ar| { switch (ar.inner_type) { .AuthRequestTypeOk => {}, .AuthRequestTypeCleartextPassword => { if (config.password) |password| { const pm = PasswordMessage{ .password = password }; try pm.write(allocator, writer); } else { return ClientError.NoPasswordSupplied; } }, } }, else => { // deliberately do nothing, we must wait for ready before the connection can be used. }, } } return res; } // Messages should always be received through this function. // this'll handle generic stuff that should happen on the connection fn receive_message(self: *Conn) !BackendMessage { var anymsg = try read_message(self.allocator, self.stream.reader()); errdefer anymsg.deinit(self.allocator); switch (anymsg) { .ReadyForQuery => { // TODO handle TxStatus }, .ParameterStatus => { // TODO handle parameter status }, .ErrorResponse => |err| { if (std.mem.eql(u8, err.severity, "FATAL")) { self.status = .connStatusClosed; // TODO close the connection here? But it should really be the caller's responsiblilty return ServerError.ErrorResponse; } }, // .NoticeResponse => { // // TODO handle notice response // }, // .NotificationResponse => { // // TODO handle notificationResponse // }, .BackendKeyData => { // TODO handle backend key data }, else => { // deliberately do nothing, caller can presumably handle them. }, } return anymsg; } pub fn deinit(self: *Conn) void { self.stream.close(); } pub const ResultIterator = struct { conn: *Conn, multi_iterator: ?*MultiResultIterator = null, row_description: ?proto.RowDescription = null, current_datarow: ?proto.DataRow = null, command_complete: ?proto.CommandComplete = null, pub fn init(conn: *Conn) ResultIterator { return .{ .conn = conn, }; } pub fn deinit(self: *ResultIterator) void { if (self.row_description != null) self.row_description.?.deinit(self.conn.allocator); if (self.current_datarow != null) self.current_datarow.?.deinit(self.conn.allocator); if (self.command_complete != null) self.command_complete.?.deinit(self.conn.allocator); } // NextRow advances the ResultIterator to the next row and returns a row if one is available. // or null if we've reached the end of the reuslt. pub fn next_row(self: *ResultIterator) !?[][]const u8 { while (self.command_complete == null) { var msg = try self.receive_message(); switch (msg) { .DataRow => { return self.current_datarow.?.columns; }, else => { msg.deinit(self.conn.allocator); }, } } return null; } pub fn skip_to_end(self: *ResultIterator) !void { while (self.command_complete == null) { _ = try self.receive_message(); } } fn receive_message(self: *ResultIterator) !BackendMessage { var msg = if (self.multi_iterator == null) try self.conn.receive_message() else try self.multi_iterator.?.receive_message(); switch (msg) { .DataRow => |dr| { if (self.current_datarow != null) self.current_datarow.?.deinit(self.conn.allocator); self.current_datarow = try dr.clone(self.conn.allocator); }, .RowDescription => |rd| { if (self.row_description != null) return ProtocolError.UnexpectedMessage; self.row_description = try rd.clone(self.conn.allocator); }, .CommandComplete => |cc| { if (self.command_complete != null) return ProtocolError.UnexpectedMessage; self.command_complete = try cc.clone(self.conn.allocator); }, } return msg; } }; pub const MultiResultIterator = struct { conn: *Conn, cri: ?*ResultIterator, // returns the next result iterator, or null if we've reached the end of the results pub fn next_result(self: *MultiResultIterator) !?*ResultIterator { if (self.cri != null) { try self.cri.?.skip_to_end(); } } fn receive_message(self: *MultiResultIterator) !BackendMessage { var msg = try self.conn.receive_message(); switch (msg) {} return msg; } }; // pub fn exec(self: *Conn) { // } test "connect unix" { // must have a local postgres runnning // TODO maybe use docker to start one? const allocator = std.testing.allocator; const cfg = Config{ .allocator = allocator, .address = .{ .unix = "/run/postgresql/.s.PGSQL.5432" }, .database = "martin", .user = "martin", }; var conn = try Conn.connect(cfg); defer conn.deinit(); } test "connect tcp with password" { const allocator = std.testing.allocator; const cfg = Config{ .allocator = allocator, .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } }, .database = "martin", .user = "martin", .password = "martin", }; var conn = try Conn.connect(cfg); defer conn.deinit(); } test "connect tcp with wrong password" { // TODO how to disable failing tests on error log // const allocator = std.testing.allocator; // const cfg = Config{ // .allocator = allocator, // .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } }, // .database = "martin", // .user = "martin", // .password = "foobar", // }; // try std.testing.expectError(ServerError.ErrorResponse, Conn.connect(cfg)); }