smtp-zig

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

main.zig (13091B)


      1 const std = @import("std");
      2 const testing = std.testing;
      3 
      4 // Library for sending email over SMTP.
      5 // Inspired by golang's standard lib SMTP client.
      6 // https://cs.opensource.google/go/go/+/master:src/net/smtp/smtp.go
      7 
      8 const log = std.log.scoped(.smtp);
      9 
     10 /// Sends an email.
     11 /// addr is the address of the mail server
     12 /// auth is authentication details for the mail server, may be null if server doesn't require auth.
     13 /// from is sender's email address
     14 /// to is list of recipient email addresses
     15 /// msg is the actual, interesting message.
     16 /// Note this library makes no effort to ensure your mail is actually a valid email, it should conform to rfc5322 otherwise servers may reject it.
     17 pub fn send_mail(allocator: std.mem.Allocator, server: []const u8, auth: ?Auth, from: []const u8, to: []const []const u8, msg: []const u8) !void {
     18     var client = try Client.init(allocator, server, auth);
     19     defer client.deinit();
     20     try client.send_mail(from, to, msg);
     21     try client.quit();
     22 }
     23 
     24 pub const Auth = struct {
     25     user: []const u8,
     26     pass: []const u8,
     27 };
     28 
     29 /// A Client represents a single client connection to an SMTP server.
     30 /// which could be used for sending many mails
     31 pub const Client = struct {
     32     const AuthCaps = struct {
     33         plain: bool = false,
     34         login: bool = false,
     35     };
     36     const Capabilities = struct {
     37         starttls: bool = false,
     38         pipelining: bool = false,
     39         auth: ?AuthCaps = null,
     40 
     41         fn parse(ehlo_response: []const u8) Capabilities {
     42             var caps = Capabilities{};
     43             var spl = std.mem.splitSequence(u8, ehlo_response, "\r\n");
     44             var nxt = spl.next(); // First line of EHLO response is just hi
     45             nxt = spl.next();
     46             while (nxt != null) : (nxt = spl.next()) {
     47                 if (std.mem.eql(u8, nxt.?, "STARTTLS")) {
     48                     caps.starttls = true;
     49                 } else if (std.mem.startsWith(u8, nxt.?, "AUTH")) {
     50                     caps.auth = AuthCaps{};
     51                     var spl2 = std.mem.splitScalar(u8, nxt.?, ' ');
     52                     var nxt2 = spl2.next();
     53                     nxt2 = spl2.next(); // skip the AUTH part
     54                     while (nxt2 != null) : (nxt2 = spl2.next()) {
     55                         if (std.mem.eql(u8, nxt2.?, "PLAIN")) {
     56                             caps.auth.?.plain = true;
     57                         } else if (std.mem.eql(u8, nxt2.?, "LOGIN")) {
     58                             caps.auth.?.login = true;
     59                         } else {
     60                             log.warn("unrecognised auth mechanism {s}", .{nxt2.?});
     61                         }
     62                     }
     63                 } else {
     64                     log.warn("unrecognised capability {s}", .{nxt.?});
     65                 }
     66             }
     67             return caps;
     68         }
     69     };
     70     const WriteError = error{} || std.net.Stream.WriteError || TlsError;
     71 
     72     fn writer(self: *Client) std.io.Writer(*Client, WriteError, write) {
     73         return .{
     74             .context = self,
     75         };
     76     }
     77 
     78     fn write(context: *Client, bytes: []const u8) WriteError!usize {
     79         if (context.use_tls) {
     80             return context.tls_stream.?.write(context.stream, bytes);
     81         } else {
     82             return context.stream.write(bytes);
     83         }
     84     }
     85 
     86     const ReadError = error{} || std.net.Stream.ReadError || TlsError;
     87 
     88     fn reader(self: *Client) std.io.Reader(*Client, ReadError, read) {
     89         return .{
     90             .context = self,
     91         };
     92     }
     93 
     94     fn read(context: *Client, buffer: []u8) ReadError!usize {
     95         if (context.use_tls) {
     96             return context.tls_stream.?.read(context.stream, buffer);
     97         } else {
     98             return context.stream.read(buffer);
     99         }
    100     }
    101 
    102     allocator: std.mem.Allocator,
    103     stream: std.net.Stream,
    104 
    105     capabilities: Capabilities = .{},
    106     use_tls: bool = false,
    107     tls_stream: ?std.crypto.tls.Client = null,
    108     tls_cert_bundle: ?std.crypto.Certificate.Bundle = null,
    109 
    110     /// Creates a new SMTP client
    111     /// 'server' must be "hostname:port"
    112     pub fn init(allocator: std.mem.Allocator, server: []const u8, maybe_auth: ?Auth) !Client {
    113         var spl = std.mem.splitScalar(u8, server, ':');
    114         const host = spl.first();
    115         const port_str = spl.rest();
    116         const port = try std.fmt.parseInt(u16, port_str, 10);
    117 
    118         const stream = try std.net.tcpConnectToHost(allocator, host, port);
    119         var client = Client{
    120             .allocator = allocator,
    121             .stream = stream,
    122         };
    123         errdefer client.deinit();
    124         allocator.free(try client.read_expect_code(220));
    125 
    126         try client.write_line("EHLO localhost");
    127         const ehlo_response = try client.read_expect_code(250);
    128         defer allocator.free(ehlo_response);
    129         client.capabilities = Capabilities.parse(ehlo_response);
    130 
    131         if (client.capabilities.starttls) {
    132             try client.write_line("STARTTLS");
    133             allocator.free(try client.read_expect_code(220));
    134 
    135             var bundle = std.crypto.Certificate.Bundle{};
    136             errdefer bundle.deinit(allocator);
    137             try std.crypto.Certificate.Bundle.rescan(&bundle, allocator);
    138             client.tls_stream = try std.crypto.tls.Client.init(stream, bundle, host);
    139             client.tls_cert_bundle = bundle;
    140             client.use_tls = true;
    141 
    142             // Redo the EHLO now we're using TLS, the capabilities might change to e.g. allow login.
    143             try client.write_line("EHLO localhost");
    144             const ehlo_response_poststarttls = try client.read_expect_code(250);
    145             defer allocator.free(ehlo_response_poststarttls);
    146             client.capabilities = Capabilities.parse(ehlo_response_poststarttls);
    147         }
    148 
    149         if (maybe_auth) |auth| {
    150             if (client.capabilities.auth) |auth_caps| {
    151                 if (auth_caps.plain) {
    152                     // TODO there has to be a nicer way than this surely
    153                     // TODO support identity as well as user+pass
    154                     const z = try std.fmt.allocPrint(allocator, "{s}\x00{s}\x00{s}", .{ "", auth.user, auth.pass });
    155                     defer allocator.free(z);
    156                     const enc = std.base64.standard.Encoder;
    157                     const zz = try allocator.alloc(u8, enc.calcSize(z.len));
    158                     defer allocator.free(zz);
    159                     const zzz = enc.encode(zz, z);
    160                     const line = try std.fmt.allocPrint(allocator, "AUTH PLAIN {s}", .{zzz});
    161                     defer allocator.free(line);
    162                     try client.write_line(line);
    163                     allocator.free(try client.read_expect_code(235));
    164                 }
    165             }
    166         }
    167         return client;
    168     }
    169 
    170     fn deinit(self: *Client) void {
    171         self.stream.close();
    172         if (self.use_tls) {
    173             self.tls_cert_bundle.?.deinit(self.allocator);
    174         }
    175     }
    176 
    177     /// lines in SMTP have \r\n terminator :shrug:
    178     /// caller owns the result
    179     fn read_line(self: *Client) ![]const u8 {
    180         var rdr = self.reader();
    181         var res = std.ArrayList(u8).init(self.allocator);
    182         var wtr = res.writer();
    183         while (true) {
    184             const b = try rdr.readByte();
    185             if (b == '\r') {
    186                 const b2 = try rdr.readByte();
    187                 if (b2 == '\n') {
    188                     break;
    189                 } else {
    190                     try wtr.writeByte(b);
    191                     try wtr.writeByte(b2);
    192                     continue;
    193                 }
    194             }
    195             try wtr.writeByte(b);
    196         }
    197         const res2 = try res.toOwnedSlice();
    198         log.debug("< {s}", .{res2});
    199         return res2;
    200     }
    201 
    202     // Reads from the stream.
    203     // Collects response lines (without code) into the result, which the caller owns.
    204     // Checks the numeric code on response lines.
    205     fn read_expect_code(self: *Client, expect_code: u16) ![]const u8 {
    206         var res = std.ArrayList(u8).init(self.allocator);
    207         errdefer res.deinit();
    208 
    209         while (true) {
    210             var line = try self.read_line();
    211             defer self.allocator.free(line);
    212             var spl = std.mem.splitAny(u8, line, " -");
    213             if (std.fmt.parseInt(u16, spl.first(), 10)) |code| {
    214                 if (code != expect_code) {
    215                     log.err("invalid response code {}", .{code});
    216                     return Client.Error.InvalidResponseCode;
    217                 }
    218             } else |_| {
    219                 log.err("malformatted line [{s}]", .{line});
    220                 return Client.Error.Malformatted;
    221             }
    222             if (line.len < 3) {
    223                 log.err("malformatted line [{s}]", .{line});
    224                 return Client.Error.Malformatted;
    225             }
    226             try res.appendSlice(line[4..]);
    227             // Continuation lines have "123-FOO"
    228             // Final line has a space like "123 BAR"
    229             if (line[3] == ' ') {
    230                 break;
    231             } else {
    232                 try res.appendSlice("\r\n");
    233             }
    234         }
    235         return try res.toOwnedSlice();
    236     }
    237 
    238     fn write_line(self: *Client, line: []const u8) !void {
    239         log.debug("> {s}", .{line});
    240         var wtr = self.writer();
    241         try wtr.writeAll(line);
    242         try wtr.writeAll("\r\n");
    243     }
    244 
    245     pub fn mail(self: *Client, from: []const u8) !void {
    246         const line = try std.fmt.allocPrint(self.allocator, "MAIL FROM:<{s}>", .{from});
    247         defer self.allocator.free(line);
    248         try self.write_line(line);
    249         const r = try self.read_expect_code(250);
    250         self.allocator.free(r);
    251     }
    252 
    253     pub fn rcpt(self: *Client, to: []const u8) !void {
    254         const line = try std.fmt.allocPrint(self.allocator, "RCPT TO:<{s}>", .{to});
    255         defer self.allocator.free(line);
    256         try self.write_line(line);
    257         const r = try self.read_expect_code(250);
    258         self.allocator.free(r);
    259     }
    260 
    261     pub fn data(self: *Client, rdr: anytype) !void {
    262         try self.write_line("DATA");
    263         self.allocator.free(try self.read_expect_code(354));
    264         var buf: [1024]u8 = undefined;
    265         var r = try rdr.read(&buf);
    266         while (r > 0) : (r = try rdr.read(&buf)) {
    267             // TODO proper dot encoding, ensure any existing sequences of \r\n.\r\n are escaped I guess
    268             log.debug("> {s}", .{buf[0..r]});
    269             try self.writer().writeAll(buf[0..r]);
    270         }
    271         try self.write_line("\r\n.");
    272         self.allocator.free(try self.read_expect_code(250));
    273     }
    274 
    275     pub fn rset(self: *Client) !void {
    276         try self.write_line("RSET");
    277         self.allocator.free(try self.read_expect_code(250));
    278     }
    279 
    280     pub fn quit(self: *Client) !void {
    281         try self.write_line("QUIT");
    282         self.allocator.free(try self.read_expect_code(221));
    283     }
    284 
    285     pub fn send_mail(self: *Client, from: []const u8, to: []const []const u8, msg: []const u8) !void {
    286         try self.rset();
    287         try self.mail(from);
    288         for (to) |t| {
    289             try self.rcpt(t);
    290         }
    291         var fbs = std.io.fixedBufferStream(msg);
    292         try self.data(fbs.reader());
    293     }
    294 
    295     const Error = error{ Malformatted, InvalidResponseCode };
    296 };
    297 
    298 // TODO why doesn't tls module define an error set?
    299 const TlsError = error{
    300     Overflow,
    301     TlsAlertUnexpectedMessage,
    302     TlsAlertBadRecordMac,
    303     TlsAlertRecordOverflow,
    304     TlsAlertHandshakeFailure,
    305     TlsAlertBadCertificate,
    306     TlsAlertUnsupportedCertificate,
    307     TlsAlertCertificateRevoked,
    308     TlsAlertCertificateExpired,
    309     TlsAlertCertificateUnknown,
    310     TlsAlertIllegalParameter,
    311     TlsAlertUnknownCa,
    312     TlsAlertAccessDenied,
    313     TlsAlertDecodeError,
    314     TlsAlertDecryptError,
    315     TlsAlertProtocolVersion,
    316     TlsAlertInsufficientSecurity,
    317     TlsAlertInternalError,
    318     TlsAlertInappropriateFallback,
    319     TlsAlertMissingExtension,
    320     TlsAlertUnsupportedExtension,
    321     TlsAlertUnrecognizedName,
    322     TlsAlertBadCertificateStatusResponse,
    323     TlsAlertUnknownPskIdentity,
    324     TlsAlertCertificateRequired,
    325     TlsAlertNoApplicationProtocol,
    326     TlsAlertUnknown,
    327     TlsUnexpectedMessage,
    328     TlsIllegalParameter,
    329     TlsRecordOverflow,
    330     TlsBadRecordMac,
    331     TlsConnectionTruncated,
    332     TlsDecodeError,
    333     TlsBadLength,
    334 };
    335 
    336 test "send email" {
    337     const mail =
    338         \\Subject: test
    339         \\From: <martin@mfashby.net>
    340         \\To: <martin@mfashby.net>
    341         \\
    342         \\This is a test message
    343     ;
    344     const user = std.posix.getenv("SMTP_USERNAME").?;
    345     const pass = std.posix.getenv("SMTP_PASSWORD").?;
    346     var client = try Client.init(std.testing.allocator, "mail.mfashby.net:587", .{ .user = user, .pass = pass });
    347     defer client.deinit();
    348     try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail);
    349     const mail2 =
    350         \\Subject: test2
    351         \\From: <martin@mfashby.net>
    352         \\To: <martin@mfashby.net>
    353         \\
    354         \\This is another test message
    355     ;
    356     try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail2);
    357     try client.quit(); // Be nice (but we don't have to)
    358 }
    359 
    360 // TODO lots more tests.
    361 // TODO use mock SMTP server to integration test local only.