diff --git a/Source/MLX/IO.swift b/Source/MLX/IO.swift index acdb590d..519096dc 100644 --- a/Source/MLX/IO.swift +++ b/Source/MLX/IO.swift @@ -36,7 +36,9 @@ public func save(array: MLXArray, url: URL, stream: StreamOrDevice = .default) t switch url.pathExtension { case "npy": _ = try withError { - mlx_save(path.cString(using: .utf8), array.ctx) + _ = evalLock.withLock { + mlx_save(path.cString(using: .utf8), array.ctx) + } } default: @@ -72,7 +74,9 @@ public func save( switch url.pathExtension { case "safetensors": _ = try withError { - mlx_save_safetensors(path.cString(using: .utf8), mlx_arrays, mlx_metadata) + _ = evalLock.withLock { + mlx_save_safetensors(path.cString(using: .utf8), mlx_arrays, mlx_metadata) + } } default: @@ -283,7 +287,11 @@ public func saveToData( let writer = new_mlx_io_writer_dataIO() defer { mlx_io_writer_free(writer) } - mlx_save_safetensors_writer(writer, mlx_arrays, mlx_metadata) + _ = evalLock.withLock { + _ = evalLock.withLock { + mlx_save_safetensors_writer(writer, mlx_arrays, mlx_metadata) + } + } return getData(writer) }