@@ -79,11 +79,25 @@ def do_POST(self):
79
79
self .wfile .write (bytes (f"Unhandled path: { self .path } " , "utf-8" ))
80
80
81
81
class TestKaggleModuleResolver (unittest .TestCase ):
82
- def test_kaggle_resolver_succeeds (self ):
82
+ def test_kaggle_resolver_long_url_succeeds (self ):
83
+ model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2"
83
84
with create_test_server (KaggleJwtHandler ) as addr :
84
85
test_inputs = tf .ones ([1 ,4 ])
85
- layer = hub .KerasLayer ("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2" )
86
+ layer = hub .KerasLayer (model_url )
86
87
self .assertEqual ([1 , 1 ], layer (test_inputs ).shape )
88
+ # Delete the files that were created in KaggleJwtHandler's do_POST method
89
+ os .unlink (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" ))
90
+ os .rmdir (os .path .dirname (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" )))
91
+
92
+ def test_kaggle_resolver_short_url_succeeds (self ):
93
+ model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2"
94
+ with create_test_server (KaggleJwtHandler ) as addr :
95
+ test_inputs = tf .ones ([1 ,4 ])
96
+ layer = hub .KerasLayer (model_url )
97
+ self .assertEqual ([1 , 1 ], layer (test_inputs ).shape )
98
+ # Delete the files that were created in KaggleJwtHandler's do_POST method
99
+ os .unlink (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" ))
100
+ os .rmdir (os .path .dirname (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" )))
87
101
88
102
def test_kaggle_resolver_not_attached_throws (self ):
89
103
with create_test_server (KaggleJwtHandler ) as addr :
0 commit comments