Source code for ai4drpm.services.engine.token_usage_service

import logging
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)


[docs] def extract_token_usage_from_pipeline_result( pipeline_result: Dict[str, Any], component_names: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """ Extract token usage information from a Haystack pipeline execution result. Haystack generators (OpenAIGenerator, etc.) return token usage in their output's `meta` list. Each meta dict contains a `usage` dict with: - prompt_tokens - completion_tokens - total_tokens Args: pipeline_result: The result dict from Pipeline.run() component_names: Optional list of component names to extract from. If None, extracts from all components with usage data. Returns: List of dicts, each containing: - component_name: Name of the generator component - model: Model name (if available) - prompt_tokens: Input tokens - completion_tokens: Output tokens - total_tokens: Total tokens Example: pipeline_result = { "generator": { "replies": ["..."], "meta": [{ "model": "mistral-large-latest", "usage": { "prompt_tokens": 1500, "completion_tokens": 500, "total_tokens": 2000 } }] } } usage_list = extract_token_usage_from_pipeline_result(pipeline_result) # Returns: [{ # "component_name": "generator", # "model": "mistral-large-latest", # "prompt_tokens": 1500, # "completion_tokens": 500, # "total_tokens": 2000 # }] """ usage_records = [] for component_name, component_output in pipeline_result.items(): # Skip if component_names filter is set and this component isn't in it if component_names and component_name not in component_names: continue # Skip if not a dict output if not isinstance(component_output, dict): continue # Look for meta list with usage data meta_list = component_output.get("meta", []) if not isinstance(meta_list, list): continue for meta in meta_list: if not isinstance(meta, dict): continue usage = meta.get("usage") if not isinstance(usage, dict): continue # Extract token counts prompt_tokens = usage.get("prompt_tokens", 0) completion_tokens = usage.get("completion_tokens", 0) total_tokens = usage.get("total_tokens", 0) # Skip if no tokens recorded if total_tokens == 0 and prompt_tokens == 0 and completion_tokens == 0: continue record = { "component_name": component_name, "model": meta.get("model"), "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "finish_reason": meta.get("finish_reason"), } usage_records.append(record) logger.debug( f"Extracted token usage from '{component_name}': " f"{total_tokens} total ({prompt_tokens} prompt, {completion_tokens} completion)" ) return usage_records
[docs] def aggregate_token_usage(usage_records: List[Dict[str, Any]]) -> Dict[str, Any]: """ Aggregate multiple token usage records into totals. Args: usage_records: List of usage dicts from extract_token_usage_from_pipeline_result Returns: Dict with aggregated totals: - prompt_tokens: Total prompt tokens - completion_tokens: Total completion tokens - total_tokens: Total tokens - models: Set of unique models used - component_count: Number of components with usage """ if not usage_records: return { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "models": [], "component_count": 0, } total_prompt = sum(r.get("prompt_tokens", 0) for r in usage_records) total_completion = sum(r.get("completion_tokens", 0) for r in usage_records) total_tokens = sum(r.get("total_tokens", 0) for r in usage_records) models = list(set(r.get("model") for r in usage_records if r.get("model"))) return { "prompt_tokens": total_prompt, "completion_tokens": total_completion, "total_tokens": total_tokens, "models": models, "component_count": len(usage_records), }
[docs] def infer_provider_from_model(model: str) -> str: """ Infer the provider name from the model name. Args: model: Model name (e.g., "mistral-large-latest", "gpt-4") Returns: Provider name (e.g., "mistral", "openai") """ if not model: return "unknown" model_lower = model.lower() if model_lower.startswith("mistral") or model_lower.startswith("codestral"): return "mistral" elif model_lower.startswith("gpt") or model_lower.startswith("o1"): return "openai" elif model_lower.startswith("claude"): return "anthropic" elif model_lower.startswith("llama"): return "meta" elif model_lower.startswith("gemini"): return "google" else: return "unknown"