@@ -126,7 +126,9 @@ async def single_agent_async[T: BaseModel](
126126 session_id = session .session_id ,
127127 )
128128
129- cohort_task = EveryrowTask (response_model = response_model , is_map = False , is_expand = return_table )
129+ cohort_task = EveryrowTask (
130+ response_model = response_model , is_map = False , is_expand = return_table
131+ )
130132 await cohort_task .submit (body , session .client )
131133 return cohort_task
132134
@@ -140,7 +142,9 @@ async def agent_map(
140142 response_model : type [BaseModel ] = DefaultAgentResponse ,
141143 return_table_per_row : bool = False ,
142144) -> TableResult :
143- cohort_task = await agent_map_async (task , session , input , effort_level , llm , response_model , return_table_per_row )
145+ cohort_task = await agent_map_async (
146+ task , session , input , effort_level , llm , response_model , return_table_per_row
147+ )
144148 result = await cohort_task .await_result (session .client )
145149 if isinstance (result , TableResult ):
146150 return result
@@ -240,7 +244,9 @@ async def agent_map_async(
240244 session_id = session .session_id ,
241245 )
242246
243- cohort_task = EveryrowTask (response_model = response_model , is_map = True , is_expand = return_table_per_row )
247+ cohort_task = EveryrowTask (
248+ response_model = response_model , is_map = True , is_expand = return_table_per_row
249+ )
244250 await cohort_task .submit (body , session .client )
245251 return cohort_task
246252
@@ -283,7 +289,9 @@ async def create_scalar_artifact(input: BaseModel, session: Session) -> UUID:
283289
284290
285291async def create_table_artifact (input : DataFrame , session : Session ) -> UUID :
286- payload = CreateGroupRequest (query = CreateGroupQueryParams (data_to_create = input .to_dict (orient = "records" )))
292+ payload = CreateGroupRequest (
293+ query = CreateGroupQueryParams (data_to_create = input .to_dict (orient = "records" ))
294+ )
287295 body = SubmitTaskBody (
288296 payload = payload ,
289297 session_id = session .session_id ,
@@ -371,12 +379,13 @@ async def merge_async(
371379 return cohort_task
372380
373381
374- async def rank (
382+ async def rank [ T : BaseModel ] (
375383 task : str ,
376384 session : Session ,
377385 input : DataFrame | UUID | TableResult ,
378386 field_name : str ,
379387 field_type : Literal ["float" , "int" , "str" , "bool" ] = "float" ,
388+ response_model : type [T ] | None = None ,
380389 ascending_order : bool = True ,
381390 preview : bool = False ,
382391) -> TableResult :
@@ -387,7 +396,8 @@ async def rank(
387396 session: The session to use
388397 input: The input table (DataFrame, UUID, or TableResult)
389398 field_name: The name of the field to extract and sort by
390- field_type: The type of the field (default: "float")
399+ field_type: The type of the field (default: "float", ignored if response_model is provided)
400+ response_model: Optional Pydantic model for the response schema
391401 ascending_order: If True, sort in ascending order
392402 preview: If True, process only the first few inputs
393403
@@ -400,6 +410,7 @@ async def rank(
400410 input = input ,
401411 field_name = field_name ,
402412 field_type = field_type ,
413+ response_model = response_model ,
403414 ascending_order = ascending_order ,
404415 preview = preview ,
405416 )
@@ -410,26 +421,33 @@ async def rank(
410421 raise EveryrowError ("Rank task did not return a table result" )
411422
412423
413- async def rank_async (
424+ async def rank_async [ T : BaseModel ] (
414425 task : str ,
415426 session : Session ,
416427 input : DataFrame | UUID | TableResult ,
417428 field_name : str ,
418429 field_type : Literal ["float" , "int" , "str" , "bool" ] = "float" ,
430+ response_model : type [T ] | None = None ,
419431 ascending_order : bool = True ,
420432 preview : bool = False ,
421- ) -> EveryrowTask [BaseModel ]:
433+ ) -> EveryrowTask [T ]:
422434 """Submit a rank task asynchronously."""
423435 input_artifact_id = await _process_agent_map_input (input , session )
424436
425- # Build response schema with single field
426- response_schema = {
427- "_model_name" : "RankResponse" ,
428- field_name : {
429- "type" : field_type ,
430- "optional" : False ,
431- },
432- }
437+ if response_model is not None :
438+ response_schema = _convert_pydantic_to_custom_schema (response_model )
439+ if field_name not in response_schema :
440+ raise ValueError (
441+ f"Field { field_name } not in response model { response_model .__name__ } "
442+ )
443+ else :
444+ response_schema = {
445+ "_model_name" : "RankResponse" ,
446+ field_name : {
447+ "type" : field_type ,
448+ "optional" : False ,
449+ },
450+ }
433451
434452 query = DeepRankPublicParams (
435453 task = task ,
@@ -448,7 +466,11 @@ async def rank_async(
448466 session_id = session .session_id ,
449467 )
450468
451- cohort_task = EveryrowTask (response_model = BaseModel , is_map = True , is_expand = False )
469+ cohort_task : EveryrowTask [T ] = EveryrowTask (
470+ response_model = response_model or BaseModel , # type: ignore[arg-type]
471+ is_map = True ,
472+ is_expand = False ,
473+ )
452474 await cohort_task .submit (body , session .client )
453475 return cohort_task
454476
@@ -625,7 +647,8 @@ async def derive(
625647 input_artifact_id = await _process_agent_map_input (input , session )
626648
627649 derive_expressions = [
628- DeriveExpression (column_name = col_name , expression = expr ) for col_name , expr in expressions .items ()
650+ DeriveExpression (column_name = col_name , expression = expr )
651+ for col_name , expr in expressions .items ()
629652 ]
630653
631654 query = DeriveQueryParams (expressions = derive_expressions )
0 commit comments