Skip to content

Commit 5afea0a

Browse files
authored
Merge pull request #1375 from Kaggle/lh/short
[b/331681978] Add Short URL matching to KaggleFileResolver
2 parents 052d182 + 1caa6b0 commit 5afea0a

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

patches/kaggle_module_resolver.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44

55
from tensorflow_hub import resolver
66

7-
url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")
7+
short_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/(?P<framework>[^\\/]+)/(?P<variation>[^\\/]+)/(?P<version>[0-9]+)$")
8+
long_url_pattern = re.compile(r"https?://([a-z]+\.)?kaggle.com/models/(?P<owner>[^\\/]+)/(?P<model>[^\\/]+)/frameworks/(?P<framework>[^\\/]+)/variations/(?P<variation>[^\\/]+)/versions/(?P<version>[0-9]+)$")
89

910
def _is_on_kaggle_notebook():
1011
return os.getenv("KAGGLE_KERNEL_RUN_TYPE") != None and os.getenv("KAGGLE_USER_SECRETS_TOKEN") != None
1112

1213
def _is_kaggle_handle(handle):
13-
return url_pattern.match(handle) != None
14+
return long_url_pattern.match(handle) != None or short_url_pattern.match(handle) != None
1415

1516
class KaggleFileResolver(resolver.HttpResolverBase):
1617
def is_supported(self, handle):
1718
return _is_on_kaggle_notebook() and _is_kaggle_handle(handle)
1819

1920
def __call__(self, handle):
20-
m = url_pattern.match(handle)
21-
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}")
21+
m = long_url_pattern.match(handle) or short_url_pattern.match(handle)
22+
return kagglehub.model_download(f"{m.group('owner')}/{m.group('model')}/{m.group('framework').lower()}/{m.group('variation')}/{m.group('version')}")

tests/test_kaggle_module_resolver.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,25 @@ def do_POST(self):
7979
self.wfile.write(bytes(f"Unhandled path: {self.path}", "utf-8"))
8080

8181
class TestKaggleModuleResolver(unittest.TestCase):
82-
def test_kaggle_resolver_succeeds(self):
82+
def test_kaggle_resolver_long_url_succeeds(self):
83+
model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2"
8384
with create_test_server(KaggleJwtHandler) as addr:
8485
test_inputs = tf.ones([1,4])
85-
layer = hub.KerasLayer("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2")
86+
layer = hub.KerasLayer(model_url)
8687
self.assertEqual([1, 1], layer(test_inputs).shape)
88+
# Delete the files that were created in KaggleJwtHandler's do_POST method
89+
os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))
90+
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")))
91+
92+
def test_kaggle_resolver_short_url_succeeds(self):
93+
model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2"
94+
with create_test_server(KaggleJwtHandler) as addr:
95+
test_inputs = tf.ones([1,4])
96+
layer = hub.KerasLayer(model_url)
97+
self.assertEqual([1, 1], layer(test_inputs).shape)
98+
# Delete the files that were created in KaggleJwtHandler's do_POST method
99+
os.unlink(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2"))
100+
os.rmdir(os.path.dirname(os.path.join(MOUNT_PATH, "foomodule/tensorflow2/barvar/2")))
87101

88102
def test_kaggle_resolver_not_attached_throws(self):
89103
with create_test_server(KaggleJwtHandler) as addr:

0 commit comments

Comments
 (0)