diff options
author | Martin Ashby <martin@ashbysoft.com> | 2023-08-24 17:35:37 +0100 |
---|---|---|
committer | Martin Ashby <martin@ashbysoft.com> | 2023-08-24 17:35:37 +0100 |
commit | 6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed (patch) | |
tree | 6655b794f964a000433916d882de980757bb1913 /src/main.zig | |
download | smtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.tar.gz smtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.tar.bz2 smtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.tar.xz smtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.zip |
Initial
Diffstat (limited to 'src/main.zig')
-rw-r--r-- | src/main.zig | 362 |
1 files changed, 362 insertions, 0 deletions
diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..7e48a92 --- /dev/null +++ b/src/main.zig @@ -0,0 +1,362 @@ +const std = @import("std"); +const testing = std.testing; + +// Library for sending email over SMTP. +// Inspired by golang's standard lib SMTP client. +// https://cs.opensource.google/go/go/+/master:src/net/smtp/smtp.go + +const log = std.log.scoped(.smtp); + +/// Sends an email. +/// addr is the address of the mail server +/// auth is authentication details for the mail server, may be null if server doesn't require auth. +/// from is sender's email address +/// to is list of recipient email addresses +/// msg is the actual, interesting message. +/// Note this library makes no effort to ensure your mail is actually a valid email, it should conform to rfc5322 otherwise servers may reject it. +pub fn send_mail(allocator: std.mem.Allocator, server: []const u8, auth: ?Auth, from: []const u8, to: []const []const u8, msg: [] const u8) !void { + var client = try Client.init(allocator, server, auth); + defer client.deinit(); + try client.send_mail(from, to, msg); + try client.quit(); +} + +pub const Auth = struct { + user: []const u8, + pass: []const u8, +}; + +/// A Client represents a single client connection to an SMTP server. +/// which could be used for sending many mails +pub const Client = struct { + const AuthCaps = struct { + plain: bool = false, + login: bool = false, + }; + const Capabilities = struct { + starttls: bool = false, + pipelining: bool = false, + auth: ?AuthCaps = null, + + fn parse(ehlo_response: []const u8) Capabilities { + var caps = Capabilities{}; + var spl = std.mem.splitSequence(u8, ehlo_response, "\r\n"); + var nxt = spl.next(); // First line of EHLO response is just hi + nxt = spl.next(); + while (nxt != null): (nxt = spl.next()) { + if (std.mem.eql(u8, nxt.?, "STARTTLS")) { + caps.starttls = true; + } if (std.mem.startsWith(u8, nxt.?, "AUTH")) { + caps.auth = AuthCaps{}; + var spl2 = std.mem.splitScalar(u8, nxt.?, ' '); + var nxt2 = spl2.next(); + nxt2 = spl2.next(); // skip the AUTH part + while (nxt2 != null): (nxt2 = spl2.next()) { + if (std.mem.eql(u8, nxt2.?, "PLAIN")) { + caps.auth.?.plain = true; + } else if (std.mem.eql(u8, nxt2.?, "LOGIN")) { + caps.auth.?.login = true; + } else { + log.warn("unrecognised auth mechanism {s}", .{nxt2.?}); + } + } + + } else { + log.warn("unrecognised capability {s}", .{nxt.?}); + } + } + return caps; + } + }; + const WriteError = error{} || std.net.Stream.WriteError || TlsError; + + fn writer(self: *Client) std.io.Writer(*Client, WriteError, write) { + return .{ + .context = self, + }; + } + + fn write(context: *Client, bytes: []const u8) WriteError!usize { + if (context.use_tls) { + return context.tls_stream.?.write(context.stream, bytes); + } else { + return context.stream.write(bytes); + } + } + + const ReadError = error {} || std.net.Stream.ReadError || TlsError; + + fn reader(self: *Client) std.io.Reader(*Client, ReadError, read) { + return .{ + .context = self, + }; + } + + fn read(context: *Client, buffer: []u8) ReadError!usize { + if (context.use_tls) { + return context.tls_stream.?.read(context.stream, buffer); + } else { + return context.stream.read(buffer); + } + } + + allocator: std.mem.Allocator, + stream: std.net.Stream, + + capabilities: Capabilities = .{}, + use_tls: bool = false, + tls_stream: ?std.crypto.tls.Client = null, + tls_cert_bundle: ?std.crypto.Certificate.Bundle = null, + + /// Creates a new SMTP client + /// 'server' must be "hostname:port" + pub fn init(allocator: std.mem.Allocator, server: []const u8, maybe_auth: ?Auth) !Client { + var spl = std.mem.splitScalar(u8, server, ':'); + const host = spl.first(); + const port_str = spl.rest(); + const port = try std.fmt.parseInt(u16, port_str, 10); + + var stream = try std.net.tcpConnectToHost(allocator, host, port); + var client = Client{ + .allocator = allocator, + .stream = stream, + }; + errdefer client.deinit(); + allocator.free(try client.read_expect_code(220)); + + try client.write_line("EHLO localhost"); + const ehlo_response = try client.read_expect_code(250); + defer allocator.free(ehlo_response); + client.capabilities = Capabilities.parse(ehlo_response); + + if (client.capabilities.starttls) { + try client.write_line("STARTTLS"); + allocator.free(try client.read_expect_code(220)); + + var bundle = std.crypto.Certificate.Bundle{}; + errdefer bundle.deinit(allocator); + try std.crypto.Certificate.Bundle.rescan(&bundle, allocator); + client.tls_stream = try std.crypto.tls.Client.init(stream, bundle, host); + client.tls_cert_bundle = bundle; + client.use_tls = true; + + // Redo the EHLO now we're using TLS, the capabilities might change to e.g. allow login. + try client.write_line("EHLO localhost"); + const ehlo_response_poststarttls = try client.read_expect_code(250); + defer allocator.free(ehlo_response_poststarttls); + client.capabilities = Capabilities.parse(ehlo_response_poststarttls); + } + + if (maybe_auth) |auth| { + if (client.capabilities.auth) |auth_caps| { + if (auth_caps.plain) { + // TODO there has to be a nicer way than this surely + // TODO support identity as well as user+pass + const z = try std.fmt.allocPrint(allocator, "{s}\x00{s}\x00{s}", .{"", auth.user, auth.pass}); + defer allocator.free(z); + const enc = std.base64.standard.Encoder; + var zz = try allocator.alloc(u8, enc.calcSize(z.len)); + defer allocator.free(zz); + const zzz = enc.encode(zz, z); + const line = try std.fmt.allocPrint(allocator, "AUTH PLAIN {s}", .{zzz}); + defer allocator.free(line); + try client.write_line(line); + allocator.free(try client.read_expect_code(235)); + } + } + } + return client; + } + + fn deinit(self: *Client) void { + self.stream.close(); + if (self.use_tls) { + self.tls_cert_bundle.?.deinit(self.allocator); + } + } + + /// lines in SMTP have \r\n terminator :shrug: + /// caller owns the result + fn read_line(self: *Client) ![]const u8 { + var rdr = self.reader(); + var res = std.ArrayList(u8).init(self.allocator); + var wtr = res.writer(); + while (true) { + const b = try rdr.readByte(); + if (b == '\r') { + const b2 = try rdr.readByte(); + if (b2 == '\n') { + break; + } else { + try wtr.writeByte(b); + try wtr.writeByte(b2); + continue; + } + } + try wtr.writeByte(b); + } + const res2 = try res.toOwnedSlice(); + log.debug("< {s}", .{res2}); + return res2; + } + + // Reads from the stream. + // Collects response lines (without code) into the result, which the caller owns. + // Checks the numeric code on response lines. + fn read_expect_code(self: *Client, expect_code: u16) ![]const u8 { + var res = std.ArrayList(u8).init(self.allocator); + errdefer res.deinit(); + + while (true) { + var line = try self.read_line(); + defer self.allocator.free(line); + var spl = std.mem.splitAny(u8, line, " -"); + if (std.fmt.parseInt(u16, spl.first(), 10)) |code| { + if (code != expect_code) { + log.err("invalid response code {}", .{code}); + return Client.Error.InvalidResponseCode; + } + } else |_| { + log.err("malformatted line [{s}]", .{line}); + return Client.Error.Malformatted; + } + if (line.len < 3) { + log.err("malformatted line [{s}]", .{line}); + return Client.Error.Malformatted; + } + try res.appendSlice(line[4..]); + // Continuation lines have "123-FOO" + // Final line has a space like "123 BAR" + if (line[3] == ' ') { + break; + } else { + try res.appendSlice("\r\n"); + } + } + return try res.toOwnedSlice(); + } + + fn write_line(self: *Client, line: []const u8) !void { + log.debug("> {s}", .{line}); + var wtr = self.writer(); + try wtr.writeAll(line); + try wtr.writeAll("\r\n"); + } + + pub fn mail(self: *Client, from: []const u8) !void { + const line = try std.fmt.allocPrint(self.allocator, "MAIL FROM:<{s}>", .{from}); + defer self.allocator.free(line); + try self.write_line(line); + var r = try self.read_expect_code(250); + self.allocator.free(r); + } + + pub fn rcpt(self: *Client, to: []const u8) !void { + const line = try std.fmt.allocPrint(self.allocator, "RCPT TO:<{s}>", .{to}); + defer self.allocator.free(line); + try self.write_line(line); + var r = try self.read_expect_code(250); + self.allocator.free(r); + } + + pub fn data(self: *Client, rdr: anytype) !void { + try self.write_line("DATA"); + self.allocator.free(try self.read_expect_code(354)); + var buf: [1024]u8 = undefined; + var r = try rdr.read(&buf); + while (r > 0): (r = try rdr.read(&buf)) { + // TODO proper dot encoding, ensure any existing sequences of \r\n.\r\n are escaped I guess + log.debug("> {s}", .{buf[0..r]}); + try self.writer().writeAll(buf[0..r]); + } + try self.write_line("\r\n."); + self.allocator.free(try self.read_expect_code(250)); + } + + pub fn rset(self: *Client) !void { + try self.write_line("RSET"); + self.allocator.free(try self.read_expect_code(250)); + } + + pub fn quit(self: *Client) !void { + try self.write_line("QUIT"); + self.allocator.free(try self.read_expect_code(221)); + } + + pub fn send_mail(self: *Client, from: []const u8, to:[]const []const u8, msg: []const u8) !void { + try self.rset(); + try self.mail(from); + for (to) |t| { + try self.rcpt(t); + } + var fbs = std.io.fixedBufferStream(msg); + try self.data(fbs.reader()); + } + + const Error = error{Malformatted, InvalidResponseCode}; +}; + +// TODO why doesn't tls module define an error set? +const TlsError = error{ + Overflow, + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsRecordOverflow, + TlsBadRecordMac, + TlsConnectionTruncated, + TlsDecodeError, + TlsBadLength, +}; + +test "send email" { + const mail = + \\Subject: test + \\From: <martin@mfashby.net> + \\To: <martin@ashbysoft.com> + \\ + \\This is a test message + ; + const user = std.os.getenv("SMTP_USERNAME").?; + const pass = std.os.getenv("SMTP_PASSWORD").?; + var client = try Client.init(std.testing.allocator, "mail.mfashby.net:587", .{.user = user, .pass = pass}); + defer client.deinit(); + try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail); + const mail2 = + \\Subject: test2 + \\From: <martin@mfashby.net> + \\To: <martin@ashbysoft.com> + \\ + \\This is another test message + ; + try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail2); + try client.quit(); // Be nice (but we don't have to) +} + +// TODO lots more tests. +// TODO use mock SMTP server to integration test local only.
\ No newline at end of file |