ona/src/coral/tree.zig
kayomn 89d975f668
Some checks reported errors
continuous-integration/drone/push Build was killed
Add deferred system operations
2025-08-11 10:13:21 +01:00

448 lines
13 KiB
Zig

const coral = @import("./coral.zig");
const std = @import("std");
pub fn Binary(comptime Key: type, comptime Value: type, comptime traits: Traits(Key)) type {
return struct {
free_nodes: Pool = .{},
has_root: ?*Node = null,
const Node = struct {
key: Key,
value: Value,
has_lesser: ?*Node = null,
has_greater: ?*Node = null,
has_parent: ?*Node = null,
fn deinit(self: *Node, allocator: std.mem.Allocator) void {
if (self.has_lesser) |left| {
left.deinit(allocator);
allocator.destroy(left);
}
if (self.has_greater) |right| {
right.deinit(allocator);
allocator.destroy(right);
}
self.* = undefined;
}
fn find(self: *Node, key: Key) ?*Node {
var node = self;
while (true) {
const next_node = switch (traits.compare(key, node.key)) {
.lesser => node.has_lesser,
.greater => node.has_greater,
.equal => {
return node;
},
};
node = next_node orelse {
return null;
};
}
}
fn getMax(self: *Node) *Node {
var node = self;
while (self.has_lesser) |lesser| {
node = lesser;
}
return node;
}
fn getMin(self: *Node) *Node {
var node = self;
while (self.has_lesser) |lesser| {
node = lesser;
}
return node;
}
fn remove(self: *Node, key: Key) ?*Node {
const node = self.find(key) orelse {
return null;
};
if (node.has_lesser == null) {
std.debug.assert(self.transplant(node.has_greater));
} else if (node.has_greater == null) {
std.debug.assert(self.transplant(node.has_lesser));
} else {
var successor = node.has_greater.?.getMin();
if (successor.has_parent != node) {
// Move successor up: replace successor with its right child first
std.debug.assert(successor.transplant(successor.has_greater));
// Attach node.right to successor
successor.has_greater = node.has_greater;
if (successor.has_greater) |g| {
g.has_parent = successor;
}
}
// Replace node with successor
std.debug.assert(node.transplant(successor));
// Attach node.left to successor
successor.has_lesser = node.has_lesser;
if (successor.has_lesser) |l| {
l.has_parent = successor;
}
}
// Detach the removed node completely and return it
node.has_parent = null;
node.has_lesser = null;
node.has_greater = null;
return node;
}
fn transplant(self: *Node, has_node: ?*Node) bool {
const parent = self.has_parent orelse {
return false;
};
if (parent.has_lesser == self) {
parent.has_lesser = has_node;
} else {
parent.has_greater = has_node;
}
if (has_node) |node| {
node.has_parent = self.has_parent;
}
return true;
}
};
pub const KeyValues = struct {
nodes: ?*Node,
pub fn next(self: *KeyValues) ?coral.KeyValuePair(Key, *Value) {
var nodes = self.nodes;
while (nodes) |node| {
const left = node.has_lesser orelse {
self.nodes = node.has_greater;
return .{
.key = node.key,
.value = &node.value,
};
};
// Find the rightmost node in left subtree or link back to current
var pred = left;
while (pred.has_greater != null and pred.has_greater != node) {
pred = pred.has_greater.?;
}
if (pred.has_greater != null) {
pred.has_greater = null;
self.nodes = node.has_greater;
return .{
.key = node.key,
.value = &node.value,
};
}
pred.has_greater = node;
self.nodes = node.has_lesser;
nodes = self.nodes;
}
return null;
}
pub fn nextKey(self: *KeyValues) ?Key {
return if (self.next()) |key_value| key_value.key else null;
}
pub fn nextValue(self: *KeyValues) ?*Value {
return if (self.next()) |key_value| key_value.value else null;
}
};
const Pool = struct {
has_node: ?*Node = null,
fn create(self: *Pool, allocator: std.mem.Allocator, node: Node) std.mem.Allocator.Error!*Node {
if (self.has_node) |free_node| {
self.has_node = free_node.has_parent;
free_node.* = node;
return free_node;
}
const created_node = try allocator.create(Node);
created_node.* = node;
return created_node;
}
fn destroy(self: *Pool, node: *Node) void {
node.* = .{
.key = undefined,
.value = undefined,
.has_parent = self.has_node,
.has_lesser = null,
.has_greater = null,
};
self.has_node = node;
}
};
const Self = @This();
pub fn clear(self: *Self) void {
var free_nodes: ?*Node = null;
if (self.has_root) |root| {
// Push root onto stack
root.has_parent = free_nodes;
free_nodes = root;
self.has_root = null;
}
while (free_nodes) |node| {
// Pop node from stack
free_nodes = node.has_parent;
// Push children onto stack
if (node.has_lesser) |left| {
left.has_parent = free_nodes;
free_nodes = left;
}
if (node.has_greater) |right| {
right.has_parent = free_nodes;
free_nodes = right;
}
self.free_nodes.destroy(node);
}
}
pub fn deinit(self: *Self, allocator: std.mem.Allocator) void {
if (self.has_root) |root| {
root.deinit(allocator);
allocator.destroy(root);
}
self.has_root = undefined;
}
pub fn insert(self: *Self, allocator: std.mem.Allocator, key: Key, value: Value) std.mem.Allocator.Error!?*Value {
var node = self.has_root orelse {
self.has_root = try self.free_nodes.create(allocator, .{
.key = key,
.value = value,
});
return &self.has_root.?.value;
};
while (true) {
switch (traits.compare(key, node.key)) {
.equal => {
return null;
},
.lesser => {
node = node.has_lesser orelse {
node.has_lesser = try self.free_nodes.create(allocator, .{
.key = key,
.value = value,
.has_parent = node,
});
return &node.has_lesser.?.value;
};
},
.greater => {
node = node.has_greater orelse {
node.has_greater = try self.free_nodes.create(allocator, .{
.key = key,
.value = value,
.has_parent = node,
});
return &node.has_greater.?.value;
};
},
}
}
}
pub const empty = Self{
.has_root = null,
};
pub fn get(self: Self, key: Key) ?*Value {
if (self.has_root) |root| {
if (root.find(key)) |node| {
return &node.value;
}
}
return null;
}
pub fn getKey(self: Self, key: Key) ?Key {
if (self.has_root) |root| {
if (root.find(key)) |node| {
return &node.key;
}
}
return null;
}
pub fn isEmpty(self: *Self) bool {
return self.has_root == null;
}
pub fn keyValues(self: *const Self) KeyValues {
return .{ .nodes = self.has_root };
}
pub fn remove(self: *Self, key: Key) ?coral.KeyValuePair(Key, Value) {
const root = self.has_root orelse {
return null;
};
const node = root.remove(key) orelse {
return null;
};
defer {
self.free_nodes.destroy(node);
}
return .{
.key = node.key,
.value = node.value,
};
}
};
}
pub const Comparison = enum(i2) {
lesser = -1,
equal = 0,
greater = 1,
};
pub fn Traits(comptime Key: type) type {
return struct {
compare: fn (Key, Key) Comparison,
};
}
pub fn scalarTraits(comptime Scalar: type) Traits(Scalar) {
const traits = switch (@typeInfo(Scalar)) {
.@"enum" => struct {
fn compare(a: Scalar, b: Scalar) Comparison {
const a_int = @intFromEnum(a);
const b_int = @intFromEnum(b);
if (a_int < b_int) {
return .lesser;
}
if (a_int > b_int) {
return .greater;
}
return .equal;
}
},
.pointer => struct {
fn compare(a: Scalar, b: Scalar) Comparison {
const a_int = @intFromPtr(a);
const b_int = @intFromPtr(b);
if (a_int < b_int) {
return .lesser;
}
if (a_int > b_int) {
return .greater;
}
return .equal;
}
},
.int => struct {
fn compare(a: Scalar, b: Scalar) Comparison {
if (a < b) {
return .lesser;
}
if (a > b) {
return .greater;
}
return .equal;
}
},
else => {
@compileError(std.fmt.comptimePrint("parameter `Scalar` must be a scalar type, not {s}", .{
@typeName(Scalar),
}));
},
};
return .{
.compare = traits.compare,
};
}
pub fn sliceTraits(comptime Slice: type) Traits(Slice) {
const slice_pointer = switch (@typeInfo(Slice)) {
.pointer => |pointer| pointer,
else => {
@compileError(std.fmt.comptimePrint("parameter `Slice` must be a slice type, not {s}", .{
@typeName(Slice),
}));
},
};
const traits = struct {
fn compare(a: Slice, b: Slice) Comparison {
return switch (std.mem.order(slice_pointer.child, a, b)) {
.lt => .lesser,
.gt => .greater,
.eq => .equal,
};
}
};
return .{
.compare = traits.compare,
};
}