diff --git a/src/State.zig b/src/State.zig new file mode 100644 index 0000000..dafe3bc --- /dev/null +++ b/src/State.zig @@ -0,0 +1,27 @@ +const std = @import("std"); + +pub const Mapping = struct { + mapped: []const u8, + doc: ?[]const u8, +}; + +pub const MapData = struct { + arena: std.heap.ArenaAllocator, + mappings: std.StringHashMap(Mapping), + renames: std.StringHashMap([]const u8), +}; + +alloc: std.mem.Allocator, +lock: std.Thread.RwLock, +mdata: ?MapData = null, + +const State = @This(); + +pub fn deinit(self: State) void { + var selfv = self; + if (selfv.mdata) |*m| { + m.arena.deinit(); + m.mappings.deinit(); + m.renames.deinit(); + } +} diff --git a/src/main.zig b/src/main.zig index 9aaf641..27af831 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,17 +1,13 @@ const std = @import("std"); const csv_reader = @import("csv_reader.zig"); +const State = @import("State.zig"); const StringPacket = @import("StringPacket.zig"); pub const std_options = std.Options{ .log_level = .debug, }; -const Mapping = struct { - mapped: []const u8, - doc: ?[]const u8, -}; - fn getAddr(alloc: std.mem.Allocator) !std.net.Address { const sockpath = try std.fs.path.join(alloc, &.{ std.posix.getenv("XDG_RUNTIME_DIR") orelse return error.MissingRuntimeDir, @@ -46,9 +42,6 @@ fn map() !void { const stdin = std.io.getStdIn().reader(); - var buf: [128]u8 = undefined; - const request = (try stdin.readUntilDelimiterOrEof(&buf, '\n')) orelse return error.NoInput; - const addr = try getAddr(alloc); const sockfd = try std.posix.socket( @@ -62,6 +55,9 @@ fn map() !void { const stream = std.net.Stream{ .handle = sockfd }; + var buf: [128]u8 = undefined; + const request = (try stdin.readUntilDelimiterOrEof(&buf, '\n')) orelse return error.NoInput; + const pkt = StringPacket{ .str = std.mem.trim(u8, request, &std.ascii.whitespace) }; try pkt.write(stream.writer()); @@ -71,41 +67,38 @@ fn map() !void { try std.io.getStdOut().writer().print("{s}\n", .{res.str}); } -fn runServer() !void { - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - defer _ = gpa.deinit(); - - const alloc = gpa.allocator(); - - var mappings = std.StringHashMap(Mapping).init(alloc); - defer mappings.deinit(); - - var data_arena = std.heap.ArenaAllocator.init(alloc); - defer data_arena.deinit(); - - var renames = std.StringHashMap([]const u8).init(alloc); - defer renames.deinit(); +fn loadData(state: *State) !void { + var data = State.MapData{ + .arena = std.heap.ArenaAllocator.init(state.alloc), + .mappings = std.StringHashMap(State.Mapping).init(state.alloc), + .renames = std.StringHashMap([]const u8).init(state.alloc), + }; + errdefer { + data.arena.deinit(); + data.mappings.deinit(); + data.renames.deinit(); + } if (std.fs.cwd().openFile("renames.csv", .{})) |renames_file| { defer renames_file.close(); var reader = csv_reader.csvReader(std.io.bufferedReader(renames_file.reader())); - while (try reader.next(data_arena.allocator())) |rec| { + while (try reader.next(data.arena.allocator())) |rec| { if (rec.cols.len != 2) { std.log.warn("found rename with invalid record length, skipping", .{}); continue; } - if (renames.contains(rec.cols[0])) { + if (data.renames.contains(rec.cols[0])) { std.log.warn("duplicate rename '{s}'", .{rec.cols[0]}); continue; } - try renames.put(rec.cols[0], rec.cols[1]); + try data.renames.put(rec.cols[0], rec.cols[1]); } - std.log.info("loaded {} renames", .{renames.count()}); + std.log.info("loaded {} renames", .{data.renames.count()}); } else |err| { std.log.warn("couldn't open renames file: {}, skipping", .{err}); } @@ -116,67 +109,171 @@ fn runServer() !void { var mappings_iter = mappings_dir.iterate(); while (try mappings_iter.next()) |entry| { - const fpath = try std.fs.path.join(alloc, &.{ "mappings", entry.name }); - defer alloc.free(fpath); + const fpath = try std.fs.path.join(state.alloc, &.{ "mappings", entry.name }); + defer state.alloc.free(fpath); var mapfile = try std.fs.cwd().openFile(fpath, .{}); defer mapfile.close(); var reader = csv_reader.csvReader(std.io.bufferedReader(mapfile.reader())); - while (try reader.next(data_arena.allocator())) |rec| { + while (try reader.next(data.arena.allocator())) |rec| { if (rec.cols.len < 2) { std.log.warn("found mapping with invalid length, skipping", .{}); continue; } - if (mappings.contains(rec.cols[0])) { + if (data.mappings.contains(rec.cols[0])) { std.log.warn("duplicate mapping '{s}'", .{rec.cols[0]}); continue; } - const mapping = Mapping{ + const mapping = State.Mapping{ .mapped = rec.cols[1], .doc = if (rec.cols.len >= 4) rec.cols[3] else null, }; - try mappings.put(rec.cols[0], mapping); + try data.mappings.put(rec.cols[0], mapping); } } - std.log.info("loaded {} mappings", .{mappings.count()}); + state.lock.lock(); + defer state.lock.unlock(); - const addr = try getAddr(alloc); - var server = try addr.listen(.{}); - defer server.deinit(); - std.log.info("listening on {}", .{addr}); + state.mdata = data; + std.log.info("loaded {} mappings", .{data.mappings.count()}); +} + +fn handleConnection(state: *State, con: std.net.Server.Connection) !void { while (true) { - const con = try server.accept(); - const req = try StringPacket.read(con.stream.reader(), alloc); - defer alloc.free(req.str); + const req = try StringPacket.read(con.stream.reader(), state.alloc); + defer state.alloc.free(req.str); - if (mappings.get(req.str)) |mapping| { - const renamed = renames.get(mapping.mapped); - const res = StringPacket{ .str = renamed orelse mapping.mapped }; - try res.write(con.stream.writer()); + state.lock.lockShared(); + defer state.lock.unlockShared(); - std.log.info( - \\ - \\ Unmapped: {s} - \\ Mapped: {s} - \\ Renamed: {s} - \\ - \\ Doc: {s} - , .{ - req.str, - mapping.mapped, - renamed orelse "", - mapping.doc orelse "", - }); + if (state.mdata) |m| { + if (m.mappings.get(req.str)) |mapping| { + const renamed = m.renames.get(mapping.mapped); + const res = StringPacket{ .str = renamed orelse mapping.mapped }; + try res.write(con.stream.writer()); + + std.log.info( + \\ + \\ Unmapped: {s} + \\ Mapped: {s} + \\ Renamed: {s} + \\ + \\ Doc: {s} + , .{ + req.str, + mapping.mapped, + renamed orelse "", + mapping.doc orelse "", + }); + } else { + const res = StringPacket{ .str = "" }; + try res.write(con.stream.writer()); + } } else { - const res = StringPacket{ .str = "" }; + const res = StringPacket{ .str = "" }; try res.write(con.stream.writer()); } } } + +fn runServer() !void { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + + const alloc = gpa.allocator(); + + var state = State{ + .alloc = alloc, + .lock = .{}, + }; + defer state.deinit(); + + const sigs = comptime sigset: { + var sigset = std.os.linux.empty_sigset; + std.os.linux.sigaddset(&sigset, std.posix.SIG.TERM); + std.os.linux.sigaddset(&sigset, std.posix.SIG.INT); + break :sigset sigset; + }; + + std.posix.sigprocmask(std.posix.SIG.BLOCK, &sigs, null); + + var pool: std.Thread.Pool = undefined; + try std.Thread.Pool.init(&pool, .{ + .allocator = alloc, + .n_jobs = 4, // plenty + }); + defer pool.deinit(); // joins threads + + try pool.spawn((struct { + fn f(s: *State) void { + loadData(s) catch |e| { + std.log.err("failed to load data: {}", .{e}); + s.mdata = .{ + .arena = std.heap.ArenaAllocator.init(s.alloc), + .mappings = std.StringHashMap(State.Mapping).init(s.alloc), + .renames = std.StringHashMap([]const u8).init(s.alloc), + }; + }; + } + }).f, .{&state}); + + const addr = try getAddr(alloc); + var server = try addr.listen(.{}); + defer { + server.deinit(); + const path = std.mem.sliceTo(&addr.un.path, 0); + std.fs.cwd().deleteFile(path) catch |e| + std.log.warn("failed to delete socket: {}", .{e}); + } + std.log.info("listening on {}", .{addr}); + + const sigfd = try std.posix.signalfd(-1, &sigs, 0); + defer std.posix.close(sigfd); + + const epfd = try std.posix.epoll_create1(0); + defer std.posix.close(epfd); + + for ([_]std.posix.fd_t{ server.stream.handle, sigfd }) |fd| { + const ev = std.os.linux.epoll_event{ + .data = .{ .fd = fd }, + .events = std.os.linux.EPOLL.IN, + }; + try std.posix.epoll_ctl(epfd, std.os.linux.EPOLL.CTL_ADD, fd, @constCast(&ev)); + } + + var ev_buf: [32]std.os.linux.epoll_event = undefined; + outer: while (true) { + const evs = ev_buf[0..std.os.linux.epoll_wait(epfd, &ev_buf, ev_buf.len, -1)]; + for (evs) |ev| { + if (ev.data.fd == server.stream.handle) { + const con = try server.accept(); + errdefer con.stream.close(); + try pool.spawn((struct { + fn f(s: *State, conn: std.net.Server.Connection) void { + defer conn.stream.close(); + handleConnection(s, conn) catch |e| switch (e) { + error.EndOfStream => {}, + else => std.log.warn("handing connection: {}", .{e}), + }; + } + }).f, .{ &state, con }); + } else if (ev.data.fd == sigfd) { + var siginfo: std.posix.siginfo_t = undefined; + std.debug.assert(try std.posix.read( + sigfd, + std.mem.asBytes(&siginfo), + ) == @sizeOf(std.posix.siginfo_t)); + + std.log.info("got signal {}, bye!", .{siginfo.signo}); + break :outer; + } + } + } +} diff --git a/vim.lua b/vim.lua index 9f62265..644cc3c 100644 --- a/vim.lua +++ b/vim.lua @@ -12,7 +12,7 @@ vim.keymap.set("n", "m", function() end) vim.keymap.set("n", "M", function() - local nlines = 99 + local nlines = vim.api.nvim_buf_line_count(0) for i = 1, nlines do print(i .. "/" .. nlines) local line = vim.fn.getline(i)