1818
1919from genkit .blocks .embedding import embedder_action_metadata , create_embedder_ref
2020from genkit .core .action import ActionMetadata , Action
21- from genkit .core .typing import EmbedderOptions , EmbedderSupports , EmbedderRef , EmbedRequest , EmbedResponse , Embedding , Part
21+ from genkit .core .typing import (
22+ EmbedderOptions ,
23+ EmbedderSupports ,
24+ EmbedderRef ,
25+ EmbedRequest ,
26+ EmbedResponse ,
27+ Embedding ,
28+ Part ,
29+ )
2230from genkit .blocks .document import Document
2331from genkit .core .schema import to_json_schema
2432from genkit .core .action .types import ActionResponse
2836from pydantic import BaseModel
2937
3038
31- # def test_embedder_action_metadata():
32- # """Test for embedder_action_metadata."""
33- # action_metadata = embedder_action_metadata(
34- # name='test_model',
35- # info={'label': 'test_label'},
36- # config_schema=None,
37- # )
38- #
39- # assert isinstance(action_metadata, ActionMetadata)
40- # assert action_metadata.input_json_schema is not None
41- # assert action_metadata.output_json_schema is not None
42- # assert action_metadata.metadata == {'embedder': {'customOptions': None, 'label': 'test_label'}}
43-
4439def test_embedder_action_metadata ():
4540 """Test for embedder_action_metadata with basic options."""
4641 options = EmbedderOptions (label = 'Test Embedder' , dimensions = 128 )
@@ -60,8 +55,10 @@ def test_embedder_action_metadata():
6055 }
6156 }
6257
58+
6359def test_embedder_action_metadata_with_supports_and_config_schema ():
6460 """Test for embedder_action_metadata with supports and config_schema."""
61+
6562 class CustomConfig (BaseModel ):
6663 param1 : str
6764 param2 : int
@@ -70,7 +67,7 @@ class CustomConfig(BaseModel):
7067 label = 'Advanced Embedder' ,
7168 dimensions = 256 ,
7269 supports = EmbedderSupports (input = ['text' , 'image' ], multilingual = True ),
73- config_schema = to_json_schema (CustomConfig ) # Pass the JSON schema directly
70+ config_schema = to_json_schema (CustomConfig ), # Pass the JSON schema directly
7471 )
7572 action_metadata = embedder_action_metadata (
7673 name = 'advanced_model' ,
@@ -94,19 +91,22 @@ class CustomConfig(BaseModel):
9491 'required' : ['param1' , 'param2' ],
9592 }
9693
94+
9795def test_embedder_action_metadata_no_options ():
9896 """Test embedder_action_metadata when no options are provided."""
9997 action_metadata = embedder_action_metadata (name = 'default_model' )
10098 assert isinstance (action_metadata , ActionMetadata )
10199 assert action_metadata .metadata == {'embedder' : {'customOptions' : None , 'dimensions' : None }}
102100
101+
103102def test_create_embedder_ref_basic ():
104103 """Test basic creation of EmbedderRef."""
105104 ref = create_embedder_ref ('my-embedder' )
106105 assert ref .name == 'my-embedder'
107106 assert ref .config is None
108107 assert ref .version is None
109108
109+
110110def test_create_embedder_ref_with_config ():
111111 """Test creation of EmbedderRef with configuration."""
112112 config = {'temperature' : 0.5 , 'max_tokens' : 100 }
@@ -115,13 +115,15 @@ def test_create_embedder_ref_with_config():
115115 assert ref .config == config
116116 assert ref .version is None
117117
118+
118119def test_create_embedder_ref_with_version ():
119120 """Test creation of EmbedderRef with a version."""
120121 ref = create_embedder_ref ('versioned-embedder' , version = 'v1.0' )
121122 assert ref .name == 'versioned-embedder'
122123 assert ref .config is None
123124 assert ref .version == 'v1.0'
124125
126+
125127def test_create_embedder_ref_with_config_and_version ():
126128 """Test creation of EmbedderRef with both config and version."""
127129 config = {'task_type' : 'retrieval' }
@@ -130,8 +132,10 @@ def test_create_embedder_ref_with_config_and_version():
130132 assert ref .config == config
131133 assert ref .version == 'beta'
132134
135+
133136class MockGenkitRegistry :
134137 """A mock registry to simulate action lookup."""
138+
135139 def __init__ (self ):
136140 self .actions = {}
137141
@@ -145,7 +149,7 @@ def register_action(self, name, kind, fn, metadata, description):
145149 async def mock_arun_side_effect (request , * args , ** kwargs ):
146150 # Call the actual (fake) embedder function directly
147151 embed_response = await fn (request )
148- return ActionResponse (response = embed_response , trace_id = " mock_trace_id" )
152+ return ActionResponse (response = embed_response , trace_id = ' mock_trace_id' )
149153
150154 mock_action .arun = AsyncMock (side_effect = mock_arun_side_effect )
151155 self .actions [(kind , name )] = mock_action
@@ -154,62 +158,52 @@ async def mock_arun_side_effect(request, *args, **kwargs):
154158 def lookup_action (self , kind , name ):
155159 return self .actions .get ((kind , name ))
156160
161+
157162@pytest .fixture
158163def mock_genkit_instance ():
159164 """Fixture for a Genkit instance with a mock registry."""
160165 registry = MockGenkitRegistry ()
161- genkit_instance = Genkit () # Create Genkit instance without registry argument
162- genkit_instance .registry = registry # Assign the mock registry to the instance
166+ genkit_instance = Genkit ()
167+ genkit_instance .registry = registry
163168 return genkit_instance , registry
164169
170+
165171@pytest .mark .asyncio
166172async def test_embed_with_embedder_ref (mock_genkit_instance ):
167173 """Test the embed method using EmbedderRef."""
168174 genkit_instance , registry = mock_genkit_instance
169175
170- # Define an embedder in the mock registry
171176 async def fake_embedder_fn (request : EmbedRequest ) -> EmbedResponse :
172177 return EmbedResponse (embeddings = [Embedding (embedding = [1.0 , 2.0 , 3.0 ])])
173178
174179 embedder_options = EmbedderOptions (
175180 label = 'Fake Embedder' ,
176181 dimensions = 3 ,
177182 supports = EmbedderSupports (input = ['text' ]),
178- config_schema = {'type' : 'object' , 'properties' : {'param' : {'type' : 'string' }}}
183+ config_schema = {'type' : 'object' , 'properties' : {'param' : {'type' : 'string' }}},
179184 )
180185 registry .register_action (
181186 name = 'my-plugin/my-embedder' ,
182187 kind = 'embedder' ,
183188 fn = fake_embedder_fn ,
184189 metadata = embedder_action_metadata ('my-plugin/my-embedder' , options = embedder_options ).metadata ,
185- description = 'A fake embedder for testing'
190+ description = 'A fake embedder for testing' ,
186191 )
187192
188- # Create an EmbedderRef
189- embedder_ref = create_embedder_ref (
190- 'my-plugin/my-embedder' ,
191- config = {'param' : 'value' },
192- version = 'v1'
193- )
193+ embedder_ref = create_embedder_ref ('my-plugin/my-embedder' , config = {'param' : 'value' }, version = 'v1' )
194194
195195 documents = [Document .from_text ('hello world' )]
196196
197- # Call the embed method
198197 response = await genkit_instance .embed (
199- embedder = embedder_ref ,
200- documents = documents ,
201- options = {'additional_option' : True }
198+ embedder = embedder_ref , documents = documents , options = {'additional_option' : True }
202199 )
203200
204- # Assertions
205201 assert response .embeddings [0 ].embedding == [1.0 , 2.0 , 3.0 ]
206202
207- # Verify that lookup_action was called correctly
208203 embed_action = registry .lookup_action ('embedder' , 'my-plugin/my-embedder' )
209204 assert embed_action is not None
210205 embed_action .arun .assert_called_once ()
211206
212- # Get the EmbedRequest passed to the mock action
213207 called_request = embed_action .arun .call_args [0 ][0 ]
214208 assert isinstance (called_request , EmbedRequest )
215209 assert called_request .input == documents
@@ -231,27 +225,26 @@ async def fake_embedder_fn(request: EmbedRequest) -> EmbedResponse:
231225 kind = 'embedder' ,
232226 fn = fake_embedder_fn ,
233227 metadata = embedder_action_metadata ('another-embedder' , options = embedder_options ).metadata ,
234- description = 'Another fake embedder'
228+ description = 'Another fake embedder' ,
235229 )
236230
237231 documents = [Document .from_text ('test text' )]
238232
239233 response = await genkit_instance .embed (
240- embedder = 'another-embedder' ,
241- documents = documents ,
242- options = {'custom_setting' : 'high' }
234+ embedder = 'another-embedder' , documents = documents , options = {'custom_setting' : 'high' }
243235 )
244236
245237 assert response .embeddings [0 ].embedding == [4.0 , 5.0 , 6.0 ]
246238 embed_action = registry .lookup_action ('embedder' , 'another-embedder' )
247239 called_request = embed_action .arun .call_args [0 ][0 ]
248240 assert called_request .options == {'custom_setting' : 'high' }
249241
242+
250243@pytest .mark .asyncio
251244async def test_embed_missing_embedder_raises_error (mock_genkit_instance ):
252245 """Test that embedding with a missing embedder raises an error."""
253246 genkit_instance , _ = mock_genkit_instance
254247 documents = [Document .from_text ('some text' )]
255248
256- with pytest .raises (ValueError , match = " Embedder must be specified as a string name or an EmbedderRef." ):
249+ with pytest .raises (ValueError , match = ' Embedder must be specified as a string name or an EmbedderRef.' ):
257250 await genkit_instance .embed (documents = documents )
0 commit comments