aboutsummaryrefslogtreecommitdiff
path: root/src/main.zig
blob: 954d93e0797aec695bcbece5a10dadc5feaf983c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
const std = @import("std");
const testing = std.testing;

// Library for sending email over SMTP.
// Inspired by golang's standard lib SMTP client.
// https://cs.opensource.google/go/go/+/master:src/net/smtp/smtp.go

const log = std.log.scoped(.smtp);

/// Sends an email.
/// addr is the address of the mail server
/// auth is authentication details for the mail server, may be null if server doesn't require auth.
/// from is sender's email address
/// to is list of recipient email addresses
/// msg is the actual, interesting message.
/// Note this library makes no effort to ensure your mail is actually a valid email, it should conform to rfc5322 otherwise servers may reject it.
pub fn send_mail(allocator: std.mem.Allocator, server: []const u8, auth: ?Auth, from: []const u8, to: []const []const u8, msg: []const u8) !void {
    var client = try Client.init(allocator, server, auth);
    defer client.deinit();
    try client.send_mail(from, to, msg);
    try client.quit();
}

pub const Auth = struct {
    user: []const u8,
    pass: []const u8,
};

/// A Client represents a single client connection to an SMTP server.
/// which could be used for sending many mails
pub const Client = struct {
    const AuthCaps = struct {
        plain: bool = false,
        login: bool = false,
    };
    const Capabilities = struct {
        starttls: bool = false,
        pipelining: bool = false,
        auth: ?AuthCaps = null,

        fn parse(ehlo_response: []const u8) Capabilities {
            var caps = Capabilities{};
            var spl = std.mem.splitSequence(u8, ehlo_response, "\r\n");
            var nxt = spl.next(); // First line of EHLO response is just hi
            nxt = spl.next();
            while (nxt != null) : (nxt = spl.next()) {
                if (std.mem.eql(u8, nxt.?, "STARTTLS")) {
                    caps.starttls = true;
                } else if (std.mem.startsWith(u8, nxt.?, "AUTH")) {
                    caps.auth = AuthCaps{};
                    var spl2 = std.mem.splitScalar(u8, nxt.?, ' ');
                    var nxt2 = spl2.next();
                    nxt2 = spl2.next(); // skip the AUTH part
                    while (nxt2 != null) : (nxt2 = spl2.next()) {
                        if (std.mem.eql(u8, nxt2.?, "PLAIN")) {
                            caps.auth.?.plain = true;
                        } else if (std.mem.eql(u8, nxt2.?, "LOGIN")) {
                            caps.auth.?.login = true;
                        } else {
                            log.warn("unrecognised auth mechanism {s}", .{nxt2.?});
                        }
                    }
                } else {
                    log.warn("unrecognised capability {s}", .{nxt.?});
                }
            }
            return caps;
        }
    };
    const WriteError = error{} || std.net.Stream.WriteError || TlsError;

    fn writer(self: *Client) std.io.Writer(*Client, WriteError, write) {
        return .{
            .context = self,
        };
    }

    fn write(context: *Client, bytes: []const u8) WriteError!usize {
        if (context.use_tls) {
            return context.tls_stream.?.write(context.stream, bytes);
        } else {
            return context.stream.write(bytes);
        }
    }

    const ReadError = error{} || std.net.Stream.ReadError || TlsError;

    fn reader(self: *Client) std.io.Reader(*Client, ReadError, read) {
        return .{
            .context = self,
        };
    }

    fn read(context: *Client, buffer: []u8) ReadError!usize {
        if (context.use_tls) {
            return context.tls_stream.?.read(context.stream, buffer);
        } else {
            return context.stream.read(buffer);
        }
    }

    allocator: std.mem.Allocator,
    stream: std.net.Stream,

    capabilities: Capabilities = .{},
    use_tls: bool = false,
    tls_stream: ?std.crypto.tls.Client = null,
    tls_cert_bundle: ?std.crypto.Certificate.Bundle = null,

    /// Creates a new SMTP client
    /// 'server' must be "hostname:port"
    pub fn init(allocator: std.mem.Allocator, server: []const u8, maybe_auth: ?Auth) !Client {
        var spl = std.mem.splitScalar(u8, server, ':');
        const host = spl.first();
        const port_str = spl.rest();
        const port = try std.fmt.parseInt(u16, port_str, 10);

        const stream = try std.net.tcpConnectToHost(allocator, host, port);
        var client = Client{
            .allocator = allocator,
            .stream = stream,
        };
        errdefer client.deinit();
        allocator.free(try client.read_expect_code(220));

        try client.write_line("EHLO localhost");
        const ehlo_response = try client.read_expect_code(250);
        defer allocator.free(ehlo_response);
        client.capabilities = Capabilities.parse(ehlo_response);

        if (client.capabilities.starttls) {
            try client.write_line("STARTTLS");
            allocator.free(try client.read_expect_code(220));

            var bundle = std.crypto.Certificate.Bundle{};
            errdefer bundle.deinit(allocator);
            try std.crypto.Certificate.Bundle.rescan(&bundle, allocator);
            client.tls_stream = try std.crypto.tls.Client.init(stream, bundle, host);
            client.tls_cert_bundle = bundle;
            client.use_tls = true;

            // Redo the EHLO now we're using TLS, the capabilities might change to e.g. allow login.
            try client.write_line("EHLO localhost");
            const ehlo_response_poststarttls = try client.read_expect_code(250);
            defer allocator.free(ehlo_response_poststarttls);
            client.capabilities = Capabilities.parse(ehlo_response_poststarttls);
        }

        if (maybe_auth) |auth| {
            if (client.capabilities.auth) |auth_caps| {
                if (auth_caps.plain) {
                    // TODO there has to be a nicer way than this surely
                    // TODO support identity as well as user+pass
                    const z = try std.fmt.allocPrint(allocator, "{s}\x00{s}\x00{s}", .{ "", auth.user, auth.pass });
                    defer allocator.free(z);
                    const enc = std.base64.standard.Encoder;
                    const zz = try allocator.alloc(u8, enc.calcSize(z.len));
                    defer allocator.free(zz);
                    const zzz = enc.encode(zz, z);
                    const line = try std.fmt.allocPrint(allocator, "AUTH PLAIN {s}", .{zzz});
                    defer allocator.free(line);
                    try client.write_line(line);
                    allocator.free(try client.read_expect_code(235));
                }
            }
        }
        return client;
    }

    fn deinit(self: *Client) void {
        self.stream.close();
        if (self.use_tls) {
            self.tls_cert_bundle.?.deinit(self.allocator);
        }
    }

    /// lines in SMTP have \r\n terminator :shrug:
    /// caller owns the result
    fn read_line(self: *Client) ![]const u8 {
        var rdr = self.reader();
        var res = std.ArrayList(u8).init(self.allocator);
        var wtr = res.writer();
        while (true) {
            const b = try rdr.readByte();
            if (b == '\r') {
                const b2 = try rdr.readByte();
                if (b2 == '\n') {
                    break;
                } else {
                    try wtr.writeByte(b);
                    try wtr.writeByte(b2);
                    continue;
                }
            }
            try wtr.writeByte(b);
        }
        const res2 = try res.toOwnedSlice();
        log.debug("< {s}", .{res2});
        return res2;
    }

    // Reads from the stream.
    // Collects response lines (without code) into the result, which the caller owns.
    // Checks the numeric code on response lines.
    fn read_expect_code(self: *Client, expect_code: u16) ![]const u8 {
        var res = std.ArrayList(u8).init(self.allocator);
        errdefer res.deinit();

        while (true) {
            var line = try self.read_line();
            defer self.allocator.free(line);
            var spl = std.mem.splitAny(u8, line, " -");
            if (std.fmt.parseInt(u16, spl.first(), 10)) |code| {
                if (code != expect_code) {
                    log.err("invalid response code {}", .{code});
                    return Client.Error.InvalidResponseCode;
                }
            } else |_| {
                log.err("malformatted line [{s}]", .{line});
                return Client.Error.Malformatted;
            }
            if (line.len < 3) {
                log.err("malformatted line [{s}]", .{line});
                return Client.Error.Malformatted;
            }
            try res.appendSlice(line[4..]);
            // Continuation lines have "123-FOO"
            // Final line has a space like "123 BAR"
            if (line[3] == ' ') {
                break;
            } else {
                try res.appendSlice("\r\n");
            }
        }
        return try res.toOwnedSlice();
    }

    fn write_line(self: *Client, line: []const u8) !void {
        log.debug("> {s}", .{line});
        var wtr = self.writer();
        try wtr.writeAll(line);
        try wtr.writeAll("\r\n");
    }

    pub fn mail(self: *Client, from: []const u8) !void {
        const line = try std.fmt.allocPrint(self.allocator, "MAIL FROM:<{s}>", .{from});
        defer self.allocator.free(line);
        try self.write_line(line);
        const r = try self.read_expect_code(250);
        self.allocator.free(r);
    }

    pub fn rcpt(self: *Client, to: []const u8) !void {
        const line = try std.fmt.allocPrint(self.allocator, "RCPT TO:<{s}>", .{to});
        defer self.allocator.free(line);
        try self.write_line(line);
        const r = try self.read_expect_code(250);
        self.allocator.free(r);
    }

    pub fn data(self: *Client, rdr: anytype) !void {
        try self.write_line("DATA");
        self.allocator.free(try self.read_expect_code(354));
        var buf: [1024]u8 = undefined;
        var r = try rdr.read(&buf);
        while (r > 0) : (r = try rdr.read(&buf)) {
            // TODO proper dot encoding, ensure any existing sequences of \r\n.\r\n are escaped I guess
            log.debug("> {s}", .{buf[0..r]});
            try self.writer().writeAll(buf[0..r]);
        }
        try self.write_line("\r\n.");
        self.allocator.free(try self.read_expect_code(250));
    }

    pub fn rset(self: *Client) !void {
        try self.write_line("RSET");
        self.allocator.free(try self.read_expect_code(250));
    }

    pub fn quit(self: *Client) !void {
        try self.write_line("QUIT");
        self.allocator.free(try self.read_expect_code(221));
    }

    pub fn send_mail(self: *Client, from: []const u8, to: []const []const u8, msg: []const u8) !void {
        try self.rset();
        try self.mail(from);
        for (to) |t| {
            try self.rcpt(t);
        }
        var fbs = std.io.fixedBufferStream(msg);
        try self.data(fbs.reader());
    }

    const Error = error{ Malformatted, InvalidResponseCode };
};

// TODO why doesn't tls module define an error set?
const TlsError = error{
    Overflow,
    TlsAlertUnexpectedMessage,
    TlsAlertBadRecordMac,
    TlsAlertRecordOverflow,
    TlsAlertHandshakeFailure,
    TlsAlertBadCertificate,
    TlsAlertUnsupportedCertificate,
    TlsAlertCertificateRevoked,
    TlsAlertCertificateExpired,
    TlsAlertCertificateUnknown,
    TlsAlertIllegalParameter,
    TlsAlertUnknownCa,
    TlsAlertAccessDenied,
    TlsAlertDecodeError,
    TlsAlertDecryptError,
    TlsAlertProtocolVersion,
    TlsAlertInsufficientSecurity,
    TlsAlertInternalError,
    TlsAlertInappropriateFallback,
    TlsAlertMissingExtension,
    TlsAlertUnsupportedExtension,
    TlsAlertUnrecognizedName,
    TlsAlertBadCertificateStatusResponse,
    TlsAlertUnknownPskIdentity,
    TlsAlertCertificateRequired,
    TlsAlertNoApplicationProtocol,
    TlsAlertUnknown,
    TlsUnexpectedMessage,
    TlsIllegalParameter,
    TlsRecordOverflow,
    TlsBadRecordMac,
    TlsConnectionTruncated,
    TlsDecodeError,
    TlsBadLength,
};

test "send email" {
    const mail =
        \\Subject: test
        \\From: <martin@mfashby.net>
        \\To: <martin@mfashby.net>
        \\
        \\This is a test message
    ;
    const user = std.posix.getenv("SMTP_USERNAME").?;
    const pass = std.posix.getenv("SMTP_PASSWORD").?;
    var client = try Client.init(std.testing.allocator, "mail.mfashby.net:587", .{ .user = user, .pass = pass });
    defer client.deinit();
    try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail);
    const mail2 =
        \\Subject: test2
        \\From: <martin@mfashby.net>
        \\To: <martin@mfashby.net>
        \\
        \\This is another test message
    ;
    try client.send_mail("martin@mfashby.net", &[_][]const u8{"martin@mfashby.net"}, mail2);
    try client.quit(); // Be nice (but we don't have to)
}

// TODO lots more tests.
// TODO use mock SMTP server to integration test local only.