Coverage for backpack/config/config.py: 32%
80 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-30 23:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-30 23:12 +0000
1''' This module defines :class:`~backpack.config.ConfigBase`, a base class for Panorama application
2configurations. The class offers two basic functionalities:
4- parse parameters from the Panorama application input ports (``panoramasdk.node.inputs``)
5- generate configuration file snippets for ``graph.json`` and ``package.json`` in your Panorama
6 project. For more information about how to use the CLI, refer to :meth:`~backpack.config.tool`.
7'''
9import dataclasses
10from typing import Sequence, List, Any, Type, TypeVar, Mapping, Tuple, Optional
11import textwrap
13from .serde import ConfigSerDeBase
15T = TypeVar('T', bound='ConfigBase')
17class ConfigBase:
18 ''' Base class for configuration structures.
20 Subclasses must be also dataclasses.
21 '''
23 TYPE_MAP = {
24 int: 'int32',
25 float: 'float32',
26 str: 'string',
27 bool: 'boolean'
28 }
30 def __init__(self) -> None:
31 assert dataclasses.is_dataclass(self), 'ConfigBase instances must be also dataclasses.'
33 @staticmethod
34 def _get_param_name(full_path: Sequence[str]=[]) -> str:
35 return '_'.join(full_path)
37 @staticmethod
38 def _get_param_type(field: dataclasses.field) -> str:
39 if 'type' in field.metadata:
40 typename = field.metadata['type']
41 else:
42 typename = ConfigBase.TYPE_MAP.get(field.type)
43 if typename is None:
44 raise ValueError(f'Field has unsupported type: {field}')
45 return typename
47 @staticmethod
48 def _get_param_serde(field: dataclasses.field) -> Optional[Type[ConfigSerDeBase]]:
49 return field.metadata.get('serde')
51 @staticmethod
52 def _get_param_default(
53 field: dataclasses.field,
54 serde_metadata: Mapping[str, Any]
55 ) -> Optional[str]:
56 if field.default is not dataclasses.MISSING:
57 default = field.default
58 elif field.default_factory is not dataclasses.MISSING:
59 default = field.default_factory()
60 else:
61 default = None
62 serde = ConfigBase._get_param_serde(field)
63 if serde is not None and default is not None:
64 default = serde.serialize(default, metadata=serde_metadata)
65 return default
67 @staticmethod
68 def _get_param_doc(field: dataclasses.field) -> Optional[str]:
69 doc: str = field.metadata.get('__doc__', field.metadata.get('doc'))
70 doc = textwrap.dedent(doc)
71 doc = doc.replace('\n', ' ')
72 doc = doc.replace(' ', ' ')
73 doc = doc.strip()
74 return doc
76 def _param_walker(self, _current_path: Sequence[str]=[]) -> Tuple[str, dataclasses.field]:
77 ''' Recursively walks all parameters in the config structure.
79 Args:
80 _current_path (Sequence[str]): The current path in the config structure. This is an
81 internal recursion parameter and users should always leave the default empty
82 list value.
84 Returns:
85 A generator that yields a tuple for each parameter. The tuple consists of the
86 following values:
87 - full_path (Sequence[str]): The full path of the parameter in the hierarchy
88 - field (dataclasses.field): The original field of the
89 '''
90 fields = dataclasses.fields(self)
91 for fld in fields:
92 current_path = _current_path + [fld.name]
93 if dataclasses.is_dataclass(fld.type):
94 obj = getattr(self, fld.name)
95 for sub_result in obj._param_walker(_current_path=current_path):
96 yield sub_result
97 else:
98 yield (current_path, fld)
100 def get_panorama_definitions(self, serde_metadata: Mapping[str, Any]={}) -> List[Mapping[str, Any]]:
101 ''' Generate the ``nodeGraph.nodes`` snippet in ``graph.json``.
103 Returns:
104 A list of dictionaries containing the application parameter node definitions.
105 '''
106 return [
107 {
108 'name': ConfigBase._get_param_name(full_path=full_path),
109 'interface': ConfigBase._get_param_type(field=field),
110 'value': ConfigBase._get_param_default(field=field, serde_metadata=serde_metadata),
111 'overridable': True,
112 'decorator': {
113 'title': ConfigBase._get_param_name(full_path=full_path),
114 'description': ConfigBase._get_param_doc(field=field)
115 }
116 }
117 for (full_path, field) in self._param_walker()
118 ]
120 def get_panorama_edges(self, code_node_name: str) -> List[Mapping[str, str]]:
121 ''' Generate the ``nodeGraph.edges`` snippet in ``graph.json``
123 Returns:
124 A list of dictionaries containing the application edge definitions.
125 '''
126 return [
127 {
128 "producer": ConfigBase._get_param_name(full_path=full_path),
129 "consumer": code_node_name + "." + ConfigBase._get_param_name(full_path=full_path)
130 }
131 for (full_path, _) in self._param_walker()
132 ]
134 def get_panorama_app_interface(self) -> List[Mapping[str, str]]:
135 ''' Generate the application interface snippet in app node ``package.json``.
137 Returns:
138 A list of dictionaries containing the elements of the application interface definition.
139 '''
140 return [
141 {
142 "name": ConfigBase._get_param_name(full_path=full_path),
143 "type": ConfigBase._get_param_type(field=field)
144 }
145 for (full_path, field) in self._param_walker()
146 ]
148 def get_panorama_markdown_doc(self, serde_metadata: Mapping[str, Any]={}) -> str:
149 ''' Generates a markdown table of the parameters that can be used in documentation.
151 Returns:
152 Markdown formatted text containing the parameter documentation.
153 '''
154 header = (
155 '| name | type | default | description |\n'
156 '|------|---------|---------|-------------|\n'
157 )
158 body = '\n'.join([
159 f'| {ConfigBase._get_param_name(full_path=full_path)} '
160 f'| {ConfigBase._get_param_type(field=field)} '
161 f'| {ConfigBase._get_param_default(field=field, serde_metadata=serde_metadata)} '
162 f'| {ConfigBase._get_param_doc(field=field)} |'
163 for (full_path, field) in self._param_walker()
164 ])
165 return header + body
167 @classmethod
168 def from_panorama_params(cls: Type[T],
169 inputs: 'panoramasdk.port', # type: ignore
170 serde_metadata: Mapping[str, Any]={}
171 ) -> T:
172 ''' Parses the config values form AWS Panorama input parameters.
174 A new Config object is created with the default values. If a particular value is
175 found in the input parameter, its value will override the default value.
177 Args:
178 inputs (panoramasdk.port): The input port of the Panorama application node.
180 Returns:
181 The config instance filled with the parameter values read from the input port.
182 '''
183 result = cls()
185 for (full_path, fld) in result._param_walker():
186 obj = result
187 for name_part in full_path[:-1]:
188 obj = getattr(obj, name_part)
189 key = full_path[-1]
190 name = ConfigBase._get_param_name(full_path=full_path)
191 if not hasattr(inputs, name):
192 continue
193 value = getattr(inputs, name).get()
194 serde = ConfigBase._get_param_serde(field=fld)
195 if serde is not None:
196 value = serde.deserialize(value, metadata=serde_metadata)
197 setattr(obj, key, value)
198 return result
200 def asdict(self):
201 return dataclasses.asdict(self)