|
|
from google import genai |
|
|
from google.genai import types |
|
|
import os |
|
|
|
|
|
|
|
|
client = None |
|
|
|
|
|
def initialize(): |
|
|
""" |
|
|
Initializes the Google Generative AI client. |
|
|
""" |
|
|
global client |
|
|
|
|
|
api_key = os.environ.get("GEMINI_API_KEY") |
|
|
if not api_key: |
|
|
api_key = os.environ.get("GOOGLE_API_KEY") |
|
|
|
|
|
if not api_key: |
|
|
raise ValueError("Neither GEMINI_API_KEY nor GOOGLE_API_KEY environment variable is set.") |
|
|
|
|
|
try: |
|
|
client = genai.Client(api_key=api_key) |
|
|
print("Google Generative AI client initialized.") |
|
|
except Exception as e: |
|
|
print(f"Error initializing Google Generative AI client: {e}") |
|
|
raise |
|
|
|
|
|
def generate_content(prompt: str, model_name: str = None, allow_fallbacks: bool = True, generation_config: dict = None) -> str: |
|
|
""" |
|
|
Generates content using the Google Generative AI model. |
|
|
|
|
|
Args: |
|
|
prompt: The prompt to send to the model. |
|
|
model_name: The name of the model to use (e.g., "gemini-2.0-flash", "gemini-1.5-flash"). |
|
|
If None, a default model will be used. |
|
|
allow_fallbacks: (Currently not directly used by genai.Client.models.generate_content, |
|
|
but kept for compatibility with agent.py structure) |
|
|
generation_config: A dictionary for generation parameters like temperature, max_output_tokens. |
|
|
|
|
|
Returns: |
|
|
The generated text content. |
|
|
""" |
|
|
global client |
|
|
if client is None: |
|
|
|
|
|
print("Client not initialized. Attempting to initialize now...") |
|
|
initialize() |
|
|
if client is None: |
|
|
raise RuntimeError("Google Generative AI client is not initialized. Call initialize() first.") |
|
|
|
|
|
|
|
|
effective_model_name = model_name if model_name else "gemini-2.0-flash-lite" |
|
|
|
|
|
|
|
|
config_obj = None |
|
|
if generation_config: |
|
|
config_params = {} |
|
|
if 'temperature' in generation_config: |
|
|
config_params['temperature'] = generation_config['temperature'] |
|
|
if 'max_output_tokens' in generation_config: |
|
|
config_params['max_output_tokens'] = generation_config['max_output_tokens'] |
|
|
|
|
|
|
|
|
if config_params: |
|
|
config_obj = types.GenerateContentConfig(**config_params) |
|
|
|
|
|
try: |
|
|
response = client.models.generate_content( |
|
|
model=effective_model_name, |
|
|
contents=[prompt], |
|
|
config=config_obj |
|
|
) |
|
|
return response.text |
|
|
except Exception as e: |
|
|
print(f"Error during content generation: {e}") |
|
|
|
|
|
|
|
|
raise |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
initialize() |
|
|
if client: |
|
|
sample_prompt = "Explain how AI works in a few words" |
|
|
print(f"Sending prompt: '{sample_prompt}'") |
|
|
config = {'temperature': 0.7, 'max_output_tokens': 50} |
|
|
generated_text = generate_content(sample_prompt, generation_config=config) |
|
|
print("\nGenerated text:") |
|
|
print(generated_text) |
|
|
|
|
|
sample_prompt_2 = "What is the capital of France?" |
|
|
print(f"\nSending prompt: '{sample_prompt_2}'") |
|
|
generated_text_2 = generate_content(sample_prompt_2, model_name="gemini-2.0-flash-lite") |
|
|
print("\nGenerated text:") |
|
|
print(generated_text_2) |
|
|
except Exception as e: |
|
|
print(f"An error occurred: {e}") |