blob: e458e74fee102920caafa41ac17ed3d71736ea8b [file] [log] [blame]
"""Read and write for the RecordIO data format."""
from __future__ import absolute_import
from collections import namedtuple
import ctypes
import struct
import numbers
import numpy as np
from .base import _LIB
from .base import RecordIOHandle
from .base import check_call
from .base import c_str
try:
import cv2
except ImportError:
cv2 = None
class MXRecordIO(object):
"""Read/write RecordIO format data.
Parameters
----------
uri : string
uri path to recordIO file.
flag : string
"r" for reading or "w" writing.
"""
def __init__(self, uri, flag):
self.uri = c_str(uri)
self.handle = RecordIOHandle()
self.flag = flag
self.is_open = False
self.open()
def open(self):
"""Open record file."""
if self.flag == "w":
check_call(_LIB.MXRecordIOWriterCreate(self.uri, ctypes.byref(self.handle)))
self.writable = True
elif self.flag == "r":
check_call(_LIB.MXRecordIOReaderCreate(self.uri, ctypes.byref(self.handle)))
self.writable = False
else:
raise ValueError("Invalid flag %s"%self.flag)
self.is_open = True
def __del__(self):
self.close()
def close(self):
"""Close record file."""
if not self.is_open:
return
if self.writable:
check_call(_LIB.MXRecordIOWriterFree(self.handle))
else:
check_call(_LIB.MXRecordIOReaderFree(self.handle))
self.is_open = False
def reset(self):
"""Reset pointer to first item. If record is opened with 'w',
this will truncate the file to empty."""
self.close()
self.open()
def write(self, buf):
"""Write a string buffer as a record.
Parameters
----------
buf : string (python2), bytes (python3)
Buffer to write.
"""
assert self.writable
check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle,
ctypes.c_char_p(buf),
ctypes.c_size_t(len(buf))))
def read(self):
"""Read a record as string.
Returns
----------
buf : string
Buffer read.
"""
assert not self.writable
buf = ctypes.c_char_p()
size = ctypes.c_size_t()
check_call(_LIB.MXRecordIOReaderReadRecord(self.handle,
ctypes.byref(buf),
ctypes.byref(size)))
if buf:
buf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char*size.value))
return buf.contents.raw
else:
return None
class MXIndexedRecordIO(MXRecordIO):
"""Read/write RecordIO format data supporting random access.
Parameters
----------
idx_path : str
Path to index file.
uri : str
Path to record file. Only support file types that are seekable.
flag : str
'w' for write or 'r' for read
key_type : type
Data type for keys.
"""
def __init__(self, idx_path, uri, flag, key_type=int):
self.idx_path = idx_path
self.idx = {}
self.keys = []
self.key_type = key_type
self.fidx = None
super(MXIndexedRecordIO, self).__init__(uri, flag)
def open(self):
super(MXIndexedRecordIO, self).open()
self.idx = {}
self.keys = []
self.fidx = open(self.idx_path, self.flag)
if not self.writable:
for line in iter(self.fidx.readline, ''):
line = line.strip().split('\t')
key = self.key_type(line[0])
self.idx[key] = int(line[1])
self.keys.append(key)
def close(self):
if not self.is_open:
return
super(MXIndexedRecordIO, self).close()
self.fidx.close()
def seek(self, idx):
"""Query current read head position."""
assert not self.writable
pos = ctypes.c_size_t(self.idx[idx])
check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))
def tell(self):
"""Query current write head position."""
assert self.writable
pos = ctypes.c_size_t()
check_call(_LIB.MXRecordIOWriterTell(self.handle, ctypes.byref(pos)))
return pos.value
def read_idx(self, idx):
"""Read record with index."""
self.seek(idx)
return self.read()
def write_idx(self, idx, buf):
"""Write record with index."""
key = self.key_type(idx)
pos = self.tell()
self.write(buf)
self.fidx.write('%s\t%d\n'%(str(key), pos))
self.idx[key] = pos
self.keys.append(key)
IRHeader = namedtuple('HEADER', ['flag', 'label', 'id', 'id2'])
_IR_FORMAT = 'IfQQ'
_IR_SIZE = struct.calcsize(_IR_FORMAT)
def pack(header, s):
"""Pack a string into MXImageRecord.
Parameters
----------
header : IRHeader
Header of the image record.
``header.label`` can be a number or an array.
s : str
string to pack
"""
header = IRHeader(*header)
if isinstance(header.label, numbers.Number):
header = header._replace(flag=0)
else:
label = np.asarray(header.label, dtype=np.float32)
header = header._replace(flag=label.size, label=0)
s = label.tostring() + s
s = struct.pack(_IR_FORMAT, *header) + s
return s
def unpack(s):
"""Unpack a MXImageRecord to string.
Parameters
----------
s : str
String buffer from ``MXRecordIO.read``.
Returns
-------
header : IRHeader
Header of the image record.
s : str
Unpacked string.
"""
header = IRHeader(*struct.unpack(_IR_FORMAT, s[:_IR_SIZE]))
s = s[_IR_SIZE:]
if header.flag > 0:
header = header._replace(label=np.fromstring(s, np.float32, header.flag))
s = s[header.flag*4:]
return header, s
def unpack_img(s, iscolor=-1):
"""Unpack a MXImageRecord to image.
Parameters
----------
s : str
String buffer from ``MXRecordIO.read``.
iscolor : int
image format option for ``cv2.imdecode``.
Returns
-------
header : IRHeader
Header of the image record.
img : numpy.ndarray
Unpacked image.
"""
header, s = unpack(s)
img = np.fromstring(s, dtype=np.uint8)
assert cv2 is not None
img = cv2.imdecode(img, iscolor)
return header, img
def pack_img(header, img, quality=95, img_fmt='.jpg'):
"""Pack an image into ``MXImageRecord``.
Parameters
----------
header : IRHeader
Header of the image record.
``header.label`` can be a number or an array.
img : numpy.ndarray
image to pack
quality : int
Quality for JPEG encoding in range 1-100, or compression for PNG encoding in range 1-9.
img_fmt : str
Encoding of the image (.jpg for JPEG, .png for PNG).
Returns
-------
s : str
The packed string.
"""
assert cv2 is not None
jpg_formats = ['.JPG', '.JPEG']
png_formats = ['.PNG']
encode_params = None
if img_fmt.upper() in jpg_formats:
encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
elif img_fmt.upper() in png_formats:
encode_params = [cv2.IMWRITE_PNG_COMPRESSION, quality]
ret, buf = cv2.imencode(img_fmt, img, encode_params)
assert ret, 'failed to encode image'
return pack(header, buf.tostring())