Source code for sail_on_client.checkpointer
"""Checkpoint to save and restore attributes."""
import logging
import torch
import os
import pickle as pkl
from torch import Tensor
from typing import Dict, Any
log = logging.getLogger(__name__)
[docs]class Checkpointer(object):
"""Checkpoint object to save and restore attributes."""
[docs] def __init__(self, toolset: Dict) -> None:
"""
Initialize.
Args:
toolset: Dictionary with parameter for the mixin
Returns:
None
"""
self.toolset = toolset
[docs] def _save_elementwise_attribute(
self, detector: Any, attribute: str, attribute_dict: Dict
) -> Dict:
"""
Private method to save attributes element wise.
Args:
detector: Instance of novelty detector
attribute: Name of the detector attribute that needs to be saved
attribute dict: A dictonary containing attribute value for other tests
Returns:
Update attribute dictionary
"""
attribute_val = getattr(detector, attribute)
test_id = self.toolset["test_id"]
if isinstance(attribute_val, dict):
if test_id in attribute_dict:
attribute_dict[test_id].update(attribute_val)
else:
attribute_dict[test_id] = attribute_val
elif isinstance(attribute_val, list):
if test_id in attribute_dict:
attribute_dict[test_id].extend(attribute_val)
else:
attribute_dict[test_id] = attribute_val
elif isinstance(attribute_val, tuple):
if test_id in attribute_dict:
old_attr_val = list(attribute_dict[test_id])
old_attr_val.extend(list(attribute_val))
attribute_dict[test_id] = tuple(old_attr_val)
else:
attribute_dict[test_id] = attribute_val
elif isinstance(attribute_val, Tensor):
if test_id in attribute_dict:
attribute_dict[test_id] = torch.cat(
[attribute_dict[test_id], attribute_val]
)
else:
attribute_dict[test_id] = attribute_val
else:
log.info(
"Treating attribute value as a single element rather than an iterable"
)
attribute_dict[test_id] = attribute_val
return attribute_dict
[docs] def save_attributes(self, step_descriptor: str) -> None:
"""
Save attribute for a detector.
Args:
step_descriptor: String describing steps for protocol
Returns
None
"""
if step_descriptor in self.toolset["saved_attributes"]:
attributes = self.toolset["saved_attributes"][step_descriptor]
else:
attributes = []
save_elementwise = self.toolset["save_elementwise"]
attribute_dict = self.toolset["attributes"]
self.detector: Any
if len(attributes) > 0:
for attribute in attributes:
if hasattr(self.detector, attribute) and save_elementwise:
if attribute not in attribute_dict:
attribute_dict[attribute] = {}
attribute_dict[attribute] = self._save_elementwise_attribute(
self.detector, attribute, attribute_dict[attribute]
)
elif hasattr(self.detector, attribute) and not save_elementwise:
raise NotImplementedError(
"Saving attributes for an entire round is not supported"
)
else:
log.warn(f"Detector does not have {attribute} attribute")
else:
log.info(f"No attributes found for {step_descriptor}")
self.toolset["attributes"] = attribute_dict
[docs] def _restore_elementwise_attribute(
self, detector: Any, attribute_name: str, attribute_val: Dict
) -> Any:
"""
Private method to restore attributes element wise.
Args:
detector: Instance of novelty detector
attribute_name: Name of the detector attribute that needs to be saved
attribute_val: A dictonary containing value for attributes
Returns
Detector with updated value for attributes
"""
dataset_ids = list(
map(lambda x: x.strip(), open(self.toolset["dataset"], "r").readlines())
)
round_id = self.toolset["round_id"]
round_len = len(dataset_ids)
if isinstance(attribute_val, dict):
round_attribute_val = {}
for dataset_id in dataset_ids:
round_attribute_val[dataset_id] = attribute_val[dataset_id]
elif (
isinstance(attribute_val, list)
or isinstance(attribute_val, tuple)
or isinstance(attribute_val, Tensor)
):
round_attribute_val = attribute_val[
round_id * round_len : (round_id + 1) * round_len
]
else:
log.info(
"Treating attribute value as a single element rather than an iterable."
)
round_attribute_val = attribute_val
setattr(detector, attribute_name, round_attribute_val)
return detector
[docs] def restore_attributes(self, step_descriptor: str) -> None:
"""
Restore attribute for a detector.
Args:
step_descriptor: String describing steps for protocol
Returns:
None
"""
if (
self.toolset["use_saved_attributes"]
and step_descriptor in self.toolset["saved_attributes"]
):
attributes = self.toolset["saved_attributes"][step_descriptor]
save_elementwise = self.toolset["save_elementwise"]
save_dir = self.toolset["save_dir"]
test_id = self.toolset["test_id"]
if os.path.isdir(save_dir):
attribute_file = os.path.join(save_dir, f"{test_id}_attribute.pkl")
attribute_val = pkl.load(open(attribute_file, "rb"))
else:
attribute_val = pkl.load(open(save_dir, "rb"))
for attribute in attributes:
if save_elementwise:
self.detector = self._restore_elementwise_attribute(
self.detector, attribute, attribute_val[attribute][test_id]
)
else:
raise NotImplementedError(
"Restoring attributes for an entire round is not supported."
)
else:
log.info(f"No attributes found for {step_descriptor}.")