Low-level wrapper для Python должен являться биндингом имеющегося API в Питон, учитывающим паттерны программирования в Питоне. При этом он не должен являться самостоятельной библиотекой, требующей поддержки и диктующей свои правила, как это происходит с существующий Low-level Python API.
В идеале:
Предлагаемое решение:
Обертка C-API для питона должна генерироваться автоматчиески по достаточно простому описанию. Обертка выполняется в виде одного небольшого модуля (не более 1к строк, лучше — меньше), содержащего класс LibArtm
.
На основе Low-level wrapper будет строиться удобная python-библиотека для тематического моделирования (реализация ArtmModel
).
ARTM_LIBRARY_PATH = '/home/romovpa/bigartm/build/src/artm/libartm.so'
import ctypes
from artm import messages_pb2
ARTM_SUCCESS = 0
ARTM_STILL_WORKING = -1
class ArtmException(BaseException): pass
class InternalError(ArtmException): pass
class ArgumentOutOfRangeException(ArtmException): pass
class InvalidMasterIdException(ArtmException): pass
class CorruptedMessageException(ArtmException): pass
class InvalidOperationException(ArtmException): pass
class DiskReadException(ArtmException): pass
class DiskWriteException(ArtmException): pass
ARTM_EXCEPTION_BY_CODE = {
-2: InternalError,
-3: ArgumentOutOfRangeException,
-4: InvalidMasterIdException,
-5: CorruptedMessageException,
-6: InvalidOperationException,
-7: DiskReadException,
-8: DiskWriteException,
}
import ctypes
from google import protobuf
from artm import messages_pb2
class LibArtm(object):
def __init__(self, lib_name):
self.cdll = ctypes.CDLL(lib_name)
def __getattr__(self, name):
func = getattr(self.cdll, name)
if func is None:
raise AttributeError('%s is not a function of libartm' % name)
return self._wrap_call(func)
def _check_error(self, error_code):
if error_code < -1:
lib.cdll.ArtmGetLastErrorMessage.restype = ctypes.c_char_p
error_message = lib.cdll.ArtmGetLastErrorMessage()
# remove exception name from error message
error_message = error_message.split(':', 1)[-1].strip()
exception_class = ARTM_EXCEPTION_BY_CODE.get(error_code)
if exception_class is not None:
raise exception_class(error_message)
else:
raise RuntimeError(error_message)
def _wrap_call(self, func):
def artm_api_call(*args):
cargs = []
for arg in args:
if isinstance(arg, basestring):
arg_cstr_p = ctypes.create_string_buffer(arg)
cargs.append(arg_cstr_p)
elif isinstance(arg, protobuf.message.Message):
message_str = arg.SerializeToString()
message_cstr_p = ctypes.create_string_buffer(message_str)
cargs += [len(message_str), message_cstr_p]
else:
cargs.append(arg)
result = func(*cargs)
self._check_error(result)
return result
return artm_api_call
def _copy_request_result(self, length):
message_blob = ctypes.create_string_buffer(length)
error_code = self.lib_.ArtmCopyRequestResult(length, message_blob)
self._check_error(error_code)
return message_blob
lib = LibArtm(ARTM_LIBRARY_PATH)
config = messages_pb2.MasterComponentConfig()
config.processors_count = -1
master_id = lib.ArtmCreateMasterComponent(config)
lib.ArtmCreateModel(master_id, messages_pb2.ModelConfig())
0
Модификация обертки, приведенной выше, решает следующие задачи:
ArtmCreateMasterComponent
возвращает int
ArtmRequestThetaMatrix
возвращают сообщениеNone
(вместо ARTM_SUCCESS
)dict
, который умным образом преобразуется в соответствующее сообщение (гораздо удобнее в питоне)class CallSpec(object):
def __init__(self, name, arguments, result=None, request=None):
self.name = name
self.arguments = arguments
self.result_type = result
self.request_type = request
Список ARTM_API
нужно будет редактировать при изменении low-level API. Заметьте: необходимо будет минимальное число правок, в отличие от существующего artm.library
.
Некоторые функции из API объявляются специальными, для них по понятной причине не делается обертка:
ArtmGetLastErrorMessage
ArtmCopyRequestResult
ARTM_API = [
CallSpec(
'ArtmCreateMasterComponent',
[('config', messages_pb2.MasterComponentConfig)],
result=ctypes.c_int,
),
CallSpec(
'ArtmReconfigureMasterComponent',
[('master_id', int), ('config', messages_pb2.MasterComponentConfig)],
),
CallSpec(
'ArtmDisposeMasterComponent',
[('master_id', int)],
),
CallSpec(
'ArtmCreateModel',
[('master_id', int), ('config', messages_pb2.ModelConfig)],
),
CallSpec(
'ArtmReconfigureModel',
[('master_id', int), ('config', messages_pb2.ModelConfig)],
),
CallSpec(
'ArtmDisposeModel',
[('master_id', int), ('name', str)],
),
CallSpec(
'ArtmCreateRegularizer',
[('master_id', int), ('config', messages_pb2.RegularizerConfig)],
),
CallSpec(
'ArtmReconfigureRegularizer',
[('master_id', int), ('config', messages_pb2.RegularizerConfig)],
),
CallSpec(
'ArtmDisposeRegularizer',
[('master_id', int), ('name', str)],
),
CallSpec(
'ArtmCreateDictionary',
[('master_id', int), ('config', messages_pb2.DictionaryConfig)],
),
CallSpec(
'ArtmReconfigureDictionary',
[('master_id', int), ('config', messages_pb2.DictionaryConfig)],
),
CallSpec(
'ArtmDisposeDictionary',
[('master_id', int), ('name', str)],
),
CallSpec(
'ArtmAddBatch',
[('master_id', int), ('args', messages_pb2.AddBatchArgs)],
),
CallSpec(
'ArtmInvokeIteration',
[('master_id', int), ('args', messages_pb2.InvokeIterationArgs)],
),
CallSpec(
'ArtmSynchronizeModel',
[('master_id', int), ('args', messages_pb2.SynchronizeModelArgs)],
),
CallSpec(
'ArtmInitializeModel',
[('master_id', int), ('args', messages_pb2.InitializeModelArgs)],
),
CallSpec(
'ArtmExportModel',
[('master_id', int), ('args', messages_pb2.ExportModelArgs)],
),
CallSpec(
'ArtmImportModel',
[('master_id', int), ('args', messages_pb2.ImportModelArgs)],
),
CallSpec(
'ArtmWaitIdle',
[('master_id', int), ('args', messages_pb2.WaitIdleArgs)],
),
CallSpec(
'ArtmOverwriteTopicModel',
[('master_id', int), ('model', messages_pb2.TopicModel)],
),
CallSpec(
'ArtmRequestThetaMatrix',
[('master_id', int), ('args', messages_pb2.GetThetaMatrixArgs)],
request=messages_pb2.ThetaMatrix,
),
CallSpec(
'ArtmRequestTopicModel',
[('master_id', int), ('args', messages_pb2.GetTopicModelArgs)],
request=messages_pb2.TopicModel,
),
CallSpec(
'ArtmRequestRegularizerState',
[('master_id', int), ('name', str)],
request=messages_pb2.RegularizerInternalState,
),
CallSpec(
'ArtmRequestScore',
[('master_id', int), ('args', messages_pb2.GetScoreValueArgs)],
request=messages_pb2.ScoreData,
),
CallSpec(
'ArtmRequestParseCollection',
[('args', messages_pb2.CollectionParserConfig)],
request=messages_pb2.DictionaryConfig,
),
CallSpec(
'ArtmRequestLoadDictionary',
[('filename', str)],
request=messages_pb2.DictionaryConfig,
),
CallSpec(
'ArtmRequestLoadBatch',
[('filename', str)],
request=messages_pb2.Batch,
),
CallSpec(
'ArtmSaveBatch',
[('filename', str), ('batch', messages_pb2.Batch)],
),
CallSpec(
'ArtmSaveBatch',
[('filename', str), ('batch', messages_pb2.Batch)],
),
]
spec = CallSpec(
'ArtmRequestScore',
[('master_id', int), ('args', messages_pb2.GetScoreValueArgs)],
messages_pb2.ScoreData,
request=True,
)
spec.arguments
[('master_id', int), ('args', artm.messages_pb2.GetScoreValueArgs)]
from types import StringTypes
import logging
def dict_to_message(record, message_type):
"""Convert dict to protobuf message"""
def parse_list(values, message):
if isinstance(values[0], dict):
for v in values:
cmd = message.add()
parse_dict(v,cmd)
else:
message.extend(values)
def parse_dict(values, message):
for k, v in values.iteritems():
if isinstance(v, dict):
parse_dict(v, getattr(message, k))
elif isinstance(v, list):
parse_list(v, getattr(message, k))
else:
try:
setattr(message, k, v)
except AttributeError:
raise TypeError('Cannot convert dict to protobuf message {message_type}: bad field "{field}"'.format(
message_type=str(message_type),
field=k,
))
message = message_type()
parse_dict(record, message)
return message
config = messages_pb2.MasterComponentConfig()
config.processor_queue_max_size = 123
config.processors_count = 10
config.SerializeToString()
'0\n8{'
config = dict_to_message({'processor_queue_max_size': 123, 'processors_count': 10}, messages_pb2.MasterComponentConfig)
config.SerializeToString()
'0\n8{'
dict_to_message({}, messages_pb2.MasterComponentConfig)
<artm.messages_pb2.MasterComponentConfig at 0x4d2f938>
dict_to_message({'sdfsdfsdf': 24134}, messages_pb2.MasterComponentConfig)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-84-00039db0cadd> in <module>() ----> 1 dict_to_message({'sdfsdfsdf': 24134}, messages_pb2.MasterComponentConfig) <ipython-input-80-667e9fb90971> in dict_to_message(record, message_type) 29 30 message = message_type() ---> 31 parse_dict(record, message) 32 33 return message <ipython-input-80-667e9fb90971> in parse_dict(values, message) 25 raise TypeError('Cannot convert dict to protobuf message {message_type}: bad field "{field}"'.format( 26 message_type=str(message_type), ---> 27 field=k, 28 )) 29 TypeError: Cannot convert dict to protobuf message <class 'artm.messages_pb2.MasterComponentConfig'>: bad field "sdfsdfsdf"
LibArtm
при помощи __getattr__
автоматически создает обертки над вызовами API¶import ctypes
from google import protobuf
from artm import messages_pb2
class LibArtm(object):
def __init__(self, lib_name):
self.cdll = ctypes.CDLL(lib_name)
self._spec_by_name = {spec.name: spec for spec in ARTM_API}
def __getattr__(self, name):
spec = self._spec_by_name.get(name)
if spec is None:
raise AttributeError('%s is not a function of libartm' % name)
func = getattr(self.cdll, name)
return self._wrap_call(func, spec)
def _check_error(self, error_code):
if error_code < -1:
lib.cdll.ArtmGetLastErrorMessage.restype = ctypes.c_char_p
error_message = lib.cdll.ArtmGetLastErrorMessage()
# remove exception name from error message
error_message = error_message.split(':', 1)[-1].strip()
exception_class = ARTM_EXCEPTION_BY_CODE.get(error_code)
if exception_class is not None:
raise exception_class(error_message)
else:
raise RuntimeError(error_message)
def _copy_request_result(self, length):
message_blob = ctypes.create_string_buffer(length)
error_code = self.lib_.ArtmCopyRequestResult(length, message_blob)
self._check_error(error_code)
return message_blob
def _wrap_call(self, func, spec):
def artm_api_call(*args):
# check the number of arguments
n_args_given = len(args)
n_args_takes = len(spec.arguments)
if n_args_given != n_args_takes:
raise TypeError('{func_name} takes {n_takes} argument ({n_given} given)'.format(
func_name=spec.name,
n_takes=n_args_takes,
n_given=n_args_given,
))
cargs = []
for (arg_index, arg), (arg_name, arg_type) in zip(enumerate(args), spec.arguments):
# try to cast argument to the required type
arg_casted = arg
if issubclass(arg_type, protobuf.message.Message) and isinstance(arg, dict):
# dict -> protobuf message
arg_casted = dict_to_message(arg, arg_type)
# check argument type
if not isinstance(arg_casted, arg_type):
raise TypeError('Argument {arg_index} ({arg_name}) should have type {arg_type} but {given_type} given'.format(
arg_index=arg_index,
arg_name=arg_name,
arg_type=str(arg_type),
given_type=str(type(arg)),
))
arg = arg_casted
# construct c-style arguments
if issubclass(arg_type, basestring):
arg_cstr_p = ctypes.create_string_buffer(arg)
cargs.append(arg_cstr_p)
elif issubclass(arg_type, protobuf.message.Message):
message_str = arg.SerializeToString()
message_cstr_p = ctypes.create_string_buffer(message_str)
cargs += [len(message_str), message_cstr_p]
else:
cargs.append(arg)
# make api call
if spec.result_type is not None:
func.restype = spec.result_type
result = func(*cargs)
self._check_error(result)
# return result value
if spec.request_type is not None:
return self._copy_request_result(length=result)
if spec.result_type is not None:
return result
return artm_api_call
LibArtm
¶lib = LibArtm(ARTM_LIBRARY_PATH)
master_id = lib.ArtmCreateMasterComponent({'cache_theta': True, 'processors_count': 10})
master_id
({'cache_theta': True, 'processors_count': 10},) [('config', <class 'artm.messages_pb2.MasterComponentConfig'>)] [((0, {'cache_theta': True, 'processors_count': 10}), ('config', <class 'artm.messages_pb2.MasterComponentConfig'>))]
6
master_id
¶lib.ArtmCreateModel({'topics_count': 20, 'class_id': ['words', 'labels'], 'class_weight': [1, 0.2]})
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-88-805156b9ba3f> in <module>() ----> 1 lib.ArtmCreateModel({'topics_count': 20, 'class_id': ['words', 'labels'], 'class_weight': [1, 0.2]}) <ipython-input-85-95367b57b175> in artm_api_call(*args) 46 func_name=spec.name, 47 n_takes=n_args_takes, ---> 48 n_given=n_args_given, 49 )) 50 TypeError: ArtmCreateModel takes 2 argument (1 given)
lib.ArtmCreateModel(master_id, {'topics_count': 20, 'class_id': ['words', 'labels'], 'class_weight': [1, 0.2]})
(6, {'class_id': ['words', 'labels'], 'topics_count': 20, 'class_weight': [1, 0.2]}) [('master_id', <type 'int'>), ('config', <class 'artm.messages_pb2.ModelConfig'>)] [((0, 6), ('master_id', <type 'int'>)), ((1, {'class_id': ['words', 'labels'], 'topics_count': 20, 'class_weight': [1, 0.2]}), ('config', <class 'artm.messages_pb2.ModelConfig'>))]
lib.ArtmRequestTopicModel(master_id, {})
(6, {}) [('master_id', <type 'int'>), ('args', <class 'artm.messages_pb2.GetTopicModelArgs'>)] [((0, 6), ('master_id', <type 'int'>)), ((1, {}), ('args', <class 'artm.messages_pb2.GetTopicModelArgs'>))]
--------------------------------------------------------------------------- InvalidOperationException Traceback (most recent call last) <ipython-input-90-dc313b07a517> in <module>() ----> 1 lib.ArtmRequestTopicModel(master_id, {}) <ipython-input-85-95367b57b175> in artm_api_call(*args) 87 func.restype = spec.result_type 88 result = func(*cargs) ---> 89 self._check_error(result) 90 91 # return result value <ipython-input-85-95367b57b175> in _check_error(self, error_code) 26 exception_class = ARTM_EXCEPTION_BY_CODE.get(error_code) 27 if exception_class is not None: ---> 28 raise exception_class(error_message) 29 else: 30 raise RuntimeError(error_message) InvalidOperationException: Topic model does not exist