|
2 | 2 |
|
3 | 3 | import os
|
4 | 4 | import threading
|
| 5 | +import json |
5 | 6 |
|
6 | 7 | import tensorflow as tf
|
7 | 8 | import tensorflow_hub as hub
|
8 | 9 |
|
9 | 10 | from http.server import BaseHTTPRequestHandler, HTTPServer
|
10 | 11 | from test.support.os_helper import EnvironmentVarGuard
|
| 12 | +from contextlib import contextmanager |
| 13 | +from kagglehub.exceptions import BackendError |
11 | 14 |
|
| 15 | +MOUNT_PATH = "/kaggle/input" |
12 | 16 |
|
13 |
| -class TestKaggleModuleResolver(unittest.TestCase): |
14 |
| - class HubHTTPHandler(BaseHTTPRequestHandler): |
15 |
| - def do_GET(self): |
16 |
| - self.send_response(200) |
17 |
| - self.send_header('Content-Type', 'application/gzip') |
18 |
| - self.end_headers() |
| 17 | +@contextmanager |
| 18 | +def create_test_server(handler_class): |
| 19 | + hostname = 'localhost' |
| 20 | + port = 8080 |
| 21 | + addr = f"http://{hostname}:{port}" |
19 | 22 |
|
20 |
| - with open('/input/tests/data/model.tar.gz', 'rb') as model_archive: |
21 |
| - self.wfile.write(model_archive.read()) |
| 23 | + # Simulates we are inside a Kaggle environment. |
| 24 | + env = EnvironmentVarGuard() |
| 25 | + env.set('KAGGLE_KERNEL_RUN_TYPE', 'Interactive') |
| 26 | + env.set('KAGGLE_USER_SECRETS_TOKEN', 'foo jwt token') |
| 27 | + env.set('KAGGLE_DATA_PROXY_TOKEN', 'foo proxy token') |
| 28 | + env.set('KAGGLE_DATA_PROXY_URL', addr) |
22 | 29 |
|
23 |
| - def _test_client(self, client_func, handler): |
24 |
| - with HTTPServer(('localhost', 8080), handler) as test_server: |
| 30 | + with env: |
| 31 | + with HTTPServer((hostname, port), handler_class) as test_server: |
25 | 32 | threading.Thread(target=test_server.serve_forever).start()
|
26 | 33 |
|
27 | 34 | try:
|
28 |
| - client_func() |
| 35 | + yield addr |
29 | 36 | finally:
|
30 | 37 | test_server.shutdown()
|
31 | 38 |
|
32 |
| - def test_kaggle_resolver_succeeds(self): |
33 |
| - # Simulates we are inside a Kaggle environment. |
34 |
| - env = EnvironmentVarGuard() |
35 |
| - env.set('KAGGLE_CONTAINER_NAME', 'foo') |
36 |
| - # Attach model to right directory. |
37 |
| - os.makedirs('/kaggle/input/foomodule/tensorflow2/barvar') |
38 |
| - os.symlink('/input/tests/data/saved_model/', '/kaggle/input/foomodule/tensorflow2/barvar/2', target_is_directory=True) |
| 39 | +class HubHTTPHandler(BaseHTTPRequestHandler): |
| 40 | + def do_GET(self): |
| 41 | + self.send_response(200) |
| 42 | + self.send_header('Content-Type', 'application/gzip') |
| 43 | + self.end_headers() |
| 44 | + |
| 45 | + with open('/input/tests/data/model.tar.gz', 'rb') as model_archive: |
| 46 | + self.wfile.write(model_archive.read()) |
39 | 47 |
|
40 |
| - with env: |
| 48 | +class KaggleJwtHandler(BaseHTTPRequestHandler): |
| 49 | + def do_POST(self): |
| 50 | + if self.path.endswith("AttachDatasourceUsingJwtRequest"): |
| 51 | + content_length = int(self.headers["Content-Length"]) |
| 52 | + request = json.loads(self.rfile.read(content_length)) |
| 53 | + model_ref = request["modelRef"] |
| 54 | + |
| 55 | + self.send_response(200) |
| 56 | + self.send_header("Content-type", "application/json") |
| 57 | + self.end_headers() |
| 58 | + |
| 59 | + if model_ref['ModelSlug'] == 'unknown': |
| 60 | + self.wfile.write(bytes(json.dumps({ |
| 61 | + "wasSuccessful": False, |
| 62 | + }), "utf-8")) |
| 63 | + return |
| 64 | + |
| 65 | + # Load the files |
| 66 | + mount_slug = f"{model_ref['ModelSlug']}/{model_ref['Framework']}/{model_ref['InstanceSlug']}/{model_ref['VersionNumber']}" |
| 67 | + os.makedirs(os.path.dirname(os.path.join(MOUNT_PATH, mount_slug))) |
| 68 | + os.symlink('/input/tests/data/saved_model/', os.path.join(MOUNT_PATH, mount_slug), target_is_directory=True) |
| 69 | + |
| 70 | + # Return the response |
| 71 | + self.wfile.write(bytes(json.dumps({ |
| 72 | + "wasSuccessful": True, |
| 73 | + "result": { |
| 74 | + "mountSlug": mount_slug, |
| 75 | + }, |
| 76 | + }), "utf-8")) |
| 77 | + else: |
| 78 | + self.send_response(404) |
| 79 | + self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8")) |
| 80 | + |
| 81 | +class TestKaggleModuleResolver(unittest.TestCase): |
| 82 | + def test_kaggle_resolver_succeeds(self): |
| 83 | + with create_test_server(KaggleJwtHandler) as addr: |
41 | 84 | test_inputs = tf.ones([1,4])
|
42 | 85 | layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2")
|
43 | 86 | self.assertEqual([1, 1], layer(test_inputs).shape)
|
44 | 87 |
|
45 | 88 | def test_kaggle_resolver_not_attached_throws(self):
|
46 |
| - # Simulates we are inside a Kaggle environment. |
47 |
| - env = EnvironmentVarGuard() |
48 |
| - env.set('KAGGLE_CONTAINER_NAME', 'foo') |
49 |
| - with env: |
50 |
| - with self.assertRaisesRegex(RuntimeError, '.*attach.*'): |
51 |
| - hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2") |
| 89 | + with create_test_server(KaggleJwtHandler) as addr: |
| 90 | + with self.assertRaises(BackendError): |
| 91 | + hub.KerasLayer("https://kaggle.com/models/foo/unknown/frameworks/TensorFlow2/variations/barvar/versions/2") |
52 | 92 |
|
53 | 93 | def test_http_resolver_succeeds(self):
|
54 |
| - def call_hub(): |
| 94 | + with create_test_server(HubHTTPHandler) as addr: |
55 | 95 | test_inputs = tf.ones([1,4])
|
56 |
| - layer = hub.KerasLayer('http://localhost:8080/model.tar.gz') |
| 96 | + layer = hub.KerasLayer(f'{addr}/model.tar.gz') |
57 | 97 | self.assertEqual([1, 1], layer(test_inputs).shape)
|
58 | 98 |
|
59 |
| - self._test_client(call_hub, TestKaggleModuleResolver.HubHTTPHandler) |
60 |
| - |
61 | 99 | def test_local_path_resolver_succeeds(self):
|
62 | 100 | test_inputs = tf.ones([1,4])
|
63 | 101 | layer = hub.KerasLayer('/input/tests/data/saved_model')
|
|
0 commit comments