blob: b8604024efa484f0f7a5341c69e9e4654c02d37a [file] [log] [blame]
# coding: utf-8
"""Attribute scoping support for symbolic API."""
from __future__ import absolute_import
from .base import string_types
class AttrScope(object):
"""Attribute manager for scoping.
User can also inherit this object to change naming behavior.
Parameters
----------
kwargs
The attributes to set for all symbol creations in the scope.
"""
current = None
def __init__(self, **kwargs):
self._old_scope = None
for value in kwargs.values():
if not isinstance(value, string_types):
raise ValueError("Attributes need to be string")
self._attr = kwargs
def get(self, attr):
"""
Get the attribute dict given the attribute set by the symbol.
Parameters
----------
attr : dict of string to string
The attribute passed in by user during symbol creation.
Returns
-------
attr : dict of string to string
Updated attributes to add other scope related attributes.
"""
if self._attr:
ret = self._attr.copy()
if attr:
ret.update(attr)
return ret
else:
return attr if attr else {}
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = AttrScope.current
attr = AttrScope.current._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope.current = self._old_scope
AttrScope.current = AttrScope()