aboutsummaryrefslogtreecommitdiff
path: root/src/main.zig
diff options
context:
space:
mode:
authorMartin Ashby <martin@ashbysoft.com>2023-08-24 17:35:37 +0100
committerMartin Ashby <martin@ashbysoft.com>2023-08-24 17:35:37 +0100
commit6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed (patch)
tree6655b794f964a000433916d882de980757bb1913 /src/main.zig
downloadsmtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.tar.gz
smtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.tar.bz2
smtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.tar.xz
smtp-zig-6d3d2869bb2105c3a9ffb54ed816d8f53c56c1ed.zip
Initial
Diffstat (limited to 'src/main.zig')
-rw-r--r--src/main.zig362
1 files changed, 362 insertions, 0 deletions
diff --git a/src/main.zig b/src/main.zig
new file mode 100644
index 0000000..7e48a92
--- /dev/null
+++ b/src/main.zig
@@ -0,0 +1,362 @@
+const std = @import("std");
+const testing = std.testing;
+
+// Library for sending email over SMTP.
+// Inspired by golang's standard lib SMTP client.
+// https://cs.opensource.google/go/go/+/master:src/net/smtp/smtp.go
+
+const log = std.log.scoped(.smtp);
+
+/// Sends an email.
+/// addr is the address of the mail server
+/// auth is authentication details for the mail server, may be null if server doesn't require auth.
+/// from is sender's email address
+/// to is list of recipient email addresses
+/// msg is the actual, interesting message.
+/// Note this library makes no effort to ensure your mail is actually a valid email, it should conform to rfc5322 otherwise servers may reject it.
+pub fn send_mail(allocator: std.mem.Allocator, server: []const u8, auth: ?Auth, from: []const u8, to: []const []const u8, msg: [] const u8) !void {
+ var client = try Client.init(allocator, server, auth);
+ defer client.deinit();
+ try client.send_mail(from, to, msg);
+ try client.quit();
+}
+
+pub const Auth = struct {
+ user: []const u8,
+ pass: []const u8,
+};
+
+/// A Client represents a single client connection to an SMTP server.
+/// which could be used for sending many mails
+pub const Client = struct {
+ const AuthCaps = struct {
+ plain: bool = false,
+ login: bool = false,
+ };
+ const Capabilities = struct {
+ starttls: bool = false,
+ pipelining: bool = false,
+ auth: ?AuthCaps = null,
+
+ fn parse(ehlo_response: []const u8) Capabilities {
+ var caps = Capabilities{};
+ var spl = std.mem.splitSequence(u8, ehlo_response, "\r\n");
+ var nxt = spl.next(); // First line of EHLO response is just hi
+ nxt = spl.next();
+ while (nxt != null): (nxt = spl.next()) {
+ if (std.mem.eql(u8, nxt.?, "STARTTLS")) {
+ caps.starttls = true;
+ } if (std.mem.startsWith(u8, nxt.?, "AUTH")) {
+ caps.auth = AuthCaps{};
+ var spl2 = std.mem.splitScalar(u8, nxt.?, ' ');
+ var nxt2 = spl2.next();
+ nxt2 = spl2.next(); // skip the AUTH part
+ while (nxt2 != null): (nxt2 = spl2.next()) {
+ if (std.mem.eql(u8, nxt2.?, "PLAIN")) {
+ caps.auth.?.plain = true;
+ } else if (std.mem.eql(u8, nxt2.?, "LOGIN")) {
+ caps.auth.?.login = true;
+ } else {
+ log.warn("unrecognised auth mechanism {s}", .{nxt2.?});
+ }
+ }
+
+ } else {
+ log.warn("unrecognised capability {s}", .{nxt.?});
+ }
+ }
+ return caps;
+ }
+ };
+ const WriteError = error{} || std.net.Stream.WriteError || TlsError;
+
+ fn writer(self: *Client) std.io.Writer(*Client, WriteError, write) {
+ return .{
+ .context = self,
+ };
+ }
+
+ fn write(context: *Client, bytes: []const u8) WriteError!usize {
+ if (context.use_tls) {
+ return context.tls_stream.?.write(context.stream, bytes);
+ } else {
+ return context.stream.write(bytes);
+ }
+ }
+
+ const ReadError = error {} || std.net.Stream.ReadError || TlsError;
+
+ fn reader(self: *Client) std.io.Reader(*Client, ReadError, read) {
+ return .{
+ .context = self,
+ };
+ }
+
+ fn read(context: *Client, buffer: []u8) ReadError!usize {
+ if (context.use_tls) {
+ return context.tls_stream.?.read(context.stream, buffer);
+ } else {
+ return context.stream.read(buffer);
+ }
+ }
+
+ allocator: std.mem.Allocator,
+ stream: std.net.Stream,
+
+ capabilities: Capabilities = .{},
+ use_tls: bool = false,
+ tls_stream: ?std.crypto.tls.Client = null,
+ tls_cert_bundle: ?std.crypto.Certificate.Bundle = null,
+
+ /// Creates a new SMTP client
+ /// 'server' must be "hostname:port"
+ pub fn init(allocator: std.mem.Allocator, server: []const u8, maybe_auth: ?Auth) !Client {
+ var spl = std.mem.splitScalar(u8, server, ':');
+ const host = spl.first();
+ const port_str = spl.rest();
+ const port = try std.fmt.parseInt(u16, port_str, 10);
+
+ var stream = try std.net.tcpConnectToHost(allocator, host, port);
+ var client = Client{
+ .allocator = allocator,
+ .stream = stream,
+ };
+ errdefer client.deinit();
+ allocator.free(try client.read_expect_code(220));
+
+ try client.write_line("EHLO localhost");
+ const ehlo_response = try client.read_expect_code(250);
+ defer allocator.free(ehlo_response);
+ client.capabilities = Capabilities.parse(ehlo_response);
+
+ if (client.capabilities.starttls) {
+ try client.write_line("STARTTLS");
+ allocator.free(try client.read_expect_code(220));
+
+ var bundle = std.crypto.Certificate.Bundle{};
+ errdefer bundle.deinit(allocator);
+ try std.crypto.Certificate.Bundle.rescan(&bundle, allocator);
+ client.tls_stream = try std.crypto.tls.Client.init(stream, bundle, host);
+ client.tls_cert_bundle = bundle;
+ client.use_tls = true;
+
+ // Redo the EHLO now we're using TLS, the capabilities might change to e.g. allow login.
+ try client.write_line("EHLO localhost");
+ const ehlo_response_poststarttls = try client.read_expect_code(250);
+ defer allocator.free(ehlo_response_poststarttls);
+ client.capabilities = Capabilities.parse(ehlo_response_poststarttls);
+ }
+
+ if (maybe_auth) |auth| {
+ if (client.capabilities.auth) |auth_caps| {
+ if (auth_caps.plain) {
+ // TODO there has to be a nicer way than this surely
+ // TODO support identity as well as user+pass
+ const z = try std.fmt.allocPrint(allocator, "{s}\x00{s}\x00{s}", .{"", auth.user, auth.pass});
+ defer allocator.free(z);
+ const enc = std.base64.standard.Encoder;
+ var zz = try allocator.alloc(u8, enc.calcSize(z.len));
+ defer allocator.free(zz);
+ const zzz = enc.encode(zz, z);
+ const line = try std.fmt.allocPrint(allocator, "AUTH PLAIN {s}", .{zzz});
+ defer allocator.free(line);
+ try client.write_line(line);
+ allocator.free(try client.read_expect_code(235));
+ }
+ }
+ }
+ return client;
+ }
+
+ fn deinit(self: *Client) void {
+ self.stream.close();
+ if (self.use_tls) {
+ self.tls_cert_bundle.?.deinit(self.allocator);
+ }
+ }
+
+ /// lines in SMTP have \r\n terminator :shrug:
+ /// caller owns the result
+ fn read_line(self: *Client) ![]const u8 {
+ var rdr = self.reader();
+ var res = std.ArrayList(u8).init(self.allocator);
+ var wtr = res.writer();
+ while (true) {
+ const b = try rdr.readByte();
+ if (b == '\r') {
+ const b2 = try rdr.readByte();
+ if (b2 == '\n') {
+ break;
+ } else {
+ try wtr.writeByte(b);
+ try wtr.writeByte(b2);
+ continue;
+ }
+ }
+ try wtr.writeByte(b);
+ }
+ const res2 = try res.toOwnedSlice();
+ log.debug("< {s}", .{res2});
+ return res2;
+ }
+
+ // Reads from the stream.
+ // Collects response lines (without code) into the result, which the caller owns.
+ // Checks the numeric code on response lines.
+ fn read_expect_code(self: *Client, expect_code: u16) ![]const u8 {
+ var res = std.ArrayList(u8).init(self.allocator);
+ errdefer res.deinit();
+
+ while (true) {
+ var line = try self.read_line();
+ defer self.allocator.free(line);
+ var spl = std.mem.splitAny(u8, line, " -");
+ if (std.fmt.parseInt(u16, spl.first(), 10)) |code| {
+ if (code != expect_code) {
+ log.err("invalid response code {}", .{code});
+ return Client.Error.InvalidResponseCode;
+ }
+ } else |_| {
+ log.err("malformatted line [{s}]", .{line});
+ return Client.Error.Malformatted;
+ }
+ if (line.len < 3) {
+ log.err("malformatted line [{s}]", .{line});
+ return Client.Error.Malformatted;
+ }
+ try res.appendSlice(line[4..]);
+ // Continuation lines have "123-FOO"
+ // Final line has a space like "123 BAR"
+ if (line[3] == ' ') {
+ break;
+ } else {
+ try res.appendSlice("\r\n");
+ }
+ }
+ return try res.toOwnedSlice();
+ }
+
+ fn write_line(self: *Client, line: []const u8) !void {
+ log.debug("> {s}", .{line});
+ var wtr = self.writer();
+ try wtr.writeAll(line);
+ try wtr.writeAll("\r\n");
+ }
+
+ pub fn mail(self: *Client, from: []const u8) !void {
+ const line = try std.fmt.allocPrint(self.allocator, "MAIL FROM:<{s}>", .{from});
+ defer self.allocator.free(line);
+ try self.write_line(line);
+ var r = try self.read_expect_code(250);
+ self.allocator.free(r);
+ }
+
+ pub fn rcpt(self: *Client, to: []const u8) !void {
+ const line = try std.fmt.allocPrint(self.allocator, "RCPT TO:<{s}>", .{to});
+ defer self.allocator.free(line);
+ try self.write_line(line);
+ var r = try self.read_expect_code(250);
+ self.allocator.free(r);
+ }
+
+ pub fn data(self: *Client, rdr: anytype) !void {
+ try self.write_line("DATA");
+ self.allocator.free(try self.read_expect_code(354));
+ var buf: [1024]u8 = undefined;
+ var r = try rdr.read(&buf);
+ while (r > 0): (r = try rdr.read(&buf)) {
+ // TODO proper dot encoding, ensure any existing sequences of \r\n.\r\n are escaped I guess
+ log.debug("> {s}", .{buf[0..r]});
+ try self.writer().writeAll(buf[0..r]);
+ }
+ try self.write_line("\r\n.");
+ self.allocator.free(try self.read_expect_code(250));
+ }
+
+ pub fn rset(self: *Client) !void {
+ try self.write_line("RSET");
+ self.allocator.free(try self.read_expect_code(250));
+ }
+
+ pub fn quit(self: *Client) !void {
+ try self.write_line("QUIT");
+ self.allocator.free(try self.read_expect_code(221));
+ }
+
+ pub fn send_mail(self: *Client, from: []const u8, to:[]const []const u8, msg: []const u8) !void {
+ try self.rset();
+ try self.mail(from);
+ for (to) |t| {
+ try self.rcpt(t);
+ }
+ var fbs = std.io.fixedBufferStream(msg);
+ try self.data(fbs.reader());
+ }
+
+ const Error = error{Malformatted, InvalidResponseCode};
+};
+
+// TODO why doesn't tls module define an error set?
+const TlsError = error{
+ Overflow,
+ TlsAlertUnexpectedMessage,
+ TlsAlertBadRecordMac,
+ TlsAlertRecordOverflow,
+ TlsAlertHandshakeFailure,
+ TlsAlertBadCertificate,
+ TlsAlertUnsupportedCertificate,
+ TlsAlertCertificateRevoked,
+ TlsAlertCertificateExpired,
+ TlsAlertCertificateUnknown,
+ TlsAlertIllegalParameter,
+ TlsAlertUnknownCa,
+ TlsAlertAccessDenied,
+ TlsAlertDecodeError,
+ TlsAlertDecryptError,
+ TlsAlertProtocolVersion,
+ TlsAlertInsufficientSecurity,
+ TlsAlertInternalError,
+ TlsAlertInappropriateFallback,
+ TlsAlertMissingExtension,
+ TlsAlertUnsupportedExtension,
+ TlsAlertUnrecognizedName,
+ TlsAlertBadCertificateStatusResponse,
+ TlsAlertUnknownPskIdentity,
+ TlsAlertCertificateRequired,
+ TlsAlertNoApplicationProtocol,
+ TlsAlertUnknown,
+ TlsUnexpectedMessage,
+ TlsIllegalParameter,
+ TlsRecordOverflow,
+ TlsBadRecordMac,
+ TlsConnectionTruncated,
+ TlsDecodeError,
+ TlsBadLength,
+};
+
+test "send email" {
+ const mail =
+ \\Subject: test
+ \\From: <martin@mfashby.net>
+ \\To: <martin@ashbysoft.com>
+ \\
+ \\This is a test message
+ ;
+ const user = std.os.getenv("SMTP_USERNAME").?;
+ const pass = std.os.getenv("SMTP_PASSWORD").?;
+ var client = try Client.init(std.testing.allocator, "mail.mfashby.net:587", .{.user = user, .pass = pass});
+ defer client.deinit();
+ try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail);
+ const mail2 =
+ \\Subject: test2
+ \\From: <martin@mfashby.net>
+ \\To: <martin@ashbysoft.com>
+ \\
+ \\This is another test message
+ ;
+ try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail2);
+ try client.quit(); // Be nice (but we don't have to)
+}
+
+// TODO lots more tests.
+// TODO use mock SMTP server to integration test local only. \ No newline at end of file