Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
// Integration packages inject their own Downloader and TokenizerLoader, then call
// these functions which run the test and throw on failure.

import CoreImage
import Foundation
import MLX
import MLXEmbedders
import MLXLLM
import MLXLMCommon
import MLXVLM

#if canImport(CoreImage)
import CoreImage
#endif

// Both MLXLMCommon and MLXEmbedders define ModelContainer.
public typealias LLModelContainer = MLXLMCommon.ModelContainer
public typealias EmbeddingModelContainer = MLXEmbedders.EmbedderModelContainer
Expand Down Expand Up @@ -176,18 +179,23 @@ public enum ChatSessionTests {
}

public static func visionModel(container: LLModelContainer) async throws {
let session = ChatSession(container, generateParameters: generateParameters)
let redImage = CIImage(color: .red).cropped(
to: CGRect(x: 0, y: 0, width: 100, height: 100))
#if canImport(CoreImage)
let session = ChatSession(container, generateParameters: generateParameters)
let redImage = CIImage(color: .red).cropped(
to: CGRect(x: 0, y: 0, width: 100, height: 100))

let result = try await streamAndCollect(
session.streamResponse(
to: "What color is this image? Reply with just the color name.",
image: .ciImage(redImage)), label: "Vision")
try check(
result.lowercased().contains("red"),
"Expected 'red' in response, got: \(result)"
)
let result = try await streamAndCollect(
session.streamResponse(
to: "What color is this image? Reply with just the color name.",
image: .ciImage(redImage)), label: "Vision")
try check(
result.lowercased().contains("red"),
"Expected 'red' in response, got: \(result)"
)
#else
fatalError(
"Vision model test requires CoreImage, which is not available on this platform.")
#endif
}

public static func streamDetailsWithTools(container: LLModelContainer) async throws {
Expand Down
5 changes: 4 additions & 1 deletion Libraries/MLXLMCommon/ChatSession.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// Copyright © 2025 Apple Inc.

import CoreGraphics
import Foundation
import MLX

#if canImport(CoreGraphics)
import CoreGraphics
#endif

/// Configuration for speculative decoding in a `ChatSession`.
///
/// Speculative decoding uses a small draft model to propose candidate tokens
Expand Down
22 changes: 22 additions & 0 deletions Libraries/MLXLMCommon/Linux/CoreGraphics.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright © 2026 Apple Inc.

#if !canImport(CoreGraphics)

public typealias CGFloat = Double

public struct CGSize: Sendable {
public var width: CGFloat
public var height: CGFloat

public init(width: CGFloat, height: CGFloat) {
self.width = width
self.height = height
}

public init(width: Int, height: Int) {
self.width = CGFloat(width)
self.height = CGFloat(height)
}
}

#endif
10 changes: 10 additions & 0 deletions Libraries/MLXLMCommon/Linux/CoreMedia.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright © 2026 Apple Inc.

#if !canImport(CoreMedia)

public struct CMTime {
public var value: Int64
public var timescale: Int32
}

#endif
29 changes: 29 additions & 0 deletions Libraries/MLXLMCommon/Linux/Logger.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright © 2026 Apple Inc.

#if canImport(os)

import os

typealias Logger = os.Logger

#else

final class Logger: Sendable {
private let subsystem: String
private let category: String

init(subsystem: String, category: String) {
self.subsystem = subsystem
self.category = category
}

func info(_ message: String) {
print("[INFO] [\(subsystem).\(category)] \(message)")
}

func error(_ message: String) {
print("[ERROR] [\(subsystem).\(category)] \(message)")
}
}

#endif
13 changes: 13 additions & 0 deletions Libraries/MLXLMCommon/Linux/String+Linux.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright © 2026 Apple Inc.

import Foundation

#if os(Linux)

extension String {
public init(localized resource: String) {
self = resource
}
}

#endif
1 change: 0 additions & 1 deletion Libraries/MLXLMCommon/ParoQuant/ParoQuantLoader.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import Foundation
import MLX
import MLXNN
import os

private let logger = Logger(subsystem: "mlx-swift-lm", category: "paroquant")

Expand Down
181 changes: 108 additions & 73 deletions Libraries/MLXLMCommon/UserInput.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
// Copyright © 2024 Apple Inc.

@preconcurrency import AVFoundation
import CoreImage
import Foundation
import MLX

#if canImport(AVFoundation)
@preconcurrency import AVFoundation
#endif
#if canImport(CoreImage)
import CoreImage
#endif

public typealias Message = [String: any Sendable]

/// Container for raw user input.
Expand Down Expand Up @@ -40,101 +45,131 @@ public struct UserInput {
}

public struct VideoFrame {
public let frame: CIImage
public let image: Image
public let timeStamp: CMTime

public init(frame: CIImage, timeStamp: CMTime) {
self.frame = frame
public init(image: Image, timeStamp: CMTime) {
self.image = image
self.timeStamp = timeStamp
}

#if canImport(CoreImage)

@available(
*, deprecated,
message: "Use init(image:, timeStamp:) instead"
)
public init(frame: CIImage, timeStamp: CMTime) {
self.image = .ciImage(frame)
self.timeStamp = timeStamp
}

@available(
*, deprecated,
message: "Use image.asCIImage()"
)
public var frame: CIImage {
return try! image.asCIImage()
}

#endif
}

/// Representation of a video resource.
public enum Video {
case avAsset(AVAsset)
#if canImport(AVFoundation)
case avAsset(AVAsset)
#endif
case url(URL)
/// Useful for decoded frames held in memory
case frames([VideoFrame])

@available(
*, deprecated,
message: "Use MediaProcessing.asProcessedSequence() with the Video directly"
)
public func asAVAsset() -> AVAsset {
switch self {
case .avAsset(let asset):
return asset
case .url(let url):
return AVAsset(url: url)
case .frames:
fatalError(
"calling asAVAsset() on Video Input with VideoFames provided is unsupported and deprecated - please use MediaProcessing.asProcessedSequence() instead"
)
#if canImport(AVFoundation)
@available(
*, deprecated,
message: "Use MediaProcessing.asProcessedSequence() with the Video directly"
)
public func asAVAsset() -> AVAsset {
switch self {
case .avAsset(let asset):
return asset
case .url(let url):
return AVAsset(url: url)
case .frames:
fatalError(
"calling asAVAsset() on Video Input with VideoFames provided is unsupported and deprecated - please use MediaProcessing.asProcessedSequence() instead"
)
}
}
}
#endif
}

/// Representation of an image resource.
public enum Image {
case ciImage(CIImage)
#if canImport(CoreImage)
case ciImage(CIImage)
#endif
case url(URL)
case array(MLXArray)

public func asCIImage() throws -> CIImage {
switch self {
case .ciImage(let image):
return image

case .url(let url):
if let image = CIImage(contentsOf: url) {
#if canImport(CoreImage)
public func asCIImage() throws -> CIImage {
switch self {
case .ciImage(let image):
return image
}
throw UserInputError.unableToLoad(url)

case .array(let array):
guard array.ndim == 3 else {
throw UserInputError.arrayError("array must have 3 dimensions: \(array.ndim)")
}

var array = array

// convert to 0 .. 255
if array.max().item(Float.self) <= 1.0 {
array = array * 255
}

// planar -> pixels
switch array.dim(0) {
case 3, 4:
// channels first (planar)
array = array.transposed(1, 2, 0)
default:
break
case .url(let url):
if let image = CIImage(contentsOf: url) {
return image
}
throw UserInputError.unableToLoad(url)

case .array(let array):
guard array.ndim == 3 else {
throw UserInputError.arrayError(
"array must have 3 dimensions: \(array.ndim)")
}

var array = array

// convert to 0 .. 255
if array.max().item(Float.self) <= 1.0 {
array = array * 255
}

// planar -> pixels
switch array.dim(0) {
case 3, 4:
// channels first (planar)
array = array.transposed(1, 2, 0)
default:
break
}

// 4 components per pixel
switch array.dim(-1) {
case 3:
// pad to 4 bytes per pixel
array = padded(array, widths: [0, 0, [0, 1]], value: MLXArray(255))
case 4:
// good
break
default:
throw UserInputError.arrayError(
"channel dimension must be last and 3/4: \(array.shape)")
}

let arrayData = array.asData()
let (H, W, _) = array.shape3
let cs = CGColorSpace(name: CGColorSpace.sRGB)!

return CIImage(
bitmapData: arrayData.data, bytesPerRow: W * 4,
size: .init(width: W, height: H),
format: .RGBA8, colorSpace: cs)
}

// 4 components per pixel
switch array.dim(-1) {
case 3:
// pad to 4 bytes per pixel
array = padded(array, widths: [0, 0, [0, 1]], value: MLXArray(255))
case 4:
// good
break
default:
throw UserInputError.arrayError(
"channel dimension must be last and 3/4: \(array.shape)")
}

let arrayData = array.asData()
let (H, W, _) = array.shape3
let cs = CGColorSpace(name: CGColorSpace.sRGB)!

return CIImage(
bitmapData: arrayData.data, bytesPerRow: W * 4,
size: .init(width: W, height: H),
format: .RGBA8, colorSpace: cs)
}
}
#endif
}

/// Representation of processing to apply to media.
Expand Down
8 changes: 4 additions & 4 deletions Libraries/MLXVLM/MediaProcessing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@ public enum MediaProcessing {
case .success(requestedTime: _, let image, actualTime: let actual):
let ciImage = CIImage(
cgImage: image, options: [.colorSpace: CGColorSpace(name: CGColorSpace.sRGB)!])
let frame = try frameProcessing(.init(frame: ciImage, timeStamp: actual))
ciImages.append(frame.frame)
let frame = try frameProcessing(.init(image: .ciImage(ciImage), timeStamp: actual))
ciImages.append(try frame.image.asCIImage())
timestamps.append(frame.timeStamp)
case .failure(requestedTime: _, _):
break
Expand Down Expand Up @@ -511,8 +511,8 @@ public enum MediaProcessing {
if let targetIndex {
let videoFrame = videoFrames[targetIndex]
let frame = try frameProcessing(
.init(frame: videoFrame.frame, timeStamp: videoFrame.timeStamp))
ciImages.append(frame.frame)
.init(image: videoFrame.image, timeStamp: videoFrame.timeStamp))
ciImages.append(try frame.image.asCIImage())
timestamps.append(frame.timeStamp)
}
}
Expand Down
Loading