Fix Update race condition in Heron Tracker (#3830)
Co-authored-by: Nicholas Nezis <nicholas.nezis@gmail.com>
diff --git a/heron/tools/tracker/src/python/topology.py b/heron/tools/tracker/src/python/topology.py
index 817c350..b49ea55 100644
--- a/heron/tools/tracker/src/python/topology.py
+++ b/heron/tools/tracker/src/python/topology.py
@@ -22,7 +22,9 @@
import dataclasses
import json
import string
+import threading
+from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from copy import deepcopy
import networkx
@@ -220,6 +222,7 @@
execution_state: Optional[ExecutionState_pb]
# pylint: disable=too-many-instance-attributes
+@dataclass(init=False)
class Topology:
"""
Class Topology
@@ -241,8 +244,16 @@
self.id: Optional[int] = None
self.tracker_config: Config = tracker_config
# this maps pb2 structs to structures returned via API endpoints
- # it is repopulated every time one of the pb2 roperties is updated
+ # it is repopulated every time one of the pb2 properties is updated
self.info: Optional[TopologyInfo] = None
+ self.lock = threading.RLock()
+
+ def __eq__(self, o):
+ return isinstance(o, Topology) \
+ and o.name == self.name \
+ and o.state_manager_name == self.state_manager_name \
+ and o.cluster == self.cluster \
+ and o.environ == self.environ
@staticmethod
def _render_extra_links(extra_links, topology, execution_state: ExecutionState_pb) -> None:
@@ -589,30 +600,31 @@
scheduler_location=...,
) -> None:
"""Atomically update this instance to avoid inconsistent reads/writes from other threads."""
- t_state = TopologyState(
- physical_plan=self.physical_plan if physical_plan is ... else physical_plan,
- packing_plan=self.packing_plan if packing_plan is ... else packing_plan,
- execution_state=self.execution_state if execution_state is ... else execution_state,
- tmanager=self.tmanager if tmanager is ... else tmanager,
- scheduler_location=(
- self.scheduler_location
- if scheduler_location is ... else
- scheduler_location
- ),
- )
- if t_state.physical_plan:
- id_ = t_state.physical_plan.topology.id
- elif t_state.packing_plan:
- id_ = t_state.packing_plan.id
- else:
- id_ = None
+ with self.lock:
+ t_state = TopologyState(
+ physical_plan=self.physical_plan if physical_plan is ... else physical_plan,
+ packing_plan=self.packing_plan if packing_plan is ... else packing_plan,
+ execution_state=self.execution_state if execution_state is ... else execution_state,
+ tmanager=self.tmanager if tmanager is ... else tmanager,
+ scheduler_location=self.scheduler_location \
+ if scheduler_location is ... else scheduler_location,
+ )
+ if t_state.physical_plan:
+ id_ = t_state.physical_plan.topology.id
+ elif t_state.packing_plan:
+ id_ = t_state.packing_plan.id
+ else:
+ id_ = None
- info = self._rebuild_info(t_state)
- update = dataclasses.asdict(t_state)
- update["info"] = info
- update["id"] = id_
- # atomic update using python GIL
- self.__dict__.update(update)
+ info = self._rebuild_info(t_state)
+ update = dataclasses.asdict(t_state)
+ update["info"] = info
+ update["id"] = id_
+ if t_state.execution_state:
+ update["cluster"] = t_state.execution_state.cluster
+ update["environ"] = t_state.execution_state.environ
+ # atomic update using python GIL
+ self.__dict__.update(update)
def set_physical_plan(self, physical_plan: PhysicalPlan_pb) -> None:
""" set physical plan """
diff --git a/heron/tools/tracker/src/python/tracker.py b/heron/tools/tracker/src/python/tracker.py
index 3d20ad4..ce06a54 100644
--- a/heron/tools/tracker/src/python/tracker.py
+++ b/heron/tools/tracker/src/python/tracker.py
@@ -19,6 +19,7 @@
# under the License.
''' tracker.py '''
+import threading
import sys
from functools import partial
@@ -42,12 +43,13 @@
by handlers.
"""
- __slots__ = ["topologies", "config", "state_managers"]
+ __slots__ = ["topologies", "config", "state_managers", "lock"]
def __init__(self, config: Config):
self.config = config
self.topologies: List[Topology] = []
self.state_managers = []
+ self.lock = threading.RLock()
def sync_topologies(self) -> None:
"""
@@ -64,17 +66,17 @@
def on_topologies_watch(state_manager: StateManager, topologies: List[str]) -> None:
"""watch topologies"""
topologies = set(topologies)
- Log.info("State watch triggered for topologies.")
- Log.debug("Topologies: %s", topologies)
+ Log.debug("Received topologies: %s, %s", state_manager.name, topologies)
cached_names = {t.name for t in self.get_stmgr_topologies(state_manager.name)}
- Log.debug("Existing topologies: %s", cached_names)
- for name in cached_names - topologies:
- Log.info("Removing topology: %s in rootpath: %s",
- name, state_manager.rootpath)
- self.remove_topology(name, state_manager.name)
+ Log.debug(f"Existing topologies: {state_manager.name}, {cached_names}")
+ for name in cached_names:
+ if name not in topologies:
+ Log.info(f"Removing topology: {name} in rootpath: {state_manager.rootpath}")
+ self.remove_topology(name, state_manager.name)
- for name in topologies - cached_names:
- self.add_new_topology(state_manager, name)
+ for name in topologies:
+ if name not in cached_names:
+ self.add_new_topology(state_manager, name)
for state_manager in self.state_managers:
# The callback function with the bound state_manager as first variable
@@ -114,17 +116,16 @@
"""
return [t for t in self.topologies if t.state_manager_name == name]
- def add_new_topology(self, state_manager, topology_name: str) -> None:
+ def add_new_topology(self, state_manager: StateManager, topology_name: str) -> None:
"""
Adds a topology in the local cache, and sets a watch
on any changes on the topology.
"""
topology = Topology(topology_name, state_manager.name, self.config)
- Log.info("Adding new topology: %s, state_manager: %s",
- topology_name, state_manager.name)
- # populate the cache before making it addressable in the topologies to
- # avoid races due to concurrent execution
- self.topologies.append(topology)
+ with self.lock:
+ if topology not in self.topologies:
+ Log.info(f"Adding new topology: {topology_name}, state_manager: {state_manager.name}")
+ self.topologies.append(topology)
# Set watches on the pplan, execution_state, tmanager and scheduler_location.
state_manager.get_pplan(topology_name, topology.set_physical_plan)
@@ -137,14 +138,15 @@
"""
Removes the topology from the local cache.
"""
- self.topologies = [
- topology
- for topology in self.topologies
- if not (
- topology.name == topology_name
- and topology.state_manager_name == state_manager_name
- )
- ]
+ with self.lock:
+ self.topologies = [
+ topology
+ for topology in self.topologies
+ if not (
+ topology.name == topology_name
+ and topology.state_manager_name == state_manager_name
+ )
+ ]
def filtered_topologies(
self,