"""Implementation of T&E Harness for PAR Server."""
import requests
import json
import io
import os
import traceback
import logging
from sail_on_client.harness.test_and_evaluation_harness import TestAndEvaluationHarness
from typing import Any, Dict, Union, List
from requests import Response
from sail_on_client.errors import ApiError, RoundError
from tenacity import (
retry,
stop_after_attempt,
wait_fixed,
before_sleep_log,
)
from json import JSONDecodeError
log = logging.getLogger(__name__)
[docs]class ParHarness(TestAndEvaluationHarness):
"""Harness for PAR server."""
[docs] def __init__(self, url: str, save_directory: str) -> None:
"""
Initialize a client connection object.
Args:
url: URL for the server
save_directory: A directory to save files
Returns:
None
"""
TestAndEvaluationHarness.__init__(self)
self.url = url
self.save_directory = save_directory
[docs] def get_config(self) -> Dict:
"""JSON Compliant representation of the object."""
return {"url": self.url, "save_directory": self.save_directory}
def _check_response(self, response: Response) -> None:
"""
Parse errors that present in the server response.
Args:
response: The response object obtained from the server
Returns:
None
"""
if response.status_code != 200:
try:
response_json = response.json()
# Find the appropriate error class based on error code.
for subclass in ApiError.error_classes():
if subclass.error_code == response.status_code:
raise subclass(
response_json["reason"],
response_json["message"],
response_json["stack_trace"],
)
except JSONDecodeError:
log.exception(f"Server Error: {traceback.format_exc()}")
exit(1)
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def test_ids_request(
self,
protocol: str,
domain: str,
detector_seed: str,
test_assumptions: str = "{}",
) -> str:
"""
Request Test Identifiers as part of a series of individual tests.
Args:
protocol : string indicating which protocol is being evaluated
domain : problem domain for the tests
detector_seed : A seed provided by the novelty detector
test_assumptions : Assumptions used by the detector
Returns:
Filename of file containing test ids
"""
payload = {
"protocol": protocol,
"domain": domain,
"detector_seed": detector_seed,
}
with open(test_assumptions, "r") as f:
contents = f.read()
response = requests.get(
f"{self.url}/test/ids",
files={ # type: ignore
"test_requirements": io.StringIO(json.dumps(payload)),
"test_assumptions": io.StringIO(contents),
},
)
self._check_response(response)
filename = os.path.abspath(
os.path.join(
self.save_directory, f"{protocol}.{domain}.{detector_seed}.csv"
)
)
with open(filename, "w") as f:
f.write(response.content.decode("utf-8"))
return filename
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def session_request(
self,
test_ids: list,
protocol: str,
domain: str,
novelty_detector_version: str,
hints: list,
detection_threshold: float,
) -> str:
"""
Create a new session to evaluate the detector using an empirical protocol.
Args:
test_ids: List of tests being evaluated in this session
protocol: String indicating which protocol is being evaluated
domain: String indicating which domain is being evaluated
novelty_detector_version: The novelty detector being evaluated
hints: Hints used for the session
detection_threshold: Detection threshold for the session
Returns:
A session identifier provided by the server
"""
payload = {
"protocol": protocol,
"novelty_detector_version": novelty_detector_version,
"domain": domain,
"hints": hints,
"detection_threshold": detection_threshold,
}
ids = "\n".join(test_ids) + "\n"
response = requests.post(
f"{self.url}/session",
files={"test_ids": ids, "configuration": io.StringIO(json.dumps(payload))}, # type: ignore
)
self._check_response(response)
return response.json()["session_id"]
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def resume_session(self, session_id: str) -> List[str]:
"""
Get finished test from an existing session.
Args:
session id : Session id that was started but not terminated
Returns:
List of tests finished in the session
"""
params: Dict[str, str] = {"session_id": session_id}
response = requests.get(f"{self.url}/session/latest", params=params)
self._check_response(response)
return response.json()["finished_tests"]
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def dataset_request(self, test_id: str, round_id: int, session_id: str) -> str:
"""
Request data for evaluation.
Args:
test_id: The test being evaluated at this moment.
round_id: The sequential number of the round being evaluated
session_id: The identifier provided by the server for a single experiment
Returns:
Filename of a file containing a list of image files (including full path for each)
"""
params: Dict[str, Union[str, int]] = {
"session_id": session_id,
"test_id": test_id,
"round_id": round_id,
}
response = requests.get(
f"{self.url}/session/dataset",
params=params,
)
if response.status_code == 204:
raise RoundError("End of Dataset", "The entire dataset has been requested")
self._check_response(response)
filename = os.path.abspath(
os.path.join(self.save_directory, f"{session_id}.{test_id}.{round_id}.csv")
)
with open(filename, "wb") as f:
f.write(response.content)
return filename
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def get_feedback_request(
self,
feedback_ids: list,
feedback_type: str,
test_id: str,
round_id: int,
session_id: str,
) -> str:
"""
Get Labels from the server based provided one or more example ids.
Args:
feedback_ids: List of media ids for which feedback is required
feedback_type: Protocols constants with the values: label, detection, characterization
test_id: The id of the test currently being evaluated
round_id: The sequential number of the round being evaluated
session_id: The id provided by a server denoting a session
Returns:
Path to a file containing containing requested feedback
"""
params: Dict[str, Union[str, int]] = {
"feedback_ids": "|".join(feedback_ids),
"session_id": session_id,
"test_id": test_id,
"feedback_type": feedback_type,
}
response = requests.get(
f"{self.url}/session/feedback",
params=params,
)
self._check_response(response)
filename = os.path.abspath(
os.path.join(
self.save_directory,
f"{session_id}.{test_id}.{round_id}_{feedback_type}.csv",
)
)
with open(filename, "wb") as f:
f.write(response.content)
return filename
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def post_results(
self, result_files: Dict[str, str], test_id: str, round_id: int, session_id: str
) -> None:
"""
Post client detector predictions for the dataset.
Args:
result_files: A dictionary of results with protocol constant as key and file path as value
test_id: The id of the test currently being evaluated
round_id: The sequential number of the round being evaluated
session_id: The id provided by a server denoting a session
Returns:
None
"""
payload = {
"session_id": session_id,
"test_id": test_id,
"round_id": round_id,
"result_types": "|".join(result_files.keys()),
}
files = {"test_identification": io.StringIO(json.dumps(payload))}
if len(result_files.keys()) == 0:
raise Exception("Must provide at least one result file")
for r_type in result_files:
with open(result_files[r_type], "r") as f:
contents = f.read()
files[f"{r_type}_file"] = io.StringIO(contents)
response = requests.post(f"{self.url}/session/results", files=files) # type: ignore
self._check_response(response)
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def evaluate_round_wise(
self,
test_id: str,
round_id: int,
session_id: str,
) -> Dict[str, Any]:
"""
Get results for round(s).
Args:
test_id: The id of the test currently being evaluated
round_id: The sequential number of the round being evaluated
session_id: The id provided by a server denoting a session
Returns:
Path to a file with the results
"""
raise NotImplementedError("Round wise accuracy computation is not supported")
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def evaluate(
self,
test_id: str,
round_id: int,
session_id: str,
baseline_session_id: str = None,
) -> Dict:
"""
Get results for test(s).
Args:
test_id: The id of the test currently being evaluated
round_id: The sequential number of the round being evaluated
session_id: The id provided by a server denoting a session
Returns:
Path to a file with the results
"""
params: Dict[str, Union[str, int]] = {
"session_id": session_id,
"test_id": test_id,
"round_id": round_id,
}
response = requests.get(
f"{self.url}/session/evaluations",
params=params,
)
self._check_response(response)
return json.loads(response.text)
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def complete_test(self, session_id: str, test_id: str) -> None:
"""
Mark test as completed.
Args:
session_id: The id of the session currently being evaluated
test_id: The id of the test currently being evaluated
Returns:
None
"""
requests.delete(
f"{self.url}/test",
params={"test_id": test_id, "session_id": session_id},
)
[docs] @retry(
stop=stop_after_attempt(5),
wait=wait_fixed(2),
reraise=True,
before_sleep=before_sleep_log(log, logging.INFO),
)
def terminate_session(self, session_id: str) -> None:
"""
Terminate the session after the evaluation for the protocol is complete.
Args:
session_id: The id provided by a server denoting a session
Returns: None
"""
response = requests.delete(
f"{self.url}/session", params={"session_id": session_id}
)
self._check_response(response)