"""OND protocol."""
from sail_on_client.agent.ond_agent import ONDAgent
from sail_on_client.harness.test_and_evaluation_harness import (
TestAndEvaluationHarnessType,
)
from sail_on_client.protocol.visual_protocol import VisualProtocol
from sail_on_client.utils.numpy_encoder import NumpyEncoder
from sail_on_client.protocol.ond_dataclasses import AlgorithmAttributes
from sail_on_client.protocol.ond_test import ONDTest
from sail_on_client.utils.decorators import skip_stage
import os
import json
import logging
from typing import Dict, List, Optional
log = logging.getLogger(__name__)
[docs]class ONDProtocol(VisualProtocol):
"""OND protocol."""
[docs] def __init__(
self,
algorithms: Dict[str, ONDAgent],
dataset_root: str,
domain: str,
harness: TestAndEvaluationHarnessType,
save_dir: str,
seed: str,
test_ids: List[str],
baseline_class: str = "",
feature_extraction_only: bool = False,
has_baseline: bool = False,
has_reaction_baseline: bool = False,
hints: List = None,
is_eval_enabled: bool = False,
is_eval_roundwise_enabled: bool = False,
resume_session: bool = False,
resume_session_ids: Dict = None,
save_attributes: bool = False,
saved_attributes: Dict = None,
save_elementwise: bool = False,
save_features: bool = False,
feature_dir: str = "",
skip_stages: List = None,
use_feedback: bool = False,
feedback_type: str = "classification",
use_consolidated_features: bool = False,
use_saved_attributes: bool = False,
use_saved_features: bool = False,
) -> None:
"""
Construct OND protocol.
Args:
algorithms: Dictionary of algorithms that are used run based on the protocol
baseline_class: Name of the baseline class
dataset_root: Root directory of the dataset
domain: Domain of the problem
save_dir: Directory where results are saved
seed: Seed for the experiments
feedback_type: Type of feedback
test_ids: List of tests
feature_extraction_only: Flag to only run feature extraction
has_baseline: Flag to check if the session has baseline
has_reaction_baseline: Flag to check if the session has reaction baseline
hints: List of hint provided in the session
harness: A harness for test and evaluation
is_eval_enabled: Flag to check if evaluation is enabled in session
is_eval_roundwise_enabled: Flag to check if evaluation is enabled for rounds
resume_session: Flag to resume session
resume_session_ids: Dictionary for resuming sessions
save_attributes: Flag to save attributes
saved_attributes: Dictionary for attributes
save_elementwise: Flag to save features elementwise
save_features: Flag to save features
feature_dir: Directory to save features
skip_stages: List of stages that are skipped
use_feedback: Flag to use feedback
use_saved_attributes: Flag to use saved attributes
use_saved_features: Flag to use saved features
Returns:
None
"""
super().__init__(algorithms, harness)
self.baseline_class = baseline_class
self.dataset_root = dataset_root
self.domain = domain
self.feature_extraction_only = feature_extraction_only
self.feature_dir = feature_dir
self.feedback_type = feedback_type
self.has_baseline = has_baseline
self.has_reaction_baseline = has_reaction_baseline
if hints is None:
self.hints = []
else:
self.hints = hints
self.is_eval_enabled = is_eval_enabled
self.is_eval_roundwise_enabled = is_eval_roundwise_enabled
self.resume_session = resume_session
if resume_session_ids is None:
self.resume_session_ids = {}
else:
self.resume_session_ids = resume_session_ids
self.save_attributes = save_attributes
if saved_attributes is None:
self.saved_attributes = {}
else:
self.saved_attributes = saved_attributes
self.save_dir = save_dir
self.save_elementwise = save_elementwise
self.save_features = save_features
if skip_stages is None:
self.skip_stages = []
else:
self.skip_stages = skip_stages
self.seed = seed
self.test_ids = test_ids
self.use_consolidated_features = use_consolidated_features
self.use_feedback = use_feedback
self.use_saved_attributes = use_saved_attributes
self.use_saved_features = use_saved_features
[docs] def get_config(self) -> Dict:
"""Get dictionary representation of the object."""
config = super().get_config()
config.update(
{
"baseline_class": self.baseline_class,
"dataset_root": self.dataset_root,
"domain": self.domain,
"feature_extraction_only": self.feature_extraction_only,
"feature_dir": self.feature_dir,
"feedback_type": self.feedback_type,
"has_baseline": self.has_baseline,
"has_reaction_baseline": self.has_reaction_baseline,
"hints": self.hints,
"is_eval_enabled": self.is_eval_enabled,
"is_eval_roundwise_enabled": self.is_eval_roundwise_enabled,
"resume_session": self.resume_session,
"resume_session_ids": self.resume_session_ids,
"save_attributes": self.save_attributes,
"saved_attributes": self.saved_attributes,
"save_dir": self.save_dir,
"save_elementwise": self.save_elementwise,
"save_features": self.save_features,
"skip_stages": self.skip_stages,
"seed": self.seed,
"test_ids": self.test_ids,
"use_feedback": self.use_feedback,
"use_saved_attributes": self.use_saved_attributes,
"use_saved_features": self.use_saved_features,
}
)
return config
[docs] def create_algorithm_attributes(
self,
algorithm_name: str,
algorithm_param: Dict,
baseline_algorithm_name: str,
has_baseline: bool,
has_reaction_baseline: bool,
test_ids: List[str],
) -> AlgorithmAttributes:
"""
Create an instance of algorithm attributes.
Args:
algorithm_name: Name of the algorithm
algorithm_param: Parameters for the algorithm
baseline_algorithm_name: Name of the baseline algorithm
has_baseline: Flag to check if a baseline is present in the config
has_reaction_baseline: Flag to check if a reaction baseline is present in the config
test_ids: List of test
Returns:
An instance of AlgorithmAttributes
"""
algorithm_instance = self.algorithms[algorithm_name]
is_baseline = algorithm_name == baseline_algorithm_name
session_id = self.resume_session_ids.get(algorithm_name, "")
return AlgorithmAttributes(
algorithm_name,
algorithm_param.get("detection_threshold", 0.5),
algorithm_instance,
has_baseline and is_baseline,
has_reaction_baseline and is_baseline,
algorithm_param.get("package_name", None),
algorithm_param,
session_id,
test_ids,
)
[docs] def create_algorithm_session(
self,
algorithm_attributes: AlgorithmAttributes,
domain: str,
hints: List[str],
has_a_session: bool,
protocol_name: str,
) -> AlgorithmAttributes:
"""
Create/resume session for an algorithm.
Args:
algorithm_attributes: An instance of AlgorithmAttributes
domain: Domain for the algorithm
hints: List of hints used in the session
has_a_session: Already has a session and we want to resume it
protocol_name: Name of the algorithm
Returns:
An AlgorithmAttributes object with updated session id or test id
"""
test_ids = algorithm_attributes.test_ids
named_version = algorithm_attributes.named_version()
detection_threshold = algorithm_attributes.detection_threshold
if has_a_session:
session_id = algorithm_attributes.session_id
finished_test = self.harness.resume_session(session_id)
algorithm_attributes.remove_completed_tests(finished_test)
log.info(f"Resumed session {session_id} for {algorithm_attributes.name}")
else:
session_id = self.harness.session_request(
test_ids,
protocol_name,
domain,
named_version,
hints,
detection_threshold,
)
algorithm_attributes.session_id = session_id
log.info(f"Created session {session_id} for {algorithm_attributes.name}")
return algorithm_attributes
def _find_baseline_session_id(
self, algorithms_attributes: List[AlgorithmAttributes]
) -> str:
"""
Find baseline session id based on the attributes of algorithms.
Args:
algorithms_attributes: List of algorithm attributes
Returns:
Baseline session id
"""
for algorithm_attributes in algorithms_attributes:
if (
algorithm_attributes.is_baseline
or algorithm_attributes.is_reaction_baseline
):
return algorithm_attributes.session_id
raise Exception(
"Failed to find baseline, this is required to compute reaction perfomance"
)
[docs] @skip_stage("EvaluateAlgorithms")
def _evaluate_algorithms(
self,
algorithms_attributes: List[AlgorithmAttributes],
algorithm_scores: Dict,
save_dir: str,
) -> None:
"""
Evaluate algorithms after all tests have been submitted.
Args:
algorithms_attributes: All algorithms present in the config
algorithm_scores: Scores for round wise evaluation
save_dir: Directory where the scores are stored
Returns:
None
"""
if self.has_baseline or self.has_reaction_baseline:
baseline_session_id: Optional[str] = self._find_baseline_session_id(
algorithms_attributes
)
else:
baseline_session_id = None
for algorithm_attributes in algorithms_attributes:
if (
algorithm_attributes.is_baseline
or algorithm_attributes.is_reaction_baseline
):
continue
session_id = algorithm_attributes.session_id
test_ids = algorithm_attributes.test_ids
algorithm_name = algorithm_attributes.name
test_scores = algorithm_scores[algorithm_name]
log.info(f"Started evaluating {algorithm_name}")
for test_id in test_ids:
score = self.harness.evaluate(
test_id, 0, session_id, baseline_session_id
)
score.update(test_scores[test_id])
with open(
os.path.join(save_dir, f"{test_id}_{algorithm_name}.json"), "w"
) as f: # type: ignore
log.info(f"Saving results in {save_dir}")
json.dump(score, f, indent=4, cls=NumpyEncoder) # type: ignore
log.info(f"Finished evaluating {algorithm_name}")
[docs] def update_skip_stages(
self,
skip_stages: List[str],
is_eval_enabled: bool,
is_eval_roundwise_enabled: bool,
use_feedback: bool,
save_features: bool,
feature_extraction_only: bool,
) -> List[str]:
"""
Update skip stages based on the boolean values in config.
Args:
skip_stages: List of skip stages specified in the config
is_eval_enabled: Flag to enable evaluation
is_eval_roundwise_enabled: Flag to enable evaluation in every round
use_feedback: Flag to enable using feedback
save_features: Flag to enable saving features
feature_extraction_only: Flag to only run feature extraction
Returns:
Update list of skip stages
"""
if not is_eval_enabled:
skip_stages.append("EvaluateAlgorithms")
skip_stages.append("EvaluateRoundwise")
if not is_eval_roundwise_enabled:
skip_stages.append("EvaluateRoundwise")
if not use_feedback:
skip_stages.append("CreateFeedbackInstance")
skip_stages.append("NoveltyAdaptation")
if not save_features:
skip_stages.append("SaveFeatures")
if feature_extraction_only:
skip_stages.append("CreateFeedbackInstance")
skip_stages.append("WorldDetection")
skip_stages.append("NoveltyClassification")
skip_stages.append("NoveltyAdaptation")
skip_stages.append("NoveltyCharacterization")
return skip_stages
[docs] def run_protocol(self, config: Dict) -> None:
"""
Run the protocol.
Args:
config: Parameters provided in the config
Returns:
None
"""
log.info("Starting OND")
self.skip_stages = self.update_skip_stages(
self.skip_stages,
self.is_eval_enabled,
self.is_eval_roundwise_enabled,
self.use_feedback,
self.save_features,
self.feature_extraction_only,
)
algorithms_attributes = []
# Populate most of the attributes for the algorithm
for algorithm_name in self.algorithms.keys():
algorithm_param = self.algorithms[algorithm_name].get_config()
algorithm_attributes = self.create_algorithm_attributes(
algorithm_name,
algorithm_param,
self.baseline_class,
self.has_baseline,
self.has_reaction_baseline,
self.test_ids,
)
log.info(f"Consolidating attributes for {algorithm_name}")
algorithms_attributes.append(algorithm_attributes)
# Create sessions an instances of all the algorithms and populate
# session_id for algorithm attributes
for idx, algorithm_attributes in enumerate(algorithms_attributes):
algorithms_attributes[idx] = self.create_algorithm_session(
algorithm_attributes,
self.domain,
self.hints,
self.resume_session,
"OND",
)
# Run tests for all the algorithms
algorithm_scores = {}
for algorithm_attributes in algorithms_attributes:
algorithm_name = algorithm_attributes.name
session_id = algorithm_attributes.session_id
test_ids = algorithm_attributes.test_ids
log.info(f"Starting session: {session_id} for algorithm: {algorithm_name}")
skip_stages = self.skip_stages.copy()
if algorithm_attributes.is_reaction_baseline:
skip_stages.append("WorldDetection")
skip_stages.append("NoveltyCharacterization")
ond_test = ONDTest(
algorithm_attributes,
self.dataset_root,
self.domain,
self.feedback_type,
self.feature_dir,
self.harness,
self.save_dir,
session_id,
skip_stages,
self.use_consolidated_features,
self.use_saved_features,
)
test_scores = {}
for test_id in test_ids:
log.info(f"Start test: {test_id}")
test_score = ond_test(test_id)
test_scores[test_id] = test_score
log.info(f"Test complete: {test_id}")
algorithm_scores[algorithm_name] = test_scores
# Evaluate algorithms
self._evaluate_algorithms(
algorithms_attributes, algorithm_scores, self.save_dir
)
# Terminate algorithms
for algorithm_attributes in algorithms_attributes:
algorithm_name = algorithm_attributes.name
session_id = algorithm_attributes.session_id
self.harness.terminate_session(session_id)
log.info(f"Session ended for {algorithm_name}: {session_id}")