@@ -325,62 +325,137 @@ def test_decoupled_execute_cancel(self):
325
325
self .assertIn ("[execute_cancel] Request cancelled at " , log_text )
326
326
327
327
def test_decoupled_bls_cancel (self ):
328
- model_name = "decoupled_bls_cancel"
328
+ model_names = [ "decoupled_bls_cancel" , "decoupled_bls_async_cancel" ]
329
329
input_value = 1
330
330
max_sum_value = 10
331
+ ignore_cancel = False
331
332
user_data = UserData ()
333
+ for model_name in model_names :
334
+ with self ._shm_leak_detector .Probe () as shm_probe :
335
+ with grpcclient .InferenceServerClient (
336
+ f"{ _tritonserver_ipaddr } :8001"
337
+ ) as client :
338
+ client .start_stream (callback = partial (callback , user_data ))
339
+ input_data = np .array ([input_value ], dtype = np .int32 )
340
+ max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
341
+ ignore_cancel_data = np .array ([ignore_cancel ], dtype = np .bool_ )
342
+ inputs = [
343
+ grpcclient .InferInput (
344
+ "INPUT" ,
345
+ input_data .shape ,
346
+ np_to_triton_dtype (input_data .dtype ),
347
+ ),
348
+ grpcclient .InferInput (
349
+ "MAX_SUM" ,
350
+ max_sum_data .shape ,
351
+ np_to_triton_dtype (max_sum_data .dtype ),
352
+ ),
353
+ grpcclient .InferInput (
354
+ "IGNORE_CANCEL" ,
355
+ ignore_cancel_data .shape ,
356
+ np_to_triton_dtype (ignore_cancel_data .dtype ),
357
+ ),
358
+ ]
359
+ inputs [0 ].set_data_from_numpy (input_data )
360
+ inputs [1 ].set_data_from_numpy (max_sum_data )
361
+ inputs [2 ].set_data_from_numpy (ignore_cancel_data )
362
+ client .async_stream_infer (model_name , inputs )
363
+
364
+ # Check the results of the decoupled model using BLS
365
+ def check_result (result ):
366
+ # Make sure the result is not an exception
367
+ self .assertIsNot (type (result ), InferenceServerException )
368
+ is_cancelled = result .as_numpy ("IS_CANCELLED" )
369
+ self .assertTrue (
370
+ is_cancelled [0 ],
371
+ "error: expected the request to be cancelled" ,
372
+ )
332
373
333
- with self ._shm_leak_detector .Probe () as shm_probe :
334
- with grpcclient .InferenceServerClient (
335
- f"{ _tritonserver_ipaddr } :8001"
336
- ) as client :
337
- client .start_stream (callback = partial (callback , user_data ))
338
- input_data = np .array ([input_value ], dtype = np .int32 )
339
- max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
340
- inputs = [
341
- grpcclient .InferInput (
342
- "INPUT" , input_data .shape , np_to_triton_dtype (input_data .dtype )
343
- ),
344
- grpcclient .InferInput (
345
- "MAX_SUM" ,
346
- max_sum_data .shape ,
347
- np_to_triton_dtype (max_sum_data .dtype ),
348
- ),
349
- ]
350
- inputs [0 ].set_data_from_numpy (input_data )
351
- inputs [1 ].set_data_from_numpy (max_sum_data )
352
- client .async_stream_infer (model_name , inputs )
374
+ sum_data = result .as_numpy ("SUM" )
375
+ self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
376
+ self .assertTrue (
377
+ np .array_equal (sum_data , max_sum_data ),
378
+ "error: expected output {} to match input {}" .format (
379
+ sum_data , max_sum_data
380
+ ),
381
+ )
353
382
354
- # Check the results of the decoupled model using BLS
355
- def check_result (result ):
356
- # Make sure the result is not an exception
357
- self .assertIsNot (type (result ), InferenceServerException )
383
+ result = user_data ._completed_requests .get ()
384
+ check_result (result )
358
385
359
- sum_data = result .as_numpy ("SUM" )
360
- self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
361
- self .assertTrue (
362
- np .array_equal (sum_data , max_sum_data ),
363
- "error: expected output {} to match input {}" .format (
364
- sum_data , max_sum_data
386
+ def test_decoupled_bls_ignore_cancel (self ):
387
+ model_names = ["decoupled_bls_cancel" , "decoupled_bls_async_cancel" ]
388
+ input_value = 1
389
+ max_sum_value = 10
390
+ ignore_cancel = True
391
+ user_data = UserData ()
392
+ for model_name in model_names :
393
+ with self ._shm_leak_detector .Probe () as shm_probe :
394
+ with grpcclient .InferenceServerClient (
395
+ f"{ _tritonserver_ipaddr } :8001"
396
+ ) as client :
397
+ client .start_stream (callback = partial (callback , user_data ))
398
+ input_data = np .array ([input_value ], dtype = np .int32 )
399
+ max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
400
+ ignore_cancel_data = np .array ([ignore_cancel ], dtype = np .bool_ )
401
+ inputs = [
402
+ grpcclient .InferInput (
403
+ "INPUT" ,
404
+ input_data .shape ,
405
+ np_to_triton_dtype (input_data .dtype ),
365
406
),
366
- )
407
+ grpcclient .InferInput (
408
+ "MAX_SUM" ,
409
+ max_sum_data .shape ,
410
+ np_to_triton_dtype (max_sum_data .dtype ),
411
+ ),
412
+ grpcclient .InferInput (
413
+ "IGNORE_CANCEL" ,
414
+ ignore_cancel_data .shape ,
415
+ np_to_triton_dtype (ignore_cancel_data .dtype ),
416
+ ),
417
+ ]
418
+ inputs [0 ].set_data_from_numpy (input_data )
419
+ inputs [1 ].set_data_from_numpy (max_sum_data )
420
+ inputs [2 ].set_data_from_numpy (ignore_cancel_data )
421
+ client .async_stream_infer (model_name , inputs )
422
+
423
+ # Check the results of the decoupled model using BLS
424
+ def check_result (result ):
425
+ # Make sure the result is not an exception
426
+ self .assertIsNot (type (result ), InferenceServerException )
427
+ is_cancelled = result .as_numpy ("IS_CANCELLED" )
428
+ self .assertFalse (
429
+ is_cancelled [0 ],
430
+ "error: expected the request not being cancelled" ,
431
+ )
367
432
368
- result = user_data ._completed_requests .get ()
369
- check_result (result )
433
+ sum_data = result .as_numpy ("SUM" )
434
+ self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
435
+ self .assertTrue (
436
+ sum_data > max_sum_data ,
437
+ "error: expected sum_data {} to be greater than max_sum_data {}" .format (
438
+ sum_data , max_sum_data
439
+ ),
440
+ )
441
+
442
+ result = user_data ._completed_requests .get ()
443
+ check_result (result )
370
444
371
- def test_decoupled_bls_async_cancel (self ):
372
- model_name = "decoupled_bls_async_cancel "
445
+ def test_decoupled_bls_cancel_after_completion (self ):
446
+ model_name = "decoupled_bls_cancel_after_complete "
373
447
input_value = 1
374
448
max_sum_value = 10
449
+ ignore_cancel = False
375
450
user_data = UserData ()
376
-
377
451
with self ._shm_leak_detector .Probe () as shm_probe :
378
452
with grpcclient .InferenceServerClient (
379
453
f"{ _tritonserver_ipaddr } :8001"
380
454
) as client :
381
455
client .start_stream (callback = partial (callback , user_data ))
382
456
input_data = np .array ([input_value ], dtype = np .int32 )
383
457
max_sum_data = np .array ([max_sum_value ], dtype = np .int32 )
458
+ ignore_cancel_data = np .array ([ignore_cancel ], dtype = np .bool_ )
384
459
inputs = [
385
460
grpcclient .InferInput (
386
461
"INPUT" , input_data .shape , np_to_triton_dtype (input_data .dtype )
@@ -390,15 +465,25 @@ def test_decoupled_bls_async_cancel(self):
390
465
max_sum_data .shape ,
391
466
np_to_triton_dtype (max_sum_data .dtype ),
392
467
),
468
+ grpcclient .InferInput (
469
+ "IGNORE_CANCEL" ,
470
+ ignore_cancel_data .shape ,
471
+ np_to_triton_dtype (ignore_cancel_data .dtype ),
472
+ ),
393
473
]
394
474
inputs [0 ].set_data_from_numpy (input_data )
395
475
inputs [1 ].set_data_from_numpy (max_sum_data )
476
+ inputs [2 ].set_data_from_numpy (ignore_cancel_data )
396
477
client .async_stream_infer (model_name , inputs )
397
478
398
479
# Check the results of the decoupled model using BLS
399
480
def check_result (result ):
400
481
# Make sure the result is not an exception
401
482
self .assertIsNot (type (result ), InferenceServerException )
483
+ is_cancelled = result .as_numpy ("IS_CANCELLED" )
484
+ self .assertTrue (
485
+ is_cancelled [0 ], "error: expected the request to be cancelled"
486
+ )
402
487
403
488
sum_data = result .as_numpy ("SUM" )
404
489
self .assertIsNotNone (sum_data , "error: expected 'SUM'" )
0 commit comments