diff --git a/elements/path.py b/elements/path.py index 0436685..53ef2ef 100644 --- a/elements/path.py +++ b/elements/path.py @@ -173,7 +173,7 @@ def open(self, mode='r'): return open(str(self), mode=mode) def absolute(self): - return type(self)(os.path.absolute(str(self))) + return type(self)(os.path.abspath(str(self))) def glob(self, pattern): for path in globlib.glob(f'{str(self)}/{pattern}', recursive=True): @@ -621,8 +621,9 @@ def read(self, size=None): raise io.UnsupportedOperation def write(self, b): - self.fp.write(b) - self.pos += len(b) + n = self.fp.write(b) + self.pos += n + return n def close(self): import google.cloud.exceptions @@ -645,6 +646,422 @@ def _wait_until_exists(self, blob, timeout=60): raise TimeoutError +S3_LOCK = threading.RLock() +S3_CLIENT = None +S3_RESOURCE = None + + +def s3_retry(max_attempts=3): + from botocore.config import Config + return Config(retries={'max_attempts': max_attempts, 'mode': 'adaptive'}) + + +class S3Path(Path): + + __slots__ = ('_path',) + + def __init__(self, path): + path = str(path) + super().__init__(path) + + @property + def _bucket_name(self): + return str(self)[5:].split('/', 1)[0] + + @property + def _key(self): + path = str(self)[5:] + return path.split('/', 1)[1] if '/' in path else None + + @property + def client(self): + global S3_CLIENT + if not S3_CLIENT: + with S3_LOCK: + if not S3_CLIENT: + import boto3 + S3_CLIENT = boto3.client('s3', config=s3_retry()) + return S3_CLIENT + + @property + def resource(self): + global S3_RESOURCE + if not S3_RESOURCE: + with S3_LOCK: + if not S3_RESOURCE: + import boto3 + S3_RESOURCE = boto3.resource('s3', config=s3_retry()) + return S3_RESOURCE + + @property + def size(self): + resp = self.client.head_object(Bucket=self._bucket_name, Key=self._key) + return resp['ContentLength'] + + def open(self, mode='r'): + assert self._key, 'is a directory' + if 'r' in mode: + return S3ReadFile(self) + elif 'a' in mode: + return S3AppendFile(self, mode) + else: + return S3WriteFile(self, mode) + + def read(self, mode='r'): + assert self._key, 'is a directory' + if mode == 'rb': + resp = self.client.get_object(Bucket=self._bucket_name, Key=self._key) + return resp['Body'].read() + elif mode == 'r': + return self.read('rb').decode('utf-8') + else: + raise NotImplementedError(mode) + + def write(self, content, mode='w'): + assert mode in 'w a wb ab'.split(), mode + if mode == 'a': + prefix = self.read('r') if self.isfile() else '' + content = prefix + content + if mode == 'ab': + prefix = self.read('rb') if self.isfile() else b'' + content = prefix + content + if isinstance(content, str): + content = content.encode('utf-8') + self.client.upload_fileobj( + io.BytesIO(content), self._bucket_name, self._key + ) + + def absolute(self): + return self + + def glob(self, pattern): + pattern = pattern.rstrip('/') + assert pattern + prefix = (self._key + '/') if self._key else '' + paginator = self.client.get_paginator('list_objects_v2') + if pattern == '**' or pattern == '**/*': + pages = paginator.paginate(Bucket=self._bucket_name, Prefix=prefix) + keys = [] + for page in pages: + for obj in page.get('Contents', []): + keys.append(obj['Key']) + # Include intermediate directories. + dirs = set() + for k in keys: + if k.startswith(prefix): + rel = k[len(prefix) :] + parts = rel.split('/') + for i in range(1, len(parts)): + dirs.add(prefix + '/'.join(parts[:i])) + results = sorted(set([k.rstrip('/') for k in keys] + list(dirs))) + elif ( + '**' not in pattern + and '*' not in pattern + and '?' not in pattern + and '[' not in pattern + ): + # Literal pattern. + target = prefix + pattern + results = [] + # Check as file. + try: + self.client.head_object(Bucket=self._bucket_name, Key=target) + results.append(target) + except self.client.exceptions.ClientError: + pass + # Check as directory prefix. + resp = self.client.list_objects_v2( + Bucket=self._bucket_name, Prefix=target + '/', MaxKeys=1 + ) + if resp.get('Contents'): + results.append(target) + elif '**' not in pattern and '/' not in pattern: + # Single-level glob (e.g. '*', '*.txt', 'foo*'). + # Use Delimiter='/' to only get immediate children. + pages = paginator.paginate( + Bucket=self._bucket_name, Prefix=prefix, Delimiter='/' + ) + keys = [] + dirs = [] + for page in pages: + for obj in page.get('Contents', []): + keys.append(obj['Key']) + for cp in page.get('CommonPrefixes', []): + dirs.append(cp['Prefix'].rstrip('/')) + all_paths = sorted(set([k.rstrip('/') for k in keys] + dirs)) + results = fnmatch.filter(all_paths, prefix + pattern) + else: + # General glob: list with common prefix, then filter. + # Find the longest literal prefix before any glob character. + literal = '' + for ch in pattern: + if ch in '*?[{': + break + literal += ch + pages = paginator.paginate( + Bucket=self._bucket_name, Prefix=prefix + literal + ) + keys = [] + for page in pages: + for obj in page.get('Contents', []): + keys.append(obj['Key']) + # Include intermediate directories. + dirs = set() + for k in keys: + if k.startswith(prefix): + rel = k[len(prefix) :] + parts = rel.split('/') + for i in range(1, len(parts)): + dirs.add(prefix + '/'.join(parts[:i])) + all_paths = sorted(set([k.rstrip('/') for k in keys] + list(dirs))) + results = fnmatch.filter(all_paths, prefix + pattern) + results = sorted(set(results)) + return [type(self)(f's3://{self._bucket_name}/{x}') for x in results] + + def exists(self): + return self.isfile() or self.isdir() + + def isfile(self): + if not self._key: + return False + try: + self.client.head_object(Bucket=self._bucket_name, Key=self._key) + return True + except self.client.exceptions.ClientError: + return False + + def isdir(self): + if not self._key: + # Bucket-level path, check if bucket exists. + try: + self.client.head_bucket(Bucket=self._bucket_name) + return True + except self.client.exceptions.ClientError: + return False + # Check if any objects exist under this prefix. + resp = self.client.list_objects_v2( + Bucket=self._bucket_name, Prefix=self._key + '/', MaxKeys=1 + ) + return bool(resp.get('Contents')) + + def mkdir(self, **kwargs): + assert kwargs.pop('parents', True) + assert kwargs.pop('exist_ok', True) + assert not kwargs, kwargs + if not self._key: + return + if self.exists(): + return + # Create a directory marker. + self.client.put_object( + Bucket=self._bucket_name, Key=self._key + '/', Body=b'' + ) + + def remove(self, recursive=False): + if recursive: + # Delete all objects under this prefix. + paginator = self.client.get_paginator('list_objects_v2') + prefix = self._key + '/' if self._key else '' + pages = paginator.paginate(Bucket=self._bucket_name, Prefix=prefix) + for page in pages: + objects = [{'Key': obj['Key']} for obj in page.get('Contents', [])] + if objects: + self.client.delete_objects( + Bucket=self._bucket_name, Delete={'Objects': objects} + ) + else: + if self._key: + self.client.delete_object(Bucket=self._bucket_name, Key=self._key) + # Also try deleting the directory marker. + try: + self.client.delete_object( + Bucket=self._bucket_name, Key=self._key + '/' + ) + except self.client.exceptions.ClientError: + pass + + def copy(self, dest, recursive=False): + dest = Path(dest) + if isinstance(dest, type(self)) and not recursive: + self.client.copy_object( + Bucket=dest._bucket_name, + Key=dest._key, + CopySource={'Bucket': self._bucket_name, 'Key': self._key}, + ) + else: + _copy_across_filesystems(self, dest, recursive) + + def move(self, dest, recursive=False): + dest = Path(dest) + if isinstance(dest, type(self)) and not recursive: + self.client.copy_object( + Bucket=dest._bucket_name, + Key=dest._key, + CopySource={'Bucket': self._bucket_name, 'Key': self._key}, + ) + self.client.delete_object(Bucket=self._bucket_name, Key=self._key) + else: + _copy_across_filesystems(self, dest, recursive) + self.remove() + + +class S3ReadFile: + + def __init__(self, s3path): + self.s3path = s3path + self.pos = 0 + self._size = None + + def __enter__(self): + return self + + def __exit__(self, *e): + self.close() + + def readable(self): + return True + + def writeable(self): + return False + + def seekable(self): + return True + + def tell(self): + return self.pos + + def _get_size(self): + if self._size is None: + self._size = self.s3path.size + return self._size + + def seek(self, pos, mode=os.SEEK_SET): + size = self._get_size() + if mode == os.SEEK_SET: + self.pos = pos + elif mode == os.SEEK_CUR: + self.pos += pos + elif mode == os.SEEK_END: + self.pos = size + pos + else: + raise NotImplementedError(mode) + assert 0 <= self.pos <= size, (self.pos, size) + + def read(self, size=None): + client = self.s3path.client + bucket = self.s3path._bucket_name + key = self.s3path._key + if size is None: + kwargs = {} + if self.pos > 0: + kwargs['Range'] = f'bytes={self.pos}-' + resp = client.get_object(Bucket=bucket, Key=key, **kwargs) + data = resp['Body'].read() + self.pos += len(data) + return data + file_size = self._get_size() + end = min(self.pos + size, file_size) + if self.pos >= end: + return b'' + resp = client.get_object( + Bucket=bucket, Key=key, Range=f'bytes={self.pos}-{end - 1}' + ) + data = resp['Body'].read() + self.pos = end + return data[:size] + + def close(self): + pass + + +class S3WriteFile: + + def __init__(self, s3path, mode='w'): + self.s3path = s3path + self.mode = mode + self.buffer = io.BytesIO() if 'b' in mode else io.StringIO() + + def __enter__(self): + return self + + def __exit__(self, *e): + self.close() + + def readable(self): + return False + + def writeable(self): + return True + + def seekable(self): + return False + + def tell(self): + return self.buffer.tell() + + def write(self, data): + return self.buffer.write(data) + + def close(self): + content = self.buffer.getvalue() + if isinstance(content, str): + content = content.encode('utf-8') + self.s3path.client.upload_fileobj( + io.BytesIO(content), self.s3path._bucket_name, self.s3path._key + ) + + +class S3AppendFile: + + def __init__(self, s3path, mode='a'): + self.s3path = s3path + self.mode = mode + self.buffer = io.BytesIO() if 'b' in mode else io.StringIO() + self.pos = s3path.size if s3path.isfile() else 0 + + def __enter__(self): + return self + + def __exit__(self, *e): + self.close() + + def readable(self): + return False + + def writeable(self): + return True + + def seekable(self): + return False + + def tell(self): + return self.pos + + def seek(self, pos, mode=os.SEEK_SET): + raise io.UnsupportedOperation + + def read(self, size=None): + raise io.UnsupportedOperation + + def write(self, data): + n = self.buffer.write(data) + self.pos += n + return n + + def close(self): + new_data = self.buffer.getvalue() + if isinstance(new_data, str): + new_data = new_data.encode('utf-8') + if self.s3path.isfile(): + existing = self.s3path.read('rb') + content = existing + new_data + else: + content = new_data + self.s3path.client.upload_fileobj( + io.BytesIO(content), self.s3path._bucket_name, self.s3path._key + ) + + def _copy_across_filesystems(source, dest, recursive): assert isinstance(source, Path), type(source) assert isinstance(dest, Path), type(dest) @@ -667,6 +1084,7 @@ def _copy_across_filesystems(source, dest, recursive): Path.filesystems = [ (GCSPath, lambda path: path.startswith('gs://')), + (S3Path, lambda path: path.startswith('s3://')), (TFPath, lambda path: path.startswith('/cns/')), (LocalPath, lambda path: True), ]