blob: 78939760c42b9d8a029f24bb532122ead4d04b61 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""""Defines abstractions around compiler artifacts produced in compiling micro TVM binaries."""
import io
import os
import json
import shutil
import tarfile
class ArtifactFileNotFoundError(Exception):
"""Raised when an artifact file cannot be found on disk."""
class ArtifactBadSymlinkError(Exception):
"""Raised when an artifact symlink points outside the base directory."""
class ArtifactBadArchiveError(Exception):
"""Raised when an artifact archive is malformed."""
class Artifact:
"""Describes a compiler artifact and defines common logic to archive it for transport."""
# A version number written to the archive.
ENCODING_VERSION = 1
# A unique string identifying the type of artifact in an archive. Subclasses must redefine this
# variable.
ARTIFACT_TYPE = None
@classmethod
def unarchive(cls, archive_path, base_dir):
"""Unarchive an artifact into base_dir.
Parameters
----------
archive_path : str
Path to the archive file.
base_dir : str
Path to a non-existent, empty directory under which the artifact will live.
Returns
-------
Artifact :
The unarchived artifact.
"""
if os.path.exists(base_dir):
raise ValueError(f"base_dir exists: {base_dir}")
base_dir_parent, base_dir_name = os.path.split(base_dir)
temp_dir = os.path.join(base_dir_parent, f"__tvm__{base_dir_name}")
os.mkdir(temp_dir)
try:
with tarfile.open(archive_path) as tar_f:
tar_f.extractall(temp_dir)
temp_dir_contents = os.listdir(temp_dir)
if len(temp_dir_contents) != 1:
raise ArtifactBadArchiveError(
"Expected exactly 1 subdirectory at root of archive, got "
f"{temp_dir_contents!r}"
)
metadata_path = os.path.join(temp_dir, temp_dir_contents[0], "metadata.json")
if not metadata_path:
raise ArtifactBadArchiveError("No metadata.json found in archive")
with open(metadata_path) as metadata_f:
metadata = json.load(metadata_f)
version = metadata.get("version")
if version != cls.ENCODING_VERSION:
raise ArtifactBadArchiveError(
f"archive version: expect {cls.EXPECTED_VERSION}, found {version}"
)
os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir)
artifact_cls = cls
for sub_cls in cls.__subclasses__():
if sub_cls.ARTIFACT_TYPE is not None and sub_cls.ARTIFACT_TYPE == metadata.get(
"artifact_type"
):
artifact_cls = sub_cls
break
return artifact_cls.from_unarchived(
base_dir, metadata["labelled_files"], metadata["metadata"]
)
finally:
shutil.rmtree(temp_dir)
@classmethod
def from_unarchived(cls, base_dir, labelled_files, metadata):
return cls(base_dir, labelled_files, metadata)
def __init__(self, base_dir, labelled_files, metadata):
"""Create a new artifact.
Parameters
----------
base_dir : str
The path to a directory on disk which contains all the files in this artifact.
labelled_files : Dict[str, str]
A dict mapping a file label to the relative paths of the files that carry that label.
metadata : Dict
A dict containing artitrary JSON-serializable key-value data describing the artifact.
"""
self.base_dir = os.path.realpath(base_dir)
self.labelled_files = labelled_files
self.metadata = metadata
for label, files in labelled_files.items():
for f in files:
f_path = os.path.join(self.base_dir, f)
if not os.path.lexists(f_path):
raise ArtifactFileNotFoundError(f"{f} (label {label}): not found at {f_path}")
if os.path.islink(f_path):
link_path = os.path.readlink(f_path)
if os.path.isabs(link_path):
link_fullpath = link_path
else:
link_fullpath = os.path.join(os.path.dirname(f_path), link_path)
link_fullpath = os.path.realpath(link_fullpath)
if not link_fullpath.startswith(self.base_dir):
raise ArtifactBadSymlinkError(
f"{f} (label {label}): symlink points outside artifact tree"
)
def abspath(self, rel_path):
"""Return absolute path to the member with the given relative path."""
return os.path.join(self.base_dir, rel_path)
def label(self, label):
"""Return a list of relative paths to files with the given label."""
return self.labelled_files[label]
def label_abspath(self, label):
return [self.abspath(p) for p in self.labelled_files[label]]
def archive(self, archive_path):
"""Create a relocatable tar archive of the artifacts.
Parameters
----------
archive_path : str
Path to the tar file to create. Or, path to a directory, under which a tar file will be
created named {base_dir}.tar.
Returns
-------
str :
The value of archive_path, after potentially making the computation describe above.
"""
if os.path.isdir(archive_path):
archive_path = os.path.join(archive_path, f"{os.path.basename(self.base_dir)}.tar")
archive_name = os.path.splitext(os.path.basename(archive_path))[0]
with tarfile.open(archive_path, "w") as tar_f:
def _add_file(name, data, f_type):
tar_info = tarfile.TarInfo(name=name)
tar_info.type = f_type
data_bytes = bytes(data, "utf-8")
tar_info.size = len(data)
tar_f.addfile(tar_info, io.BytesIO(data_bytes))
_add_file(
f"{archive_name}/metadata.json",
json.dumps(
{
"version": self.ENCODING_VERSION,
"labelled_files": self.labelled_files,
"metadata": self.metadata,
},
indent=2,
sort_keys=True,
),
tarfile.REGTYPE,
)
for dir_path, _, files in os.walk(self.base_dir):
for f in files:
file_path = os.path.join(dir_path, f)
archive_file_path = os.path.join(
archive_name, os.path.relpath(file_path, self.base_dir)
)
if not os.path.islink(file_path):
tar_f.add(file_path, archive_file_path, recursive=False)
continue
link_path = os.readlink(file_path)
if not os.path.isabs(link_path):
tar_f.add(file_path, archive_file_path, recursive=False)
continue
relpath = os.path.relpath(link_path, os.path.dirname(file_path))
_add_file(archive_file_path, relpath, tarfile.LNKTYPE)
return archive_path