diff --git a/Libraries/MLXLMCommon/SwitchLayers.swift b/Libraries/MLXLMCommon/SwitchLayers.swift index 7cb112dda..bd6f719a7 100644 --- a/Libraries/MLXLMCommon/SwitchLayers.swift +++ b/Libraries/MLXLMCommon/SwitchLayers.swift @@ -713,6 +713,7 @@ public class SwitchGLU: Module, @unchecked Sendable { var usedSlots = Set() var missInfo = [(rangeIdx: Int, expertId: Int, bufferSlot: Int)]() + var slotExhausted = false for (ri, r) in ranges.enumerated() { if let slot = prevSlotMap[r.id], !usedSlots.contains(slot) { @@ -723,7 +724,12 @@ public class SwitchGLU: Module, @unchecked Sendable { usedSlots.insert(slot) } else { // MISS: find a free slot - let freeSlot = (0.. (assignments: [(rangeIdx: Int, slot: Int)], exhausted: Bool) { + var prevSlotMap = [Int: Int]() + for (slot, eid) in prevIds.enumerated() { + prevSlotMap[eid] = slot + } + + var usedSlots = Set() + var assignments = [(rangeIdx: Int, slot: Int)]() + var slotExhausted = false + + for (ri, r) in ranges.enumerated() { + if let slot = prevSlotMap[r.id], !usedSlots.contains(slot) { + // HIT + usedSlots.insert(slot) + assignments.append((ri, slot)) + } else { + // MISS — find a free slot + guard let freeSlot = (0.. [(rangeIdx: Int, slot: Int)] { + var prevSlotMap = [Int: Int]() + for (slot, eid) in prevIds.enumerated() { + prevSlotMap[eid] = slot + } + + var usedSlots = Set() + var assignments = [(rangeIdx: Int, slot: Int)]() + + for (ri, r) in ranges.enumerated() { + if let slot = prevSlotMap[r.id], !usedSlots.contains(slot) { + usedSlots.insert(slot) + assignments.append((ri, slot)) + } else { + // BUG: force-unwrap crashes when all slots consumed by hits + let freeSlot = (0.. maxBuffers") + XCTAssertEqual(assignments.count, 8, "Should have assigned 8 ranges before exhaustion") + } + + // ═══════════════════════════════════════════════════════════════════ + // MARK: - 2. Normal operation: hits + misses fit within maxBuffers + // ═══════════════════════════════════════════════════════════════════ + + func testNormalHitMissResolution() { + let maxBuffers = 8 + let prevIds = [0, 1, 2, 3, 4, 5, 6, 7] + // 6 hits + 2 misses = 8 total, fits in maxBuffers + let ranges = [0, 1, 2, 3, 4, 5, 99, 100].enumerated().map { + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) + } + + let (assignments, exhausted) = resolveSlots( + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers + ) + + XCTAssertFalse(exhausted) + XCTAssertEqual(assignments.count, 8) + + // Verify hits got their original slots + for i in 0..<6 { + XCTAssertEqual(assignments[i].slot, i, "Expert \(i) should hit slot \(i)") + } + // Misses should get free slots 6 and 7 + XCTAssertTrue([6, 7].contains(assignments[6].slot), "Miss expert 99 should get free slot") + XCTAssertTrue([6, 7].contains(assignments[7].slot), "Miss expert 100 should get free slot") + } + + // ═══════════════════════════════════════════════════════════════════ + // MARK: - 3. Edge case: all misses (no previous predictions) + // ═══════════════════════════════════════════════════════════════════ + + func testAllMisses() { + let maxBuffers = 8 + let prevIds = [100, 101, 102, 103, 104, 105, 106, 107] + // All 8 current experts are completely different from prev + let ranges = [0, 1, 2, 3, 4, 5, 6, 7].enumerated().map { + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) + } + + let (assignments, exhausted) = resolveSlots( + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers + ) + + XCTAssertFalse(exhausted, "8 misses should fit in 8 slots") + XCTAssertEqual(assignments.count, 8) + } + + // ═══════════════════════════════════════════════════════════════════ + // MARK: - 4. Edge case: all hits (100% speculation accuracy) + // ═══════════════════════════════════════════════════════════════════ + + func testAllHits() { + let maxBuffers = 8 + let prevIds = [0, 1, 2, 3, 4, 5, 6, 7] + let ranges = [0, 1, 2, 3, 4, 5, 6, 7].enumerated().map { + ExpertRange(id: $0.element, start: $0.offset, end: $0.offset + 1) + } + + let (assignments, exhausted) = resolveSlots( + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers + ) + + XCTAssertFalse(exhausted) + XCTAssertEqual(assignments.count, 8) + // Every expert should get its original slot + for i in 0..<8 { + XCTAssertEqual(assignments[i].slot, i) + } + } + + // ═══════════════════════════════════════════════════════════════════ + // MARK: - 5. Stress test: duplicate expert IDs in sorted ranges + // ═══════════════════════════════════════════════════════════════════ + + /// When idx is sorted, the same expert can appear in non-contiguous + /// ranges if the routing assigns it to tokens in different sorted + /// groups. The second occurrence of the same expertId is treated as + /// a miss (its slot was already claimed by the first occurrence). + func testDuplicateExpertInRangesExhaustsSlots() { + let maxBuffers = 4 + let prevIds = [0, 1, 2, 3] + // Expert 0 appears twice — second occurrence is a miss + let ranges = [ + ExpertRange(id: 0, start: 0, end: 1), + ExpertRange(id: 1, start: 1, end: 2), + ExpertRange(id: 2, start: 2, end: 3), + ExpertRange(id: 3, start: 3, end: 4), + ExpertRange(id: 0, start: 4, end: 5), // duplicate — miss + ] + + let (assignments, exhausted) = resolveSlots( + ranges: ranges, prevIds: prevIds, maxBuffers: maxBuffers + ) + + XCTAssertTrue(exhausted, "5 ranges with 4 slots must exhaust") + XCTAssertEqual(assignments.count, 4, "Should assign 4 before exhaustion") + } +}