Skip to content

Commit

Permalink
Merge pull request apache#11935 [BEAM-9577] Remove use of legacy arti…
Browse files Browse the repository at this point in the history
…fact service in Python.
  • Loading branch information
robertwb committed Jun 10, 2020
2 parents aabe31c + 647d21d commit 57487dc
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 863 deletions.
303 changes: 34 additions & 269 deletions sdks/python/apache_beam/runners/portability/artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@
import tempfile
import threading
import typing
import zipfile
from io import BytesIO
from typing import Callable
from typing import Iterator

import grpc
from future.moves.urllib.request import urlopen
from google.protobuf import json_format

from apache_beam.io import filesystems
from apache_beam.portability import common_urns
Expand All @@ -55,272 +52,6 @@
from typing import Iterable
from typing import MutableMapping

# The legacy artifact staging and retrieval services.


class AbstractArtifactService(
beam_artifact_api_pb2_grpc.LegacyArtifactStagingServiceServicer,
beam_artifact_api_pb2_grpc.LegacyArtifactRetrievalServiceServicer):

_DEFAULT_CHUNK_SIZE = 2 << 20 # 2mb

def __init__(self, root, chunk_size=None):
self._root = root
self._chunk_size = chunk_size or self._DEFAULT_CHUNK_SIZE

def _sha256(self, string):
return hashlib.sha256(string.encode('utf-8')).hexdigest()

def _join(self, *args):
# type: (*str) -> str
raise NotImplementedError(type(self))

def _dirname(self, path):
# type: (str) -> str
raise NotImplementedError(type(self))

def _temp_path(self, path):
# type: (str) -> str
return path + '.tmp'

def _open(self, path, mode):
raise NotImplementedError(type(self))

def _rename(self, src, dest):
# type: (str, str) -> None
raise NotImplementedError(type(self))

def _delete(self, path):
# type: (str) -> None
raise NotImplementedError(type(self))

def _artifact_path(self, retrieval_token, name):
# type: (str, str) -> str
return self._join(self._dirname(retrieval_token), self._sha256(name))

def _manifest_path(self, retrieval_token):
# type: (str) -> str
return retrieval_token

def _get_manifest_proxy(self, retrieval_token):
# type: (str) -> beam_artifact_api_pb2.ProxyManifest
with self._open(self._manifest_path(retrieval_token), 'r') as fin:
return json_format.Parse(
fin.read().decode('utf-8'), beam_artifact_api_pb2.ProxyManifest())

def retrieval_token(self, staging_session_token):
# type: (str) -> str
return self._join(
self._root, self._sha256(staging_session_token), 'MANIFEST')

def PutArtifact(self, request_iterator, context=None):
# type: (...) -> beam_artifact_api_pb2.PutArtifactResponse
first = True
for request in request_iterator:
if first:
first = False
metadata = request.metadata.metadata
retrieval_token = self.retrieval_token(
request.metadata.staging_session_token)
artifact_path = self._artifact_path(retrieval_token, metadata.name)
temp_path = self._temp_path(artifact_path)
fout = self._open(temp_path, 'w')
hasher = hashlib.sha256()
else:
hasher.update(request.data.data)
fout.write(request.data.data)
fout.close()
data_hash = hasher.hexdigest()
if metadata.sha256 and metadata.sha256 != data_hash:
self._delete(temp_path)
raise ValueError(
'Bad metadata hash: %s vs %s' % (metadata.sha256, data_hash))
self._rename(temp_path, artifact_path)
return beam_artifact_api_pb2.PutArtifactResponse()

def CommitManifest(self,
request, # type: beam_artifact_api_pb2.CommitManifestRequest
context=None):
# type: (...) -> beam_artifact_api_pb2.CommitManifestResponse
retrieval_token = self.retrieval_token(request.staging_session_token)
proxy_manifest = beam_artifact_api_pb2.ProxyManifest(
manifest=request.manifest,
location=[
beam_artifact_api_pb2.ProxyManifest.Location(
name=metadata.name,
uri=self._artifact_path(retrieval_token, metadata.name))
for metadata in request.manifest.artifact
])
with self._open(self._manifest_path(retrieval_token), 'w') as fout:
fout.write(json_format.MessageToJson(proxy_manifest).encode('utf-8'))
return beam_artifact_api_pb2.CommitManifestResponse(
retrieval_token=retrieval_token)

def GetManifest(self,
request, # type: beam_artifact_api_pb2.GetManifestRequest
context=None):
# type: (...) -> beam_artifact_api_pb2.GetManifestResponse
return beam_artifact_api_pb2.GetManifestResponse(
manifest=self._get_manifest_proxy(request.retrieval_token).manifest)

def GetArtifact(self,
request, # type: beam_artifact_api_pb2.LegacyGetArtifactRequest
context=None):
# type: (...) -> Iterator[beam_artifact_api_pb2.ArtifactChunk]
for artifact in self._get_manifest_proxy(request.retrieval_token).location:
if artifact.name == request.name:
with self._open(artifact.uri, 'r') as fin:
# This value is not emitted, but lets us yield a single empty
# chunk on an empty file.
chunk = b'1'
while chunk:
chunk = fin.read(self._chunk_size)
yield beam_artifact_api_pb2.ArtifactChunk(data=chunk)
break
else:
raise ValueError('Unknown artifact: %s' % request.name)


class ZipFileArtifactService(AbstractArtifactService):
"""Stores artifacts in a zip file.
This is particularly useful for storing artifacts as part of an UberJar for
submitting to an upstream runner's cluster.
Writing to zip files requires Python 3.6+.
"""
def __init__(self, path, internal_root, chunk_size=None):
if sys.version_info < (3, 6):
raise RuntimeError(
'Writing to zip files requires Python 3.6+, '
'but current version is %s' % sys.version)
super(ZipFileArtifactService, self).__init__(internal_root, chunk_size)
self._zipfile = zipfile.ZipFile(path, 'a')
self._lock = threading.Lock()

def _join(self, *args):
# type: (*str) -> str
return '/'.join(args)

def _dirname(self, path):
# type: (str) -> str
return path.rsplit('/', 1)[0]

def _temp_path(self, path):
# type: (str) -> str
return path # ZipFile offers no move operation.

def _rename(self, src, dest):
# type: (str, str) -> None
assert src == dest

def _delete(self, path):
# type: (str) -> None
# ZipFile offers no delete operation: https://bugs.python.org/issue6818
pass

def _open(self, path, mode):
if path.startswith('/'):
raise ValueError(
'ZIP file entry %s invalid: '
'path must not contain a leading slash.' % path)
return self._zipfile.open(path, mode, force_zip64=True)

def PutArtifact(self, request_iterator, context=None):
# ZipFile only supports one writable channel at a time.
with self._lock:
return super(ZipFileArtifactService,
self).PutArtifact(request_iterator, context)

def CommitManifest(self, request, context=None):
# ZipFile only supports one writable channel at a time.
with self._lock:
return super(ZipFileArtifactService,
self).CommitManifest(request, context)

def GetManifest(self, request, context=None):
# ZipFile appears to not be threadsafe on some platforms.
with self._lock:
return super(ZipFileArtifactService, self).GetManifest(request, context)

def GetArtifact(self, request, context=None):
# ZipFile appears to not be threadsafe on some platforms.
with self._lock:
for chunk in super(ZipFileArtifactService, self).GetArtifact(request,
context):
yield chunk

def close(self):
self._zipfile.close()


class BeamFilesystemArtifactService(AbstractArtifactService):
def _join(self, *args):
# type: (*str) -> str
return filesystems.FileSystems.join(*args)

def _dirname(self, path):
# type: (str) -> str
return filesystems.FileSystems.split(path)[0]

def _rename(self, src, dest):
# type: (str, str) -> None
filesystems.FileSystems.rename([src], [dest])

def _delete(self, path):
# type: (str) -> None
filesystems.FileSystems.delete([path])

def _open(self, path, mode='r'):
dir = self._dirname(path)
if not filesystems.FileSystems.exists(dir):
try:
filesystems.FileSystems.mkdirs(dir)
except Exception:
pass

if 'w' in mode:
return filesystems.FileSystems.create(path)
else:
return filesystems.FileSystems.open(path)


# The dependency-aware artifact staging and retrieval services.


class _QueueIter(object):

_END = object()

def __init__(self):
self._queue = queue.Queue()

def put(self, item):
self._queue.put(item)

def done(self):
self._queue.put(self._END)
self._queue.put(StopIteration)

def abort(self, exn=None):
if exn is None:
exn = sys.exc_info()[1]
self._queue.put(self._END)
self._queue.put(exn)

def __iter__(self):
return self

def __next__(self):
item = self._queue.get()
if item is self._END:
raise self._queue.get()
else:
return item

if sys.version_info < (3, ):
next = __next__


class ArtifactRetrievalService(
beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer):
Expand Down Expand Up @@ -589,3 +320,37 @@ def sha256(path):
for block in iter(lambda: fin.read(4 << 20), b''):
hasher.update(block)
return hasher.hexdigest()


class _QueueIter(object):

_END = object()

def __init__(self):
self._queue = queue.Queue()

def put(self, item):
self._queue.put(item)

def done(self):
self._queue.put(self._END)
self._queue.put(StopIteration)

def abort(self, exn=None):
if exn is None:
exn = sys.exc_info()[1]
self._queue.put(self._END)
self._queue.put(exn)

def __iter__(self):
return self

def __next__(self):
item = self._queue.get()
if item is self._END:
raise self._queue.get()
else:
return item

if sys.version_info < (3, ):
next = __next__
Loading

0 comments on commit 57487dc

Please sign in to comment.