const std = @import("std"); const log = std.log.scoped(.pgz); const SSHashMap = std.StringHashMap([]const u8); const Config = @import("config.zig"); const proto = @import("../proto/proto.zig"); const StartupMessage = proto.StartupMessage; const PasswordMessage = proto.PasswordMessage; const BackendMessage = proto.BackendMessage; const read_message = proto.read_message; const ProtocolError = @import("../main.zig").ProtocolError; const ServerError = @import("../main.zig").ServerError; const ClientError = @import("../main.zig").ClientError; const diagnosticReader = @import("../main.zig").diagnosticReader; const Conn = @This(); const ConnStatus = enum { connStatusUninitialized, connStatusConnecting, connStatusClosed, connStatusIdle, connStatusBusy, }; allocator: std.mem.Allocator, stream: std.net.Stream, config: Config, status: ConnStatus = .connStatusUninitialized, pub fn connect(config: Config) !Conn { const allocator = config.allocator; var stream = switch (config.address) { .net => |addr| try std.net.tcpConnectToAddress(addr), .unix => |path| try std.net.connectUnixSocket(path), }; var res = Conn{ .stream = stream, .config = config, }; errdefer res.deinit(); var writer = stream.writer(); var dr = diagnosticReader(100, stream.reader()); var reader = dr.reader(); var params = SSHashMap.init(allocator); try params.put("user", config.user); if (config.database) |database| try params.put("database", database); var sm = StartupMessage{ .parameters = params, }; defer sm.deinit(allocator); try sm.write(allocator, writer); lp: while (true) { var anymsg = try read_message(allocator, reader); defer anymsg.deinit(allocator); switch (anymsg) { .ErrorResponse => |err| { log.err("Error connecting to server {any}", .{err}); return ServerError.ErrorResponse; }, .AuthenticationRequest => |ar| { switch (ar.inner_type) { .AuthRequestTypeOk => {}, // fine do nothing! .AuthRequestTypeCleartextPassword => { if (config.password) |password| { const pm = PasswordMessage{ .password = password }; try pm.write(allocator, writer); } else { return ClientError.NoPasswordSupplied; } }, } }, .ReadyForQuery => |rfq| { // TODO do something about transaction state? res.status = .connStatusIdle; log.info("ready for query {any}", .{rfq}); break :lp; }, .ParameterStatus => |ps| { // TODO Handle this somehow? log.info("ParameterStatus: {s}:{s}", .{ ps.name, ps.value }); }, .BackendKeyData => |bkd| { log.info("BackendKeyData process_id {} secret_key {}", .{ bkd.process_id, bkd.secret_key }); }, else => |response_type| { log.err("unhandled message type [{}]", .{response_type}); const diag = try dr.get(allocator); defer allocator.free(diag); log.err("diag [{s}]", .{diag}); return ProtocolError.WrongMessageType; }, } } return res; } fn receiveMessage(self: *Conn) !BackendMessage { var anymsg = try read_message(self., reader); defer anymsg.deinit(allocator); switch (anymsg) { .ReadyForQuery => pgConn.txStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) } case *pgproto3.NotificationResponse: if pgConn.config.OnNotification != nil { pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) } } return msg, nil } pub fn deinit(self: *Conn) void { self.stream.close(); } // How to handle this ... // The Go code relies on polymorphism to generically read any message type. // I _could_ have a tagged union type thing pub const ResultIterator = struct { conn: *Conn, command_concluded: bool = false, // NextRow advances the ResultReader to the next row and returns true if a row is available. pub fn next_row(self: *ResultIterator) bool { // TODO implement var reader = self.conn.stream.reader(); switch (try reader.readByte()) { case } return false; } }; pub const MultiResultIterator = struct { conn: *Conn, fn next() ? }; // pub fn exec(self: *Conn) { // } test "connect unix" { // must have a local postgres runnning // TODO maybe use docker to start one? const allocator = std.testing.allocator; const cfg = Config{ .allocator = allocator, .address = .{ .unix = "/run/postgresql/.s.PGSQL.5432" }, .database = "martin", .user = "martin", }; var conn = try Conn.connect(cfg); defer conn.deinit(); } test "connect tcp with password" { const allocator = std.testing.allocator; const cfg = Config{ .allocator = allocator, .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } }, .database = "martin", .user = "martin", .password = "martin", }; var conn = try Conn.connect(cfg); defer conn.deinit(); } test "connect tcp with wrong password" { // TODO how to disable failing tests on error log // const allocator = std.testing.allocator; // const cfg = Config{ // .allocator = allocator, // .address = .{ .net = std.net.Address{ .in = std.net.Ip4Address.init([4]u8{ 127, 0, 0, 1 }, 5432) } }, // .database = "martin", // .user = "martin", // .password = "foobar", // }; // try std.testing.expectError(ServerError.ErrorResponse, Conn.connect(cfg)); }