From 449b56947ed82abfc9d650d59ebde01b46655f01 Mon Sep 17 00:00:00 2001 From: kayomn Date: Sat, 15 Oct 2022 21:11:15 +0100 Subject: [PATCH] Improve memory safety of Oar archive lookups --- src/oar.zig | 58 +++++++++++++++++++++++++++++++---------------------- src/sys.zig | 2 +- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/src/oar.zig b/src/oar.zig index ed8db92..1a1cb77 100644 --- a/src/oar.zig +++ b/src/oar.zig @@ -24,7 +24,7 @@ pub const Archive = struct { /// /// /// - const IndexCache = table.Hashed([]const u8, Entry.Header, table.string_context); + const IndexCache = table.Hashed([]const u8, u64, table.string_context); /// /// Finds an entry matching `entry_path` in `archive`. @@ -34,40 +34,50 @@ pub const Archive = struct { pub fn find(archive: *Archive, entry_path: []const u8) FindError!Entry { return Entry{ .header = find_header: { - if (archive.index_cache.lookup(entry_path)) |entry_header| - break: find_header entry_header.*; - - // Start from beginning of archive. - try archive.file_access.seek(0); - - var entry_header = Entry.Header{ + var header = Entry.Header{ .revision = 0, .file_size = 0, - .file_offset = 0 + .absolute_offset = 0 }; - const read_buffer = std.mem.asBytes(&entry_header); + const header_size = @sizeOf(Entry.Header); - // Read first entry. - while ((try archive.file_access.read(read_buffer)) == @sizeOf(Entry.Header)) { - if (std.mem.eql(u8, entry_path, entry_header. - name_buffer[0 .. entry_header.name_length])) { + if (archive.index_cache.lookup(entry_path)) |cursor| { + try archive.file_access.seek(cursor); - // If caching fails... oh well... - archive.index_cache.insert(entry_path, entry_header) catch {}; + if ((try archive.file_access.read(std.mem.asBytes(&header))) != header_size) { + std.debug.assert(archive.index_cache.remove(entry_path) != null); - break: find_header entry_header; + return error.EntryNotFound; } - // Move over file data following the entry. - var to_skip = entry_header.file_size; + break: find_header header; + } else { + const mem = std.mem; - while (to_skip != 0) { - const skipped = std.math.min(to_skip, std.math.maxInt(i64)); + // Start from beginning of archive. + try archive.file_access.seek(0); - try archive.file_access.skip(@intCast(i64, skipped)); + // Read first entry. + while ((try archive.file_access.read(mem.asBytes(&header))) == header_size) { + if (mem.eql(u8, entry_path, header.name_buffer[0 .. header.name_length])) { + // If caching fails... oh well... + archive.index_cache.insert(entry_path, header.absolute_offset) catch {}; - to_skip -= skipped; + break: find_header header; + } + + // Move over file data following the entry. + var to_skip = header.file_size; + + while (to_skip != 0) { + const math = std.math; + const skipped = math.min(to_skip, math.maxInt(i64)); + + try archive.file_access.skip(@intCast(i64, skipped)); + + to_skip -= skipped; + } } } @@ -109,7 +119,7 @@ pub const Entry = struct { name_buffer: [255]u8 = std.mem.zeroes([255]u8), name_length: u8 = 0, file_size: u64, - file_offset: u64, + absolute_offset: u64, padding: [232]u8 = std.mem.zeroes([232]u8), comptime { diff --git a/src/sys.zig b/src/sys.zig index 6b3dc10..d0a4aee 100644 --- a/src/sys.zig +++ b/src/sys.zig @@ -487,7 +487,7 @@ pub const FileSystem = union(enum) { if (archive_entry.cursor >= archive_entry.header.file_size) return error.FileInaccessible; - try file_access.seek(archive_entry.header.file_offset); + try file_access.seek(archive_entry.header.absolute_offset); return file_access.read(buffer[0 .. std.math.min( buffer.len, archive_entry.header.file_size)]);