Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b80ef61
feat(agents): generator-as-protocol-adapter and tool coercion
Feb 26, 2026
723c08b
refactor(agents): replace mixin-based retry/rate-limiting with middle…
Feb 27, 2026
5b4199a
chore(agents): bump Python to 3.12+, fix tests for middleware API, ad…
Feb 27, 2026
0e25fe8
test(agents): restore sleep-time assertions in retry tests via asynci…
Feb 27, 2026
166601d
test(agents): tighten serialization test assertions per review feedback
Feb 27, 2026
198bdb9
Merge branch 'main' into feat/generator-middleware-pipeline
Hartorn Mar 2, 2026
297665c
Merge remote-tracking branch 'origin/feat/generator-middleware-pipeli…
Hartorn Mar 2, 2026
29a8240
refactor(agents): extract _call_model template method in BaseGenerator
Hartorn Mar 2, 2026
3c5c709
Merge branch 'main' into feat/generator-middleware-pipeline
mattbit Mar 3, 2026
2457430
refactor(agents): add built-in retry/rate-limiter slots with convenie…
Hartorn Mar 4, 2026
5e69aed
Merge branch 'main' into feat/generator-middleware-pipeline
Hartorn Mar 9, 2026
191dab6
Merge branch 'feat/generator-middleware-pipeline' of github.com:Giska…
Hartorn Mar 9, 2026
8278eea
fix(agents): move param validation inside try block and add _call_mod…
Hartorn Mar 9, 2026
7c87de0
Merge branch 'main' of github.com:Giskard-AI/giskard-oss into refacto…
Hartorn Mar 11, 2026
4548a22
Merge branch 'main' into refactor/generator-as-protocol-adapter
mattbit Mar 11, 2026
e23c289
refactor(generators): move serialization into LiteLLMGenerator, add G…
Hartorn Mar 11, 2026
74d7e80
fix(tests): resolve basedpyright warnings in test files
Hartorn Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ repos:
rev: v1.5.0
hooks:
- id: detect-secrets
language_version: python3.12
args: ["--baseline", ".secrets.baseline", "--exclude-secrets", "YOUR_API_KEY"]
exclude: |
(?x)^(
Expand Down
25 changes: 25 additions & 0 deletions libs/giskard-agents/.cursor/rules/generator-adapter.mdc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---
description: Generator as protocol adapter — architectural rules for provider translation in giskard-agents.
alwaysApply: true
---

# Generator as Protocol Adapter

## Core Principles

1. **Message is the canonical internal format.** `role="tool"`, `tool_call_id`, `tool_calls` are the project's standard fields. Never reshape `Message` to accommodate a specific provider.

2. **Tool.run() returns `str`.** The tool serializes its own output. The workflow wraps the string into a `Message`. The workflow never calls `json.dumps` on tool results.

3. **The generator owns all provider translation.** Each generator subclass is responsible for converting between internal `Message` objects and whatever wire format the provider expects. Serialization is internal to the subclass — `BaseGenerator` does not impose any wire format.

## Rules

- **Do not call provider APIs (litellm, anthropic, openai, etc.) from workflow, tool, or chat code.** Provider calls belong exclusively in generator subclasses.
- **Do not import provider libraries outside of generator modules.** The `generators/` package is the only place that should depend on provider SDKs.
- **`_complete` is the template method.** It merges params via `GenerationParams.merge()`, calls the abstract `_call_model`, and wraps the result in a `Response`. Do not override `_complete` in production subclasses — override `_call_model` instead.
- **`_call_model` receives internal types.** Signature: `_call_model(messages: list[Message], params: GenerationParams) -> tuple[Message, FinishReason]`. The base class never touches wire formats.
- **Serialization is the subclass's concern.** Each subclass handles translating `Message`/`Tool` to its provider's wire format inside `_call_model` (e.g. `LiteLLMGenerator._serialize_messages`). Do not add serialization methods to `BaseGenerator`.
- **Param merging lives on `GenerationParams.merge()`.** Scalar fields from overrides replace the base; tools are concatenated. Do not reimplement merge logic elsewhere.
- **Adding a new provider?** Subclass `BaseGenerator` and implement `_call_model`. Handle all serialization/deserialization inside the subclass. Do not modify `Message`, `Tool`, or workflow code.
- **Cross-cutting concerns use middleware.** Retries, rate limiting, logging, and similar concerns are implemented as `CompletionMiddleware` subclasses, not generator mixins. See `generators/middleware.py`.
11 changes: 1 addition & 10 deletions libs/giskard-agents/src/giskard/agents/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Generic, Literal, Type, TypeVar
from typing import Generic, Literal, Type, TypeVar

from litellm import Message as LiteLLMMessage
from pydantic import BaseModel, Field

from .context import RunContext
Expand Down Expand Up @@ -42,14 +41,6 @@ class Message(BaseModel):
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None

def to_litellm(self) -> dict[str, Any]:
msg = self.model_dump(include={"role", "content", "tool_calls", "tool_call_id"})
return msg

@classmethod
def from_litellm(cls, msg: LiteLLMMessage | dict[str, Any]):
return cls.model_validate(msg.model_dump()) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]

def parse(self, model_type: type[T]) -> T:
return model_type.model_validate_json(self.content) # pyright: ignore[reportArgumentType]

Expand Down
2 changes: 2 additions & 0 deletions libs/giskard-agents/src/giskard/agents/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._types import FinishReason
from .base import BaseGenerator, GenerationParams, Response
from .litellm_generator import LiteLLMGenerator, LiteLLMRetryMiddleware
from .middleware import (
Expand All @@ -11,6 +12,7 @@
Generator = LiteLLMGenerator

__all__ = [
"FinishReason",
"Generator",
"GenerationParams",
"Response",
Expand Down
32 changes: 29 additions & 3 deletions libs/giskard-agents/src/giskard/agents/generators/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from ..chat import Message
from ..tools import Tool

type FinishReason = (
Literal["stop", "length", "tool_calls", "content_filter", "null"] | None
)


class Response(BaseModel):
message: Message
finish_reason: (
Literal["stop", "length", "tool_calls", "content_filter", "null"] | None
)
finish_reason: FinishReason


class GenerationParams(BaseModel):
Expand All @@ -35,3 +37,27 @@ class GenerationParams(BaseModel):
response_format: type[BaseModel] | None = Field(default=None)
tools: list[Tool] = Field(default_factory=list)
timeout: float | int | None = Field(default=None)

def merge(self, overrides: "GenerationParams | None") -> "GenerationParams":
"""Return a copy with *overrides*' explicitly-set fields applied on top.

Scalar fields from *overrides* replace the base values (only if set).
Tools are concatenated.

Parameters
----------
overrides : GenerationParams or None
Per-call overrides. Only explicitly-set fields take effect.
If None, returns an unmodified copy.

Returns
-------
GenerationParams
A new instance with merged values.
"""
if overrides is None:
return self.model_copy()
updates = overrides.model_dump(exclude={"tools"}, exclude_unset=True)
merged = self.model_copy(update=updates)
merged.tools = self.tools + overrides.tools
return merged
40 changes: 37 additions & 3 deletions libs/giskard-agents/src/giskard/agents/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import Field

from ..chat import Message, Role
from ._types import GenerationParams, Response
from ._types import FinishReason, GenerationParams, Response
from .middleware import (
CompletionMiddleware,
NextFn,
Expand All @@ -22,17 +22,51 @@

@discriminated_base
class BaseGenerator(Discriminated, ABC):
"""Base class for all generators."""
"""Base class for all generators.

Each subclass is responsible for translating between the internal
``Message`` / ``Tool`` objects and whatever wire format its provider
expects. Workflow, tool, and chat code work exclusively with
``Message`` objects and never call provider APIs directly.
"""

params: GenerationParams = Field(default_factory=GenerationParams)
retry_policy: RetryPolicy | None = Field(default=None)
rate_limiter: BaseRateLimiter | None = Field(default=None)
middlewares: list[CompletionMiddleware] = Field(default_factory=list)

# -- Completion pipeline -----------------------------------------------

@abstractmethod
async def _call_model(
self,
messages: list[Message],
params: GenerationParams,
) -> tuple[Message, FinishReason]:
"""Call the provider and return the response as an internal Message.

Subclasses handle all serialization/deserialization internally.

Parameters
----------
messages : list[Message]
Conversation messages in internal format.
params : GenerationParams
Merged generation parameters (including tools).

Returns
-------
tuple[Message, FinishReason]
The assistant message and the finish reason.
"""
raise NotImplementedError

async def _complete(
self, messages: list[Message], params: GenerationParams | None = None
) -> Response: ...
) -> Response:
merged = self.params.merge(params)
message, finish_reason = await self._call_model(messages, merged)
return Response(message=message, finish_reason=finish_reason)

async def complete(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import cast, override
from typing import Any, cast, override

from litellm import Choices, ModelResponse, acompletion
from litellm import Message as LiteLLMMessage
from litellm import _should_retry as litellm_should_retry
from pydantic import Field

from ..chat import Message
from .base import BaseGenerator, GenerationParams, Response
from ..tools import Tool
from ._types import FinishReason, GenerationParams
from .base import BaseGenerator
from .middleware import CompletionMiddleware, RetryMiddleware, RetryPolicy


Expand Down Expand Up @@ -34,30 +35,49 @@ def _create_retry_middleware(self) -> LiteLLMRetryMiddleware | None:
return None
return LiteLLMRetryMiddleware(retry_policy=self.retry_policy)

@override
async def _complete(
self, messages: list[Message], params: GenerationParams | None = None
) -> Response:
params_ = self.params.model_dump(exclude={"tools"})
def _serialize_tools(self, tools: list[Tool]) -> list[dict[str, Any]]:
"""Convert ``Tool`` objects to the OpenAI function-calling format."""
return [
{
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters_schema,
},
}
for t in tools
]

def _serialize_messages(self, messages: list[Message]) -> list[dict[str, Any]]:
"""Convert ``Message`` objects to LiteLLM's dict format."""
return [
m.model_dump(include={"role", "content", "tool_calls", "tool_call_id"})
for m in messages
]

if params is not None:
params_.update(params.model_dump(exclude={"tools"}, exclude_unset=True))
def _deserialize_response(self, raw: Any) -> Message:
"""Convert a LiteLLM response object into an internal ``Message``."""
data = raw if isinstance(raw, dict) else raw.model_dump()
return Message.model_validate(data)

tools = self.params.tools + (params.tools if params is not None else [])
if tools:
params_["tools"] = [t.to_litellm_function() for t in tools]
@override
async def _call_model(
self,
messages: list[Message],
params: GenerationParams,
) -> tuple[Message, FinishReason]:
wire_messages = self._serialize_messages(messages)
wire_params = params.model_dump(exclude={"tools"})
wire_tools = self._serialize_tools(params.tools) if params.tools else []
if wire_tools:
wire_params["tools"] = wire_tools

response = cast(
ModelResponse,
await acompletion(
messages=[m.to_litellm() for m in messages],
model=self.model,
**params_,
),
await acompletion(messages=wire_messages, model=self.model, **wire_params),
)

choice = cast(Choices, response.choices[0])
return Response(
message=Message.from_litellm(cast(LiteLLMMessage, choice.message)),
finish_reason=choice.finish_reason, # pyright: ignore[reportArgumentType]
)
message = self._deserialize_response(choice.message)
return message, choice.finish_reason # pyright: ignore[reportReturnType]
75 changes: 34 additions & 41 deletions libs/giskard-agents/src/giskard/agents/tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Core tool functionality for Giskard Agents."""

import inspect
import json
from typing import Any, Callable, Literal, TypeVar

import logfire_api as logfire
Expand Down Expand Up @@ -139,11 +140,23 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
@logfire.instrument("tool.run")
async def run(
self, arguments: dict[str, Any], ctx: RunContext | None = None
) -> Any:
"""Run the tool's function asynchronously.
) -> str:
"""Run the tool's function and return a serialized string result.

This method handles both sync and async functions by awaiting async
functions. Errors are handled based on ``self.catch``.
Executes the underlying function (sync or async), handles errors via
``self.catch``, and serializes the result to a string suitable for use
as ``Message.content``.

Input coercion: if a ``_params_model`` is available (set by
``from_callable``), raw dict arguments are validated through the
Pydantic model so that e.g. nested dicts become typed ``BaseModel``
instances before the function is called.

Output serialization: if a ``_return_adapter`` is available, the result
is serialized to JSON-safe Python via ``TypeAdapter.dump_python`` (this
handles ``BaseModel``, ``datetime``, ``UUID``, ``list[BaseModel]``,
etc.). String results are returned as-is; everything else is
``json.dumps``'d.

Parameters
----------
Expand All @@ -154,28 +167,25 @@ async def run(

Returns
-------
Any
The result of calling the function.
str
The serialized result of calling the function.
"""

# Coerce dict arguments into typed objects via the Pydantic params model.
# We use getattr() instead of model_dump() to preserve coerced types
# (e.g. a raw dict becomes a BaseModel instance). Extra keys that are
# not in model_fields are dropped (Pydantic defaults to extra='ignore').
if self._params_model is not None:
validated = self._params_model.model_validate(arguments)
arguments = {
name: getattr(validated, name)
for name in arguments
if name in self._params_model.model_fields
}

# Inject the context after coercion (RunContext is excluded from the model)
if ctx and self.run_context_param:
arguments = arguments.copy()
arguments[self.run_context_param] = ctx

try:
# Coerce dict arguments into typed objects via the Pydantic params model.
# Extra keys not in model_fields are dropped (Pydantic extra='ignore').
if self._params_model is not None:
validated = self._params_model.model_validate(arguments)
arguments = {
name: getattr(validated, name)
for name in self._params_model.model_fields
if name in arguments
}

if ctx and self.run_context_param:
arguments = arguments.copy()
arguments[self.run_context_param] = ctx

res = self.fn(**arguments)
if inspect.isawaitable(res):
res = await res
Expand All @@ -192,24 +202,7 @@ async def run(
if self._return_adapter is not None:
res = self._return_adapter.dump_python(res, mode="json")

return res

def to_litellm_function(self) -> dict[str, Any]:
"""Convert the tool to a LiteLLM function format.

Returns
-------
dict[str, Any]
A dictionary in the LiteLLM function format.
"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters_schema,
},
}
return res if isinstance(res, str) else json.dumps(res)


class ToolMethod:
Expand Down
4 changes: 2 additions & 2 deletions libs/giskard-agents/src/giskard/agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ async def _run_tools(self, chat: Chat[Any]) -> AsyncGenerator[Message, None]:
continue # TODO: raise an error?

tool = self._workflow.tools[tool_call.function.name]
tool_response = await tool.run(
tool_content = await tool.run(
json.loads(tool_call.function.arguments),
ctx=chat.context,
)
yield Message(
role="tool",
tool_call_id=tool_call.id,
content=json.dumps(tool_response),
content=tool_content,
)

async def _run_completion(self, chat: Chat[Any]) -> Message:
Expand Down
Loading