Skip to content

Commit 83789d0

Browse files
authored
Use kagglehub inside Kaggle Notebook for tensorflow_hub (#1342)
Currently, if you try using tensorflow_hub with a kaggle.com URL inside a Kaggle notebook but the model isn't attached you would get an error message telling you to attach the model. With this change, we are calling `kagglehub` instead and will attach the model for the user if it is missing. http://b/311184735
1 parent d1d0e2e commit 83789d0

File tree

2 files changed

+71
-37
lines changed

2 files changed

+71
-37
lines changed

patches/kaggle_module_resolver.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
22
import re
3+
import kagglehub
34

45
from tensorflow_hub import resolver
56

6-
url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/[^\\/]+/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")
7+
url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")
78

89
def _is_on_kaggle_notebook():
9-
return os.getenv("KAGGLE_CONTAINER_NAME") != None
10+
return os.getenv("KAGGLE_KERNEL_RUN_TYPE") != None
1011

1112
def _is_kaggle_handle(handle):
1213
return url_pattern.match(handle) != None
@@ -17,9 +18,4 @@ def is_supported(self, handle):
1718

1819
def __call__(self, handle):
1920
m = url_pattern.match(handle)
20-
local_path = f"/kaggle/input/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}"
21-
if not os.path.exists(local_path):
22-
# TODO(b/268256777) Attach model & wait until ready instead.
23-
raise RuntimeError(f"You have to attach the '{handle}' model to your Kaggle notebook.")
24-
25-
return local_path
21+
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}")

tests/test_kaggle_module_resolver.py

+67-29
Original file line numberDiff line numberDiff line change
@@ -2,62 +2,100 @@
22

33
import os
44
import threading
5+
import json
56

67
import tensorflow as tf
78
import tensorflow_hub as hub
89

910
from http.server import BaseHTTPRequestHandler, HTTPServer
1011
from test.support.os_helper import EnvironmentVarGuard
12+
from contextlib import contextmanager
13+
from kagglehub.exceptions import BackendError
1114

15+
MOUNT_PATH = "/kaggle/input"
1216

13-
class TestKaggleModuleResolver(unittest.TestCase):
14-
class HubHTTPHandler(BaseHTTPRequestHandler):
15-
def do_GET(self):
16-
self.send_response(200)
17-
self.send_header('Content-Type', 'application/gzip')
18-
self.end_headers()
17+
@contextmanager
18+
def create_test_server(handler_class):
19+
hostname = 'localhost'
20+
port = 8080
21+
addr = f"http://{hostname}:{port}"
1922

20-
with open('/input/tests/data/model.tar.gz', 'rb') as model_archive:
21-
self.wfile.write(model_archive.read())
23+
# Simulates we are inside a Kaggle environment.
24+
env = EnvironmentVarGuard()
25+
env.set('KAGGLE_KERNEL_RUN_TYPE', 'Interactive')
26+
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foo jwt token')
27+
env.set('KAGGLE_DATA_PROXY_TOKEN', 'foo proxy token')
28+
env.set('KAGGLE_DATA_PROXY_URL', addr)
2229

23-
def _test_client(self, client_func, handler):
24-
with HTTPServer(('localhost', 8080), handler) as test_server:
30+
with env:
31+
with HTTPServer((hostname, port), handler_class) as test_server:
2532
threading.Thread(target=test_server.serve_forever).start()
2633

2734
try:
28-
client_func()
35+
yield addr
2936
finally:
3037
test_server.shutdown()
3138

32-
def test_kaggle_resolver_succeeds(self):
33-
# Simulates we are inside a Kaggle environment.
34-
env = EnvironmentVarGuard()
35-
env.set('KAGGLE_CONTAINER_NAME', 'foo')
36-
# Attach model to right directory.
37-
os.makedirs('/kaggle/input/foomodule/tensorflow2/barvar')
38-
os.symlink('/input/tests/data/saved_model/', '/kaggle/input/foomodule/tensorflow2/barvar/2', target_is_directory=True)
39+
class HubHTTPHandler(BaseHTTPRequestHandler):
40+
def do_GET(self):
41+
self.send_response(200)
42+
self.send_header('Content-Type', 'application/gzip')
43+
self.end_headers()
44+
45+
with open('/input/tests/data/model.tar.gz', 'rb') as model_archive:
46+
self.wfile.write(model_archive.read())
3947

40-
with env:
48+
class KaggleJwtHandler(BaseHTTPRequestHandler):
49+
def do_POST(self):
50+
if self.path.endswith("AttachDatasourceUsingJwtRequest"):
51+
content_length = int(self.headers["Content-Length"])
52+
request = json.loads(self.rfile.read(content_length))
53+
model_ref = request["modelRef"]
54+
55+
self.send_response(200)
56+
self.send_header("Content-type", "application/json")
57+
self.end_headers()
58+
59+
if model_ref['ModelSlug'] == 'unknown':
60+
self.wfile.write(bytes(json.dumps({
61+
"wasSuccessful": False,
62+
}), "utf-8"))
63+
return
64+
65+
# Load the files
66+
mount_slug = f"{model_ref['ModelSlug']}/{model_ref['Framework']}/{model_ref['InstanceSlug']}/{model_ref['VersionNumber']}"
67+
os.makedirs(os.path.dirname(os.path.join(MOUNT_PATH, mount_slug)))
68+
os.symlink('/input/tests/data/saved_model/', os.path.join(MOUNT_PATH, mount_slug), target_is_directory=True)
69+
70+
# Return the response
71+
self.wfile.write(bytes(json.dumps({
72+
"wasSuccessful": True,
73+
"result": {
74+
"mountSlug": mount_slug,
75+
},
76+
}), "utf-8"))
77+
else:
78+
self.send_response(404)
79+
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8"))
80+
81+
class TestKaggleModuleResolver(unittest.TestCase):
82+
def test_kaggle_resolver_succeeds(self):
83+
with create_test_server(KaggleJwtHandler) as addr:
4184
test_inputs = tf.ones([1,4])
4285
layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2")
4386
self.assertEqual([1, 1], layer(test_inputs).shape)
4487

4588
def test_kaggle_resolver_not_attached_throws(self):
46-
# Simulates we are inside a Kaggle environment.
47-
env = EnvironmentVarGuard()
48-
env.set('KAGGLE_CONTAINER_NAME', 'foo')
49-
with env:
50-
with self.assertRaisesRegex(RuntimeError, '.*attach.*'):
51-
hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2")
89+
with create_test_server(KaggleJwtHandler) as addr:
90+
with self.assertRaises(BackendError):
91+
hub.KerasLayer("https://kaggle.com/models/foo/unknown/frameworks/TensorFlow2/variations/barvar/versions/2")
5292

5393
def test_http_resolver_succeeds(self):
54-
def call_hub():
94+
with create_test_server(HubHTTPHandler) as addr:
5595
test_inputs = tf.ones([1,4])
56-
layer = hub.KerasLayer('http://localhost:8080/model.tar.gz')
96+
layer = hub.KerasLayer(f'{addr}/model.tar.gz')
5797
self.assertEqual([1, 1], layer(test_inputs).shape)
5898

59-
self._test_client(call_hub, TestKaggleModuleResolver.HubHTTPHandler)
60-
6199
def test_local_path_resolver_succeeds(self):
62100
test_inputs = tf.ones([1,4])
63101
layer = hub.KerasLayer('/input/tests/data/saved_model')

0 commit comments

Comments
 (0)