Skip to content

Commit

Permalink
DLrover support ps failure (#392)
Browse files Browse the repository at this point in the history
* support ps failure

* add

* add test

* reformat

* reformat

* test worker failure

* pass test

* fix test case

* remove unnecessary log and fix unitest

* remove unnecessary log and fix unitest

* reformat

* fix system test

* remove kill thread util
  • Loading branch information
hxdtest committed May 11, 2023
1 parent 7a166b3 commit e2027b3
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 58 deletions.
46 changes: 6 additions & 40 deletions dlrover/examples/deepctr_manual_scale_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ spec:
replicaSpecs:
ps:
autoScale: False
replicas: 1
replicas: 2
template:
spec:
restartPolicy: Never
containers:
- name: main
# yamllint disable-line rule:line-length
image: registry.cn-hangzhou.aliyuncs.com/intell-ai/dlrover:deeprec_criteo_v1
image: registry.cn-hangzhou.aliyuncs.com/dlrover_deeprec/deeprec:v11
imagePullPolicy: Always
resources:
limits:
Expand All @@ -30,8 +30,7 @@ spec:
command:
- /bin/bash
- -c
- "pip install pyhocon \
&& cd /home/model_zoo/tf_estimator/criteo_deeprec \
- "cd /home/model_zoo/tf_estimator/criteo_deeprec \
&& python -m dlrover.trainer.entry.local_entry \
--platform=Kubernetes --conf=train_conf.TrainConf \
--enable_auto_scaling=True"
Expand All @@ -44,46 +43,14 @@ spec:
claimName: pvc-nas
worker:
autoScale: False
replicas: 1
replicas: 2
template:
spec:
restartPolicy: Never
containers:
- name: main
# yamllint disable-line rule:line-length
image: registry.cn-hangzhou.aliyuncs.com/intell-ai/dlrover:deeprec_criteo_v1
imagePullPolicy: Always
resources:
limits:
cpu: "0.5"
memory: 4Gi
requests:
cpu: "0.5"
memory: 4Gi
command:
- /bin/bash
- -c
- "pip install pyhocon \
&& cd /home/model_zoo/tf_estimator/criteo_deeprec \
&& python -m dlrover.trainer.entry.local_entry \
--platform=Kubernetes --conf=train_conf.TrainConf \
--enable_auto_scaling=True"
volumeMounts:
- name: pvc-nas
mountPath: /nas
volumes:
- name: pvc-nas
persistentVolumeClaim:
claimName: pvc-nas
evaluator:
replicas: 1
template:
spec:
restartPolicy: Never
containers:
- name: main
# yamllint disable-line rule:line-length
image: registry.cn-hangzhou.aliyuncs.com/intell-ai/dlrover:deeprec_criteo_v1
image: registry.cn-hangzhou.aliyuncs.com/dlrover_deeprec/deeprec:v11
imagePullPolicy: Always
resources:
limits:
Expand Down Expand Up @@ -112,6 +79,5 @@ spec:
restartPolicy: Never
containers:
- name: main
image: registry.cn-hangzhou.aliyuncs.com/intell-ai/dlrover:test
imagePullPolicy: Always
# yamllint disable-line rule:line-length
image: registry.cn-hangzhou.aliyuncs.com/dlrover_deeprec/deeprec:v0.1.0
6 changes: 6 additions & 0 deletions dlrover/trainer/constants/tf_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ class TFConstants(object):
EnableDynamicSharding = Constant("enable_dynamic_sharding", True)
EnableIncrSavedModel = Constant("enable_incr_saved_model", False)
RelaunchForPs = Constant("relaunch_for_ps", False)
RelaunchForFailure = Constant("relaunch_for_failure", False)
SaveCheckpoint = Constant("save_checkpoint_for_ps", False)
CheckpointIncrementalSaveSecs = Constant(
"checkpoint_incremental_save_secs", None
)
KeepCheckpointMax = Constant("keep_checkpoint_max", 5)
DataShardClient = Constant("data_shard_client", None)
ExitRecoverableSession = Constant("exit_recoverable_session", None)
DataShardCheckpoint = Constant(
"data_shard_checkpoint", "data_shard_checkpoint.json"
)
4 changes: 4 additions & 0 deletions dlrover/trainer/tensorflow/executor/estimator_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ElasticDataShardReportHook,
)
from dlrover.trainer.tensorflow.hooks.global_step_hook import GlobalStepHook
from dlrover.trainer.tensorflow.util import common_util
from dlrover.trainer.tensorflow.util.data_mapping_util import data_mapping
from dlrover.trainer.tensorflow.util.dataset_util import DatasetUtil
from dlrover.trainer.tensorflow.util.estimator_util import (
Expand Down Expand Up @@ -162,6 +163,8 @@ def _prepare_estimator_config_and_params(self):
data_shard_client = self.train_dataset.reader.data_shard_client

if data_shard_client is not None:
global_dict = common_util.GlobalDict()
global_dict[TFConstants.DataShardClient.name] = data_shard_client
logger.info("appending ElasticDataShardReportHook")
shard_report_hook = ElasticDataShardReportHook(data_shard_client)
model_metric_report_hook = ReportModelMetricHook()
Expand Down Expand Up @@ -205,6 +208,7 @@ def _prepare_estimator_config_and_params(self):
def _prepare_train_dataset(self):
"""prepare_train_dataset"""
train_set = self._task_conf.get(TFConstants.TrainSet.name)
logger.info("Prepare training dataset with {}".format(train_set))
self.train_dataset = DatasetUtil.create(train_set)

def _prepare_eval_dataset(self):
Expand Down
4 changes: 2 additions & 2 deletions dlrover/trainer/tensorflow/failover/failover_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def set_local_version(self, version=0):
logger.info("successfully set local version: %s.", version)

def get_training_ps_addr(self):
ps_nodes, _ = self._client.get_all_ps_nodes()
return [n.addr for n in ps_nodes]
ps_nodes, ps_failure = self._client.get_all_ps_nodes()
return [n.addr for n in ps_nodes], ps_failure

def init_version(self, version=0):
logger.info("initiating local and global version")
Expand Down
26 changes: 24 additions & 2 deletions dlrover/trainer/tensorflow/failover/tensorflow_failover.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ def _start_failover_monitor(self):
def monitor_fun():
logger.info("Successfully to start failover monitor!")
while True:
ps_address_changed, _ = self.ps_addresses_changed()
ps_address_changed, change_type = self.ps_addresses_changed()
if ps_address_changed:
self.refresh_env()
if change_type == "ps_failure":
self.exit_from_recoverable_session()
else:
self.info_cheif_do_checkpoints()
break
time.sleep(10)

Expand All @@ -92,7 +96,7 @@ def ps_addresses_changed(self):
"""
changed = False
changed_type = None
curr_address = self._failover_client.get_training_ps_addr()
curr_address, ps_failure = self._failover_client.get_training_ps_addr()
if "".join(curr_address) != "".join(self.curr_ps_address):
if len(curr_address) != len(self.curr_ps_address):
changed_type = "scaling"
Expand All @@ -103,6 +107,11 @@ def ps_addresses_changed(self):
self.curr_ps_address, curr_address
)
)
if ps_failure is True:
changed_type = "ps_failure"
logger.warning(
"ps failure happens, worker pod is going to exit"
)
self.curr_ps_address = curr_address
changed = True
return changed, changed_type
Expand All @@ -121,6 +130,19 @@ def refresh_env(self):
"successfully refresh TF_CONFIFG %s" % os.environ["TF_CONFIG"]
)

def exit_from_recoverable_session(self):
logger.info("exit_from_recoverable_session")
# TODO: when encountering ps failure, session will be hanged.
# we need to add grpc timeout
os._exit(2)

def set_training_thread(self, training_thread):
global_dict = common_util.GlobalDict()
global_dict[TFConstants.RelaunchForFailure.name] = True
self.training_thread = training_thread

def info_cheif_do_checkpoints(self):
global_dict = common_util.GlobalDict()
if self._is_chief:
# chief needs to do checkpoint and then
# set global_dict[TFConstants.SaveCheckpoint.name] = True
Expand Down
14 changes: 8 additions & 6 deletions dlrover/trainer/tensorflow/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ class FileReader(ElasticReader):
def __init__(self, path=None, skip_header=True):
self._skip_header = skip_header
self._file_handler = open(path, "r")
self.data = self._file_handler.readlines()
self._file_name = path
super().__init__(
path=path,
)
self._data_nums = None

def count_data(self):
self.data = self._file_handler.readlines()
if self._skip_header:
self._data_nums = len(self.data) - 1
self.data = self.data[1:]
else:
self._data_nums = len(self.data)
if self._data_nums is None:
if self._skip_header:
self._data_nums = len(self.data) - 1
self.data = self.data[1:]
else:
self._data_nums = len(self.data)

def read_data_by_index_range(self, start_index, end_index):
for i in range(start_index, end_index):
Expand Down
25 changes: 25 additions & 0 deletions dlrover/trainer/tensorflow/util/estimator_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json

from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.training import basic_session_run_hooks

Expand Down Expand Up @@ -64,11 +66,18 @@ def after_run(self, run_context, run_values):


def ck_after_run(self, run_context, run_values):
logger.info("save checkpoint session hook runs")

stale_global_step = run_values.results
global_dict = common_util.GlobalDict()
print(global_dict)
should_save_checkpoint = global_dict.get(
TFConstants.SaveCheckpoint.name, TFConstants.SaveCheckpoint()
)
data_shard_client = global_dict.get(
TFConstants.DataShardClient.name, TFConstants.DataShardClient()
)
data_shard_checkpoint = None
if should_save_checkpoint:
logger.info(
"Before saving checkpoint, cheif should wait for \
Expand Down Expand Up @@ -98,6 +107,10 @@ def ck_after_run(self, run_context, run_values):
self._timer.update_last_triggered_step(global_step)
if self._save(run_context.session, global_step):
run_context.request_stop()
if data_shard_checkpoint is not None:
data_shard_checkpoint = (
data_shard_client.get_shard_checkpoint()
)
elif self._incremental_save:
if (
self._incremental_timer.should_trigger_for_step(
Expand Down Expand Up @@ -126,6 +139,18 @@ def ck_after_run(self, run_context, run_values):
global_step,
self._incremental_save_path,
)
if data_shard_checkpoint is not None:
data_shard_checkpoint = (
data_shard_client.get_shard_checkpoint()
)
if data_shard_checkpoint is not None:
logger.info(
"data_shard_checkpoint for global step {} is {}".format(
global_step, data_shard_checkpoint
)
)
with open(TFConstants.DataShardCheckpoint(), "w") as f:
json.dump(data_shard_checkpoint, f)


def append_hooks(estimator_spec, key, params):
Expand Down
13 changes: 13 additions & 0 deletions dlrover/trainer/tensorflow/util/tf_patch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import time

from tensorflow.python.client import session
Expand All @@ -19,6 +20,7 @@
from tensorflow.python.training.monitored_session import _WrappedSession
from tensorflow_estimator.python.estimator.mode_keys import ModeKeys

from dlrover.trainer.constants.tf_constants import TFConstants
from dlrover.trainer.tensorflow.util import common_util
from dlrover.trainer.tensorflow.util.tf_version_util import (
is_tf_2,
Expand Down Expand Up @@ -304,6 +306,17 @@ def prepare_session_115(
max_wait_secs=max_wait_secs,
config=config,
)
global_dict = common_util.GlobalDict()
if is_loaded_from_checkpoint:
data_shard_client = global_dict.get(
TFConstants.DataShardClient.name, TFConstants.DataShardClient()
)
if data_shard_client is not None:
with open("data_shard_checkpoint.json", "r") as f:
data_shard_checkpoint = json.load(f)
data_shard_client.restore_shard_from_checkpoint(
data_shard_checkpoint
)
if not is_loaded_from_checkpoint:
if init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def init_version(self):
return

def get_training_ps_addr(self):
return ["web04-pod2.default.svc:5004"]
return ["web04-pod2.default.svc:5004"], False


class TensorflowFailoverTest(unittest.TestCase):
Expand Down
22 changes: 16 additions & 6 deletions dlrover/trainer/worker/tf_kubernetes_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,24 @@ def run(self):
run_thread = threading.Thread(target=self.run_ps)
else:
run_thread = threading.Thread(target=self.run_worker)
if hasattr(self, "tensorflow_failover"):
self.tensorflow_failover.set_training_thread(run_thread)
run_thread.start()
run_thread.join()
if not run_thread.is_alive() and global_dict.get(
TFConstants.RelaunchForPs.name, TFConstants.RelaunchForPs()
):
logger.info("ps is migrating or scaling")
if not run_thread.is_alive():
if global_dict.get(
TFConstants.RelaunchForPs.name, TFConstants.RelaunchForPs()
):
logger.info("ps is migrating or scaling")
elif global_dict.get(
TFConstants.RelaunchForFailure.name,
TFConstants.RelaunchForFailure(),
):
logger.info(
"worker encounters ps failure and restart thread"
)
else:
break
global_dict.clear()
self.init_executor(self._task_conf)
continue
else:
break
2 changes: 1 addition & 1 deletion model_zoo/tf_estimator/criteo_deeprec/train_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TrainConf(object):
train_set = {
"reader": FileReader("./data_kaggle_ad_ctr_train.csv"),
"columns": col,
"epoch": 100000,
"epoch": 10,
"batch_size": 32,
}

Expand Down

0 comments on commit e2027b3

Please sign in to comment.