blob: 20baade16b4f5f42e04891e2bdcd93bfb50b7d33 [file] [log] [blame]
# Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
import contextlib
import os
import tarfile
import zipfile
from .common import safe_mkdtemp
class Archiver(object):
class Error(Exception): pass
class UnpackError(Error): pass
class InvalidArchive(Error): pass
EXTENSIONS = {
'.tar': (tarfile.TarFile.open, tarfile.ReadError),
'.tar.gz': (tarfile.TarFile.open, tarfile.ReadError),
'.tar.bz2': (tarfile.TarFile.open, tarfile.ReadError),
'.tgz': (tarfile.TarFile.open, tarfile.ReadError),
'.zip': (zipfile.ZipFile, zipfile.BadZipfile)
}
@classmethod
def first_nontrivial_dir(cls, path):
files = os.listdir(path)
if len(files) == 1 and os.path.isdir(os.path.join(path, files[0])):
return cls.first_nontrivial_dir(os.path.join(path, files[0]))
else:
return path
@classmethod
def get_extension(cls, filename):
for ext in cls.EXTENSIONS:
if filename.endswith(ext):
return ext
@classmethod
def unpack(cls, filename, location=None):
path = location or safe_mkdtemp()
ext = cls.get_extension(filename)
if ext is None:
raise cls.InvalidArchive('Unknown archive format: %s' % filename)
archive_class, error_class = cls.EXTENSIONS[ext]
try:
with contextlib.closing(archive_class(filename)) as package:
package.extractall(path=path)
except error_class:
raise cls.UnpackError('Could not extract %s' % filename)
return cls.first_nontrivial_dir(path)