158 lines
5.8 KiB
Python
158 lines
5.8 KiB
Python
import os
|
|
import tempfile
|
|
import zipfile
|
|
|
|
from azure.common import AzureMissingResourceHttpError
|
|
try:
|
|
from azure.storage.blob import BlobService
|
|
except ImportError:
|
|
from azure.storage.blob import BlockBlobService as BlobService
|
|
from shutil import unpack_archive
|
|
from threading import Event
|
|
|
|
"""TODOS:
|
|
- use Azure snapshots instead of hacky backups
|
|
"""
|
|
|
|
|
|
def fixed_list_blobs(service, *args, **kwargs):
|
|
"""By defualt list_containers only returns a subset of results.
|
|
|
|
This function attempts to fix this.
|
|
"""
|
|
res = []
|
|
next_marker = None
|
|
while next_marker is None or len(next_marker) > 0:
|
|
kwargs['marker'] = next_marker
|
|
gen = service.list_blobs(*args, **kwargs)
|
|
for b in gen:
|
|
res.append(b.name)
|
|
next_marker = gen.next_marker
|
|
return res
|
|
|
|
|
|
def make_archive(source_path, dest_path):
|
|
if source_path.endswith(os.path.sep):
|
|
source_path = source_path.rstrip(os.path.sep)
|
|
prefix_path = os.path.dirname(source_path)
|
|
with zipfile.ZipFile(dest_path, "w", compression=zipfile.ZIP_STORED) as zf:
|
|
if os.path.isdir(source_path):
|
|
for dirname, subdirs, files in os.walk(source_path):
|
|
zf.write(dirname, os.path.relpath(dirname, prefix_path))
|
|
for filename in files:
|
|
filepath = os.path.join(dirname, filename)
|
|
zf.write(filepath, os.path.relpath(filepath, prefix_path))
|
|
else:
|
|
zf.write(source_path, os.path.relpath(source_path, prefix_path))
|
|
|
|
|
|
class Container(object):
|
|
services = {}
|
|
|
|
def __init__(self, account_name, account_key, container_name, maybe_create=False):
|
|
self._account_name = account_name
|
|
self._container_name = container_name
|
|
if account_name not in Container.services:
|
|
Container.services[account_name] = BlobService(account_name, account_key)
|
|
self._service = Container.services[account_name]
|
|
if maybe_create:
|
|
self._service.create_container(self._container_name, fail_on_exist=False)
|
|
|
|
def put(self, source_path, blob_name, callback=None):
|
|
"""Upload a file or directory from `source_path` to azure blob `blob_name`.
|
|
|
|
Upload progress can be traced by an optional callback.
|
|
"""
|
|
upload_done = Event()
|
|
|
|
def progress_callback(current, total):
|
|
if callback:
|
|
callback(current, total)
|
|
if current >= total:
|
|
upload_done.set()
|
|
|
|
# Attempt to make backup if an existing version is already available
|
|
try:
|
|
x_ms_copy_source = "https://{}.blob.core.windows.net/{}/{}".format(
|
|
self._account_name,
|
|
self._container_name,
|
|
blob_name
|
|
)
|
|
self._service.copy_blob(
|
|
container_name=self._container_name,
|
|
blob_name=blob_name + ".backup",
|
|
x_ms_copy_source=x_ms_copy_source
|
|
)
|
|
except AzureMissingResourceHttpError:
|
|
pass
|
|
|
|
with tempfile.TemporaryDirectory() as td:
|
|
arcpath = os.path.join(td, "archive.zip")
|
|
make_archive(source_path, arcpath)
|
|
self._service.put_block_blob_from_path(
|
|
container_name=self._container_name,
|
|
blob_name=blob_name,
|
|
file_path=arcpath,
|
|
max_connections=4,
|
|
progress_callback=progress_callback,
|
|
max_retries=10)
|
|
upload_done.wait()
|
|
|
|
def get(self, dest_path, blob_name, callback=None):
|
|
"""Download a file or directory to `dest_path` to azure blob `blob_name`.
|
|
|
|
Warning! If directory is downloaded the `dest_path` is the parent directory.
|
|
|
|
Upload progress can be traced by an optional callback.
|
|
"""
|
|
download_done = Event()
|
|
|
|
def progress_callback(current, total):
|
|
if callback:
|
|
callback(current, total)
|
|
if current >= total:
|
|
download_done.set()
|
|
|
|
with tempfile.TemporaryDirectory() as td:
|
|
arcpath = os.path.join(td, "archive.zip")
|
|
for backup_blob_name in [blob_name, blob_name + '.backup']:
|
|
try:
|
|
properties = self._service.get_blob_properties(
|
|
blob_name=backup_blob_name,
|
|
container_name=self._container_name
|
|
)
|
|
if hasattr(properties, 'properties'):
|
|
# Annoyingly, Azure has changed the API and this now returns a blob
|
|
# instead of it's properties with up-to-date azure package.
|
|
blob_size = properties.properties.content_length
|
|
else:
|
|
blob_size = properties['content-length']
|
|
if int(blob_size) > 0:
|
|
self._service.get_blob_to_path(
|
|
container_name=self._container_name,
|
|
blob_name=backup_blob_name,
|
|
file_path=arcpath,
|
|
max_connections=4,
|
|
progress_callback=progress_callback)
|
|
unpack_archive(arcpath, dest_path)
|
|
download_done.wait()
|
|
return True
|
|
except AzureMissingResourceHttpError:
|
|
pass
|
|
return False
|
|
|
|
def list(self, prefix=None):
|
|
"""List all blobs in the container."""
|
|
return fixed_list_blobs(self._service, self._container_name, prefix=prefix)
|
|
|
|
def exists(self, blob_name):
|
|
"""Returns true if `blob_name` exists in container."""
|
|
try:
|
|
self._service.get_blob_properties(
|
|
blob_name=blob_name,
|
|
container_name=self._container_name
|
|
)
|
|
return True
|
|
except AzureMissingResourceHttpError:
|
|
return False
|