448 lines
13 KiB
Zig
448 lines
13 KiB
Zig
const coral = @import("./coral.zig");
|
|
|
|
const std = @import("std");
|
|
|
|
pub fn Binary(comptime Key: type, comptime Value: type, comptime traits: Traits(Key)) type {
|
|
return struct {
|
|
free_nodes: Pool = .{},
|
|
has_root: ?*Node = null,
|
|
|
|
const Node = struct {
|
|
key: Key,
|
|
value: Value,
|
|
has_lesser: ?*Node = null,
|
|
has_greater: ?*Node = null,
|
|
has_parent: ?*Node = null,
|
|
|
|
fn deinit(self: *Node, allocator: std.mem.Allocator) void {
|
|
if (self.has_lesser) |left| {
|
|
left.deinit(allocator);
|
|
allocator.destroy(left);
|
|
}
|
|
|
|
if (self.has_greater) |right| {
|
|
right.deinit(allocator);
|
|
allocator.destroy(right);
|
|
}
|
|
|
|
self.* = undefined;
|
|
}
|
|
|
|
fn find(self: *Node, key: Key) ?*Node {
|
|
var node = self;
|
|
|
|
while (true) {
|
|
const next_node = switch (traits.compare(key, node.key)) {
|
|
.lesser => node.has_lesser,
|
|
.greater => node.has_greater,
|
|
|
|
.equal => {
|
|
return node;
|
|
},
|
|
};
|
|
|
|
node = next_node orelse {
|
|
return null;
|
|
};
|
|
}
|
|
}
|
|
|
|
fn getMax(self: *Node) *Node {
|
|
var node = self;
|
|
|
|
while (self.has_lesser) |lesser| {
|
|
node = lesser;
|
|
}
|
|
|
|
return node;
|
|
}
|
|
|
|
fn getMin(self: *Node) *Node {
|
|
var node = self;
|
|
|
|
while (self.has_lesser) |lesser| {
|
|
node = lesser;
|
|
}
|
|
|
|
return node;
|
|
}
|
|
|
|
fn remove(self: *Node, key: Key) ?*Node {
|
|
const node = self.find(key) orelse {
|
|
return null;
|
|
};
|
|
|
|
if (node.has_lesser == null) {
|
|
std.debug.assert(self.transplant(node.has_greater));
|
|
} else if (node.has_greater == null) {
|
|
std.debug.assert(self.transplant(node.has_lesser));
|
|
} else {
|
|
var successor = node.has_greater.?.getMin();
|
|
|
|
if (successor.has_parent != node) {
|
|
// Move successor up: replace successor with its right child first
|
|
std.debug.assert(successor.transplant(successor.has_greater));
|
|
|
|
// Attach node.right to successor
|
|
successor.has_greater = node.has_greater;
|
|
|
|
if (successor.has_greater) |g| {
|
|
g.has_parent = successor;
|
|
}
|
|
}
|
|
|
|
// Replace node with successor
|
|
std.debug.assert(node.transplant(successor));
|
|
|
|
// Attach node.left to successor
|
|
successor.has_lesser = node.has_lesser;
|
|
|
|
if (successor.has_lesser) |l| {
|
|
l.has_parent = successor;
|
|
}
|
|
}
|
|
|
|
// Detach the removed node completely and return it
|
|
node.has_parent = null;
|
|
node.has_lesser = null;
|
|
node.has_greater = null;
|
|
|
|
return node;
|
|
}
|
|
|
|
fn transplant(self: *Node, has_node: ?*Node) bool {
|
|
const parent = self.has_parent orelse {
|
|
return false;
|
|
};
|
|
|
|
if (parent.has_lesser == self) {
|
|
parent.has_lesser = has_node;
|
|
} else {
|
|
parent.has_greater = has_node;
|
|
}
|
|
|
|
if (has_node) |node| {
|
|
node.has_parent = self.has_parent;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
};
|
|
|
|
pub const KeyValues = struct {
|
|
nodes: ?*Node,
|
|
|
|
pub fn next(self: *KeyValues) ?coral.KeyValuePair(Key, *Value) {
|
|
var nodes = self.nodes;
|
|
|
|
while (nodes) |node| {
|
|
const left = node.has_lesser orelse {
|
|
self.nodes = node.has_greater;
|
|
|
|
return .{
|
|
.key = node.key,
|
|
.value = &node.value,
|
|
};
|
|
};
|
|
|
|
// Find the rightmost node in left subtree or link back to current
|
|
var pred = left;
|
|
while (pred.has_greater != null and pred.has_greater != node) {
|
|
pred = pred.has_greater.?;
|
|
}
|
|
|
|
if (pred.has_greater != null) {
|
|
pred.has_greater = null;
|
|
self.nodes = node.has_greater;
|
|
|
|
return .{
|
|
.key = node.key,
|
|
.value = &node.value,
|
|
};
|
|
}
|
|
|
|
pred.has_greater = node;
|
|
self.nodes = node.has_lesser;
|
|
nodes = self.nodes;
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
pub fn nextKey(self: *KeyValues) ?Key {
|
|
return if (self.next()) |key_value| key_value.key else null;
|
|
}
|
|
|
|
pub fn nextValue(self: *KeyValues) ?*Value {
|
|
return if (self.next()) |key_value| key_value.value else null;
|
|
}
|
|
};
|
|
|
|
const Pool = struct {
|
|
has_node: ?*Node = null,
|
|
|
|
fn create(self: *Pool, allocator: std.mem.Allocator, node: Node) std.mem.Allocator.Error!*Node {
|
|
if (self.has_node) |free_node| {
|
|
self.has_node = free_node.has_parent;
|
|
|
|
free_node.* = node;
|
|
|
|
return free_node;
|
|
}
|
|
|
|
const created_node = try allocator.create(Node);
|
|
|
|
created_node.* = node;
|
|
|
|
return created_node;
|
|
}
|
|
|
|
fn destroy(self: *Pool, node: *Node) void {
|
|
node.* = .{
|
|
.key = undefined,
|
|
.value = undefined,
|
|
.has_parent = self.has_node,
|
|
.has_lesser = null,
|
|
.has_greater = null,
|
|
};
|
|
|
|
self.has_node = node;
|
|
}
|
|
};
|
|
|
|
const Self = @This();
|
|
|
|
pub fn clear(self: *Self) void {
|
|
var free_nodes: ?*Node = null;
|
|
|
|
if (self.has_root) |root| {
|
|
// Push root onto stack
|
|
root.has_parent = free_nodes;
|
|
free_nodes = root;
|
|
self.has_root = null;
|
|
}
|
|
|
|
while (free_nodes) |node| {
|
|
// Pop node from stack
|
|
free_nodes = node.has_parent;
|
|
|
|
// Push children onto stack
|
|
if (node.has_lesser) |left| {
|
|
left.has_parent = free_nodes;
|
|
free_nodes = left;
|
|
}
|
|
|
|
if (node.has_greater) |right| {
|
|
right.has_parent = free_nodes;
|
|
free_nodes = right;
|
|
}
|
|
|
|
self.free_nodes.destroy(node);
|
|
}
|
|
}
|
|
|
|
pub fn deinit(self: *Self, allocator: std.mem.Allocator) void {
|
|
if (self.has_root) |root| {
|
|
root.deinit(allocator);
|
|
allocator.destroy(root);
|
|
}
|
|
|
|
self.has_root = undefined;
|
|
}
|
|
|
|
pub fn insert(self: *Self, allocator: std.mem.Allocator, key: Key, value: Value) std.mem.Allocator.Error!?*Value {
|
|
var node = self.has_root orelse {
|
|
self.has_root = try self.free_nodes.create(allocator, .{
|
|
.key = key,
|
|
.value = value,
|
|
});
|
|
|
|
return &self.has_root.?.value;
|
|
};
|
|
|
|
while (true) {
|
|
switch (traits.compare(key, node.key)) {
|
|
.equal => {
|
|
return null;
|
|
},
|
|
|
|
.lesser => {
|
|
node = node.has_lesser orelse {
|
|
node.has_lesser = try self.free_nodes.create(allocator, .{
|
|
.key = key,
|
|
.value = value,
|
|
.has_parent = node,
|
|
});
|
|
|
|
return &node.has_lesser.?.value;
|
|
};
|
|
},
|
|
|
|
.greater => {
|
|
node = node.has_greater orelse {
|
|
node.has_greater = try self.free_nodes.create(allocator, .{
|
|
.key = key,
|
|
.value = value,
|
|
.has_parent = node,
|
|
});
|
|
|
|
return &node.has_greater.?.value;
|
|
};
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
pub const empty = Self{
|
|
.has_root = null,
|
|
};
|
|
|
|
pub fn get(self: Self, key: Key) ?*Value {
|
|
if (self.has_root) |root| {
|
|
if (root.find(key)) |node| {
|
|
return &node.value;
|
|
}
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
pub fn getKey(self: Self, key: Key) ?Key {
|
|
if (self.has_root) |root| {
|
|
if (root.find(key)) |node| {
|
|
return &node.key;
|
|
}
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
pub fn isEmpty(self: *Self) bool {
|
|
return self.has_root == null;
|
|
}
|
|
|
|
pub fn keyValues(self: *const Self) KeyValues {
|
|
return .{ .nodes = self.has_root };
|
|
}
|
|
|
|
pub fn remove(self: *Self, key: Key) ?coral.KeyValuePair(Key, Value) {
|
|
const root = self.has_root orelse {
|
|
return null;
|
|
};
|
|
|
|
const node = root.remove(key) orelse {
|
|
return null;
|
|
};
|
|
|
|
defer {
|
|
self.free_nodes.destroy(node);
|
|
}
|
|
|
|
return .{
|
|
.key = node.key,
|
|
.value = node.value,
|
|
};
|
|
}
|
|
};
|
|
}
|
|
|
|
pub const Comparison = enum(i2) {
|
|
lesser = -1,
|
|
equal = 0,
|
|
greater = 1,
|
|
};
|
|
|
|
pub fn Traits(comptime Key: type) type {
|
|
return struct {
|
|
compare: fn (Key, Key) Comparison,
|
|
};
|
|
}
|
|
|
|
pub fn scalarTraits(comptime Scalar: type) Traits(Scalar) {
|
|
const traits = switch (@typeInfo(Scalar)) {
|
|
.@"enum" => struct {
|
|
fn compare(a: Scalar, b: Scalar) Comparison {
|
|
const a_int = @intFromEnum(a);
|
|
const b_int = @intFromEnum(b);
|
|
|
|
if (a_int < b_int) {
|
|
return .lesser;
|
|
}
|
|
|
|
if (a_int > b_int) {
|
|
return .greater;
|
|
}
|
|
|
|
return .equal;
|
|
}
|
|
},
|
|
|
|
.pointer => struct {
|
|
fn compare(a: Scalar, b: Scalar) Comparison {
|
|
const a_int = @intFromPtr(a);
|
|
const b_int = @intFromPtr(b);
|
|
|
|
if (a_int < b_int) {
|
|
return .lesser;
|
|
}
|
|
|
|
if (a_int > b_int) {
|
|
return .greater;
|
|
}
|
|
|
|
return .equal;
|
|
}
|
|
},
|
|
|
|
.int => struct {
|
|
fn compare(a: Scalar, b: Scalar) Comparison {
|
|
if (a < b) {
|
|
return .lesser;
|
|
}
|
|
|
|
if (a > b) {
|
|
return .greater;
|
|
}
|
|
|
|
return .equal;
|
|
}
|
|
},
|
|
|
|
else => {
|
|
@compileError(std.fmt.comptimePrint("parameter `Scalar` must be a scalar type, not {s}", .{
|
|
@typeName(Scalar),
|
|
}));
|
|
},
|
|
};
|
|
|
|
return .{
|
|
.compare = traits.compare,
|
|
};
|
|
}
|
|
|
|
pub fn sliceTraits(comptime Slice: type) Traits(Slice) {
|
|
const slice_pointer = switch (@typeInfo(Slice)) {
|
|
.pointer => |pointer| pointer,
|
|
|
|
else => {
|
|
@compileError(std.fmt.comptimePrint("parameter `Slice` must be a slice type, not {s}", .{
|
|
@typeName(Slice),
|
|
}));
|
|
},
|
|
};
|
|
|
|
const traits = struct {
|
|
fn compare(a: Slice, b: Slice) Comparison {
|
|
return switch (std.mem.order(slice_pointer.child, a, b)) {
|
|
.lt => .lesser,
|
|
.gt => .greater,
|
|
.eq => .equal,
|
|
};
|
|
}
|
|
};
|
|
|
|
return .{
|
|
.compare = traits.compare,
|
|
};
|
|
}
|