9
9
import ExecuTorchLLM
10
10
import XCTest
11
11
12
+ extension UIImage {
13
+ func asImage( ) -> Image {
14
+ let cgImage = self . cgImage!
15
+ let width = cgImage. width
16
+ let height = cgImage. height
17
+ let pixelCount = width * height
18
+ let bytesPerPixel = 4
19
+ let bytesPerRow = bytesPerPixel * width
20
+ var pixelBytes = [ UInt8] ( repeating: 0 , count: pixelCount * bytesPerPixel)
21
+ let context = CGContext (
22
+ data: & pixelBytes,
23
+ width: width,
24
+ height: height,
25
+ bitsPerComponent: 8 ,
26
+ bytesPerRow: bytesPerRow,
27
+ space: CGColorSpaceCreateDeviceRGB ( ) ,
28
+ bitmapInfo: CGImageAlphaInfo . premultipliedLast. rawValue | CGBitmapInfo . byteOrder32Big. rawValue
29
+ ) !
30
+ context. draw ( cgImage, in: CGRect ( x: 0 , y: 0 , width: width, height: height) )
31
+ var rgbBytes = [ UInt8] ( repeating: 0 , count: pixelCount * 3 )
32
+ for i in 0 ..< pixelCount {
33
+ let pixelOffset = i * bytesPerPixel
34
+ rgbBytes [ i] = pixelBytes [ pixelOffset]
35
+ rgbBytes [ i + pixelCount] = pixelBytes [ pixelOffset + 1 ]
36
+ rgbBytes [ i + pixelCount * 2 ] = pixelBytes [ pixelOffset + 2 ]
37
+ }
38
+ return Image ( data: Data ( rgbBytes) , width: width, height: height, channels: 3 )
39
+ }
40
+ }
41
+
12
42
class MultimodalRunnerTest : XCTestCase {
13
43
func test( ) {
14
44
let bundle = Bundle ( for: type ( of: self ) )
15
45
guard let modelPath = bundle. path ( forResource: " llava " , ofType: " pte " ) ,
16
- let tokenizerPath = bundle. path ( forResource: " tokenizer " , ofType: " bin " ) else {
46
+ let tokenizerPath = bundle. path ( forResource: " tokenizer " , ofType: " bin " ) ,
47
+ let imagePath = bundle. path ( forResource: " IMG_0005 " , ofType: " JPG " ) ,
48
+ let image = UIImage ( contentsOfFile: imagePath) else {
17
49
XCTFail ( " Couldn't find model or tokenizer files " )
18
50
return
19
51
}
@@ -22,12 +54,15 @@ class MultimodalRunnerTest: XCTestCase {
22
54
var text = " "
23
55
24
56
do {
25
- try runner. generate ( [ MultimodalInput ( " hello " ) ] , sequenceLength: 2 ) { token in
57
+ try runner. generate ( [
58
+ MultimodalInput ( " What's this? " ) ,
59
+ MultimodalInput ( image. asImage ( ) ) ,
60
+ ] , sequenceLength: 2 ) { token in
26
61
text += token
27
62
}
28
63
} catch {
29
64
XCTFail ( " Failed to generate text with error \( error) " )
30
65
}
31
- XCTAssertEqual ( " hello, " , text. lowercased ( ) )
66
+ XCTAssertTrue ( text. lowercased ( ) . contains ( " waterfall " ) )
32
67
}
33
68
}
0 commit comments