Skip to content

Commit b384ca9

Browse files
authored
Optional response model for Rank (#4)
Users can now specify the response model for Rank in case they want multiple output columns. It's optional, so they can still specify only a column name.
1 parent e1e8a48 commit b384ca9

4 files changed

Lines changed: 162 additions & 42 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ from everyrow_sdk.ops import rank
5757

5858
result = await rank(
5959
session=session,
60-
task="Rank organizations by their contribution to AI research",
60+
task="Score this organization by their contribution to AI research",
6161
input=dataframe,
6262
field_name="contribution_score",
6363
ascending_order=False,
@@ -164,7 +164,7 @@ from everyrow_sdk.ops import rank_async
164164

165165
task = await rank_async(
166166
session=session,
167-
task="Rank organizations",
167+
task="Score this organization",
168168
input=dataframe,
169169
field_name="score",
170170
)

examples/rank_example.py

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,82 @@
33
from textwrap import dedent
44

55
from pandas import DataFrame
6+
from pydantic import BaseModel, Field
67

78
from everyrow_sdk import create_client, create_session
89
from everyrow_sdk.ops import rank
910
from everyrow_sdk.session import Session
1011

1112

13+
class ContributionRanking(BaseModel):
14+
contribution_score: float = Field(
15+
description="Score from 0-100 reflecting contribution"
16+
)
17+
most_significant_contribution: str = Field(
18+
description="Single most significant contribution"
19+
)
20+
21+
1222
async def call_rank(session: Session):
1323
# Rank AI research organizations by their contributions to the field
1424
# This requires researching each org's publications, releases, and impact
1525
ai_research_orgs = DataFrame(
1626
[
17-
{"organization": "OpenAI", "type": "Private lab", "founded": 2015},
27+
{
28+
"organization": "OpenAI",
29+
"type": "Private lab",
30+
"founded": 2015,
31+
},
1832
{
1933
"organization": "Google DeepMind",
2034
"type": "Corporate lab",
2135
"founded": 2010,
2236
},
23-
{"organization": "Anthropic", "type": "Private lab", "founded": 2021},
24-
{"organization": "Meta FAIR", "type": "Corporate lab", "founded": 2013},
37+
{
38+
"organization": "Anthropic",
39+
"type": "Private lab",
40+
"founded": 2021,
41+
},
42+
{
43+
"organization": "Meta FAIR",
44+
"type": "Corporate lab",
45+
"founded": 2013,
46+
},
2547
{
2648
"organization": "Microsoft Research",
2749
"type": "Corporate lab",
2850
"founded": 1991,
2951
},
30-
{"organization": "Stanford HAI", "type": "Academic", "founded": 2019},
31-
{"organization": "MIT CSAIL", "type": "Academic", "founded": 2003},
52+
{
53+
"organization": "Stanford HAI",
54+
"type": "Academic",
55+
"founded": 2019,
56+
},
57+
{
58+
"organization": "MIT CSAIL",
59+
"type": "Academic",
60+
"founded": 2003,
61+
},
3262
{
3363
"organization": "Berkeley AI Research",
3464
"type": "Academic",
3565
"founded": 2010,
3666
},
37-
{"organization": "Mistral AI", "type": "Private lab", "founded": 2023},
38-
{"organization": "xAI", "type": "Private lab", "founded": 2023},
39-
{"organization": "Cohere", "type": "Private lab", "founded": 2019},
67+
{
68+
"organization": "Mistral AI",
69+
"type": "Private lab",
70+
"founded": 2023,
71+
},
72+
{
73+
"organization": "xAI",
74+
"type": "Private lab",
75+
"founded": 2023,
76+
},
77+
{
78+
"organization": "Cohere",
79+
"type": "Private lab",
80+
"founded": 2019,
81+
},
4082
{
4183
"organization": "Allen Institute for AI",
4284
"type": "Non-profit",
@@ -45,22 +87,26 @@ async def call_rank(session: Session):
4587
]
4688
)
4789

48-
result = await rank(
49-
session=session,
50-
task=dedent("""
51-
Rank these AI research organizations by their overall contribution to
52-
advancing large language models and generative AI in the past 2 years.
90+
task = dedent("""
91+
Score the given AI research organization by their overall contribution to
92+
advancing large language models and generative AI in the past 2 years.
93+
94+
Consider factors such as:
95+
- Influential model releases (both open and closed source)
96+
- Important research papers and technical breakthroughs
97+
- Impact on the broader AI ecosystem (open source contributions,
98+
techniques that others have adopted)
99+
- Novel capabilities introduced
53100
54-
Consider factors such as:
55-
- Influential model releases (both open and closed source)
56-
- Important research papers and technical breakthroughs
57-
- Impact on the broader AI ecosystem (open source contributions,
58-
techniques that others have adopted)
59-
- Novel capabilities introduced
101+
Assign a score from 0-100 reflecting their relative contribution,
102+
where 100 represents the most impactful organization.
103+
""")
60104

61-
Assign a score from 0-100 reflecting their relative contribution,
62-
where 100 represents the most impactful organization.
63-
"""),
105+
# Example 1: Basic ranking with a single score field
106+
print("Example 1: Basic ranking")
107+
result = await rank(
108+
session=session,
109+
task=task,
64110
input=ai_research_orgs,
65111
field_name="contribution_score",
66112
ascending_order=False,
@@ -69,6 +115,21 @@ async def call_rank(session: Session):
69115
print(result.data.to_string())
70116
print(f"\nArtifact ID: {result.artifact_id}")
71117

118+
# Example 2: Ranking with a custom response model for additional context
119+
print("\n" + "=" * 80)
120+
print("Example 2: Ranking with detailed response model")
121+
detailed_result = await rank(
122+
session=session,
123+
task=task + "\n\nAlso include their single most significant contribution.",
124+
input=ai_research_orgs,
125+
field_name="contribution_score",
126+
response_model=ContributionRanking,
127+
ascending_order=False,
128+
)
129+
print("Detailed Rankings with Context:")
130+
print(detailed_result.data.to_string())
131+
print(f"\nArtifact ID: {detailed_result.artifact_id}")
132+
72133

73134
async def main():
74135
async with create_client() as client:

src/everyrow_sdk/ops.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ async def single_agent_async[T: BaseModel](
126126
session_id=session.session_id,
127127
)
128128

129-
cohort_task = EveryrowTask(response_model=response_model, is_map=False, is_expand=return_table)
129+
cohort_task = EveryrowTask(
130+
response_model=response_model, is_map=False, is_expand=return_table
131+
)
130132
await cohort_task.submit(body, session.client)
131133
return cohort_task
132134

@@ -140,7 +142,9 @@ async def agent_map(
140142
response_model: type[BaseModel] = DefaultAgentResponse,
141143
return_table_per_row: bool = False,
142144
) -> TableResult:
143-
cohort_task = await agent_map_async(task, session, input, effort_level, llm, response_model, return_table_per_row)
145+
cohort_task = await agent_map_async(
146+
task, session, input, effort_level, llm, response_model, return_table_per_row
147+
)
144148
result = await cohort_task.await_result(session.client)
145149
if isinstance(result, TableResult):
146150
return result
@@ -240,7 +244,9 @@ async def agent_map_async(
240244
session_id=session.session_id,
241245
)
242246

243-
cohort_task = EveryrowTask(response_model=response_model, is_map=True, is_expand=return_table_per_row)
247+
cohort_task = EveryrowTask(
248+
response_model=response_model, is_map=True, is_expand=return_table_per_row
249+
)
244250
await cohort_task.submit(body, session.client)
245251
return cohort_task
246252

@@ -283,7 +289,9 @@ async def create_scalar_artifact(input: BaseModel, session: Session) -> UUID:
283289

284290

285291
async def create_table_artifact(input: DataFrame, session: Session) -> UUID:
286-
payload = CreateGroupRequest(query=CreateGroupQueryParams(data_to_create=input.to_dict(orient="records")))
292+
payload = CreateGroupRequest(
293+
query=CreateGroupQueryParams(data_to_create=input.to_dict(orient="records"))
294+
)
287295
body = SubmitTaskBody(
288296
payload=payload,
289297
session_id=session.session_id,
@@ -371,12 +379,13 @@ async def merge_async(
371379
return cohort_task
372380

373381

374-
async def rank(
382+
async def rank[T: BaseModel](
375383
task: str,
376384
session: Session,
377385
input: DataFrame | UUID | TableResult,
378386
field_name: str,
379387
field_type: Literal["float", "int", "str", "bool"] = "float",
388+
response_model: type[T] | None = None,
380389
ascending_order: bool = True,
381390
preview: bool = False,
382391
) -> TableResult:
@@ -387,7 +396,8 @@ async def rank(
387396
session: The session to use
388397
input: The input table (DataFrame, UUID, or TableResult)
389398
field_name: The name of the field to extract and sort by
390-
field_type: The type of the field (default: "float")
399+
field_type: The type of the field (default: "float", ignored if response_model is provided)
400+
response_model: Optional Pydantic model for the response schema
391401
ascending_order: If True, sort in ascending order
392402
preview: If True, process only the first few inputs
393403
@@ -400,6 +410,7 @@ async def rank(
400410
input=input,
401411
field_name=field_name,
402412
field_type=field_type,
413+
response_model=response_model,
403414
ascending_order=ascending_order,
404415
preview=preview,
405416
)
@@ -410,26 +421,33 @@ async def rank(
410421
raise EveryrowError("Rank task did not return a table result")
411422

412423

413-
async def rank_async(
424+
async def rank_async[T: BaseModel](
414425
task: str,
415426
session: Session,
416427
input: DataFrame | UUID | TableResult,
417428
field_name: str,
418429
field_type: Literal["float", "int", "str", "bool"] = "float",
430+
response_model: type[T] | None = None,
419431
ascending_order: bool = True,
420432
preview: bool = False,
421-
) -> EveryrowTask[BaseModel]:
433+
) -> EveryrowTask[T]:
422434
"""Submit a rank task asynchronously."""
423435
input_artifact_id = await _process_agent_map_input(input, session)
424436

425-
# Build response schema with single field
426-
response_schema = {
427-
"_model_name": "RankResponse",
428-
field_name: {
429-
"type": field_type,
430-
"optional": False,
431-
},
432-
}
437+
if response_model is not None:
438+
response_schema = _convert_pydantic_to_custom_schema(response_model)
439+
if field_name not in response_schema:
440+
raise ValueError(
441+
f"Field {field_name} not in response model {response_model.__name__}"
442+
)
443+
else:
444+
response_schema = {
445+
"_model_name": "RankResponse",
446+
field_name: {
447+
"type": field_type,
448+
"optional": False,
449+
},
450+
}
433451

434452
query = DeepRankPublicParams(
435453
task=task,
@@ -448,7 +466,11 @@ async def rank_async(
448466
session_id=session.session_id,
449467
)
450468

451-
cohort_task = EveryrowTask(response_model=BaseModel, is_map=True, is_expand=False)
469+
cohort_task: EveryrowTask[T] = EveryrowTask(
470+
response_model=response_model or BaseModel, # type: ignore[arg-type]
471+
is_map=True,
472+
is_expand=False,
473+
)
452474
await cohort_task.submit(body, session.client)
453475
return cohort_task
454476

@@ -625,7 +647,8 @@ async def derive(
625647
input_artifact_id = await _process_agent_map_input(input, session)
626648

627649
derive_expressions = [
628-
DeriveExpression(column_name=col_name, expression=expr) for col_name, expr in expressions.items()
650+
DeriveExpression(column_name=col_name, expression=expr)
651+
for col_name, expr in expressions.items()
629652
]
630653

631654
query = DeriveQueryParams(expressions=derive_expressions)

tests/test_ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from everyrow_sdk.ops import (
1717
agent_map,
1818
create_scalar_artifact,
19+
rank_async,
1920
single_agent,
2021
)
2122
from everyrow_sdk.result import ScalarResult, TableResult
@@ -316,3 +317,38 @@ async def test_agent_map_with_table_output(mocker, mock_session):
316317
assert isinstance(result, TableResult)
317318
assert len(result.data) == 2
318319
assert result.artifact_id == artifact_id
320+
321+
322+
@pytest.mark.asyncio
323+
async def test_rank_model_validation(mocker, mock_session) -> None:
324+
input_df = pd.DataFrame(
325+
[
326+
{"country": "China"},
327+
{"country": "India"},
328+
{"country": "Indonesia"},
329+
{"country": "Pakistan"},
330+
{"country": "USA"},
331+
],
332+
)
333+
334+
class ResponseModel(BaseModel):
335+
population_size: int
336+
337+
input_artifact_id = uuid.uuid4()
338+
# Mock create_table_artifact (called because input is DataFrame)
339+
mock_create_table = mocker.patch(
340+
"everyrow_sdk.ops.create_table_artifact", new_callable=AsyncMock
341+
)
342+
mock_create_table.return_value = input_artifact_id
343+
344+
with pytest.raises(
345+
ValueError,
346+
match="Field population not in response model ResponseModel",
347+
):
348+
await rank_async(
349+
task="Find the population of the given country",
350+
session=mock_session,
351+
input=input_df,
352+
field_name="population",
353+
response_model=ResponseModel,
354+
)

0 commit comments

Comments
 (0)