Spaces:
Runtime error
Runtime error
| # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Literal, Optional | |
| from camel.agents import ChatAgent | |
| logger = logging.getLogger(__name__) | |
| class BaseBenchmark(ABC): | |
| r"""Base class for benchmarks. | |
| Attributes: | |
| name (str): Name of the benchmark. | |
| data_dir (str): Path to the data directory. | |
| save_to (str): Path to save the results. | |
| processes (int): Number of processes to use for parallel | |
| processing. :(default: :obj:`1`) | |
| """ | |
| def __init__( | |
| self, name: str, data_dir: str, save_to: str, processes: int = 1 | |
| ): | |
| r"""Initialize the benchmark. | |
| Args: | |
| name (str): Name of the benchmark. | |
| data_dir (str): Path to the data directory. | |
| save_to (str): Path to save the results. | |
| processes (int): Number of processes to use for parallel | |
| processing. :(default: :obj:`1`) | |
| """ | |
| self.name = name | |
| self.data_dir = Path(data_dir) | |
| self.processes = processes | |
| self.save_to = save_to | |
| if not self.data_dir.exists(): | |
| logger.info( | |
| f"Data directory {data_dir} does not exist. Creating it." | |
| ) | |
| self.data_dir.mkdir(parents=True, exist_ok=True) | |
| if not self.data_dir.is_dir(): | |
| raise NotADirectoryError( | |
| f"Data directory {data_dir} is not a directory" | |
| ) | |
| self._data: Dict[str, List[Dict[str, Any]]] = dict() | |
| self._results: List[Dict[str, Any]] = [] | |
| def download(self) -> "BaseBenchmark": | |
| r"""Download the benchmark data. | |
| Returns: | |
| BaseBenchmark: The benchmark instance. | |
| """ | |
| pass | |
| def load(self, force_download: bool = False) -> "BaseBenchmark": | |
| r"""Load the benchmark data. | |
| Args: | |
| force_download (bool): Whether to force download the data. | |
| Returns: | |
| BaseBenchmark: The benchmark instance. | |
| """ | |
| pass | |
| def train(self) -> List[Dict[str, Any]]: | |
| r"""Get the training data. | |
| Returns: | |
| List[Dict[str, Any]]: The training data. | |
| """ | |
| if not self._data: | |
| logger.info("Data not loaded. Loading data.") | |
| self.load() | |
| return self._data["train"] | |
| def valid(self) -> List[Dict[str, Any]]: | |
| r"""Get the validation data. | |
| Returns: | |
| List[Dict[str, Any]]: The validation data. | |
| """ | |
| if not self._data: | |
| logger.info("Data not loaded. Loading data.") | |
| self.load() | |
| return self._data["valid"] | |
| def test(self) -> List[Dict[str, Any]]: | |
| r"""Get the test data. | |
| Returns: | |
| List[Dict[str, Any]]: The test data. | |
| """ | |
| if not self._data: | |
| logger.info("Data not loaded. Loading data.") | |
| self.load() | |
| return self._data["test"] | |
| def run( | |
| self, | |
| agent: ChatAgent, | |
| on: Literal["train", "valid", "test"], | |
| randomize: bool = False, | |
| subset: Optional[int] = None, | |
| *args, | |
| **kwargs, | |
| ) -> "BaseBenchmark": | |
| r"""Run the benchmark. | |
| Args: | |
| agent (ChatAgent): The chat agent. | |
| on (str): The data split to run the benchmark on. | |
| randomize (bool): Whether to randomize the data. | |
| subset (int): The subset of the data to run the benchmark on. | |
| Returns: | |
| BaseBenchmark: The benchmark instance. | |
| """ | |
| pass | |
| def results(self) -> List[Dict[str, Any]]: | |
| r"""Get the results. | |
| Returns: | |
| List[Dict[str, Any]]: The results. | |
| """ | |
| return self._results | |