Source code for ai4drpm.services.engine.token_usage_service
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
[docs]
def extract_token_usage_from_pipeline_result(
pipeline_result: Dict[str, Any],
component_names: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Extract token usage information from a Haystack pipeline execution result.
Haystack generators (OpenAIGenerator, etc.) return token usage in their
output's `meta` list. Each meta dict contains a `usage` dict with:
- prompt_tokens
- completion_tokens
- total_tokens
Args:
pipeline_result: The result dict from Pipeline.run()
component_names: Optional list of component names to extract from.
If None, extracts from all components with usage data.
Returns:
List of dicts, each containing:
- component_name: Name of the generator component
- model: Model name (if available)
- prompt_tokens: Input tokens
- completion_tokens: Output tokens
- total_tokens: Total tokens
Example:
pipeline_result = {
"generator": {
"replies": ["..."],
"meta": [{
"model": "mistral-large-latest",
"usage": {
"prompt_tokens": 1500,
"completion_tokens": 500,
"total_tokens": 2000
}
}]
}
}
usage_list = extract_token_usage_from_pipeline_result(pipeline_result)
# Returns: [{
# "component_name": "generator",
# "model": "mistral-large-latest",
# "prompt_tokens": 1500,
# "completion_tokens": 500,
# "total_tokens": 2000
# }]
"""
usage_records = []
for component_name, component_output in pipeline_result.items():
# Skip if component_names filter is set and this component isn't in it
if component_names and component_name not in component_names:
continue
# Skip if not a dict output
if not isinstance(component_output, dict):
continue
# Look for meta list with usage data
meta_list = component_output.get("meta", [])
if not isinstance(meta_list, list):
continue
for meta in meta_list:
if not isinstance(meta, dict):
continue
usage = meta.get("usage")
if not isinstance(usage, dict):
continue
# Extract token counts
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
# Skip if no tokens recorded
if total_tokens == 0 and prompt_tokens == 0 and completion_tokens == 0:
continue
record = {
"component_name": component_name,
"model": meta.get("model"),
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"finish_reason": meta.get("finish_reason"),
}
usage_records.append(record)
logger.debug(
f"Extracted token usage from '{component_name}': "
f"{total_tokens} total ({prompt_tokens} prompt, {completion_tokens} completion)"
)
return usage_records
[docs]
def aggregate_token_usage(usage_records: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Aggregate multiple token usage records into totals.
Args:
usage_records: List of usage dicts from extract_token_usage_from_pipeline_result
Returns:
Dict with aggregated totals:
- prompt_tokens: Total prompt tokens
- completion_tokens: Total completion tokens
- total_tokens: Total tokens
- models: Set of unique models used
- component_count: Number of components with usage
"""
if not usage_records:
return {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"models": [],
"component_count": 0,
}
total_prompt = sum(r.get("prompt_tokens", 0) for r in usage_records)
total_completion = sum(r.get("completion_tokens", 0) for r in usage_records)
total_tokens = sum(r.get("total_tokens", 0) for r in usage_records)
models = list(set(r.get("model") for r in usage_records if r.get("model")))
return {
"prompt_tokens": total_prompt,
"completion_tokens": total_completion,
"total_tokens": total_tokens,
"models": models,
"component_count": len(usage_records),
}
[docs]
def infer_provider_from_model(model: str) -> str:
"""
Infer the provider name from the model name.
Args:
model: Model name (e.g., "mistral-large-latest", "gpt-4")
Returns:
Provider name (e.g., "mistral", "openai")
"""
if not model:
return "unknown"
model_lower = model.lower()
if model_lower.startswith("mistral") or model_lower.startswith("codestral"):
return "mistral"
elif model_lower.startswith("gpt") or model_lower.startswith("o1"):
return "openai"
elif model_lower.startswith("claude"):
return "anthropic"
elif model_lower.startswith("llama"):
return "meta"
elif model_lower.startswith("gemini"):
return "google"
else:
return "unknown"