Source code for jobman.dao.orm

import json
import logging
import textwrap


[docs]class ORM(object):
[docs] class UpdateError(Exception): pass
[docs] class InsertError(Exception): pass
[docs] class IntegrityError(Exception): pass
def __init__(self, name=None, fields=None, json=json, table_prefix=None, logger=None): self.name = name self.fields = fields self.json = json self.table_prefix = table_prefix self.logger = logger or logging
[docs] def create_table(self, connection=None): create_statement = ( 'CREATE TABLE IF NOT EXISTS {table} ({column_defs})' ).format( table=self.table, column_defs=( ",\n".join([ self._generate_column_def(field=field, field_def=field_def) for field, field_def in self.fields.items()]) ) ) connection.execute(create_statement)
@property def table(self): table_prefix = self.table_prefix or '' return table_prefix + self.name def _generate_column_def(self, field=None, field_def=None): column_def = '{field} {type}'.format( field=field, type=self._get_column_type(field_type=field_def['type']) ) if field_def.get('primary_key'): column_def += ' PRIMARY KEY' return column_def def _get_column_type(self, field_type=None): column_type = field_type if field_type == 'JSON': column_type = 'TEXT' return column_type
[docs] def save_object(self, obj=None, connection=None, replace=True): saved_record = self._save_record(record=self._obj_to_record(obj=obj), connection=connection, replace=replace) return self._record_to_obj(record=saved_record)
def _obj_to_record(self, obj=None, fields=None): record = {} for field, field_def in self.fields.items(): record[field] = self._obj_val_to_record_val( field_def=field_def, value=obj.get(field)) return record def _obj_val_to_record_val(self, field_def=None, value=None): if field_def.get('type') == 'JSON': value = self._serialize_json_value(value) return value def _serialize_json_value(self, value=None): return self.json.dumps(value) def _save_record(self, record=None, connection=None, replace=None): fields = sorted(self.fields.keys()) values = [] for field in fields: field_def = self.fields[field] if field_def.get('default') and record.get(field) is None: record[field] = field_def['default']() if field_def.get('auto_update'): record[field] = field_def['auto_update'](record=record) values.append(record.get(field)) self.execute_insert_or_replace(fields=fields, values=values, replace=replace, connection=connection) return record
[docs] def execute_insert_or_replace(self, fields=None, values=None, replace=None, connection=None): replace_sql = '' if replace: replace_sql = 'OR REPLACE' statement = textwrap.dedent( ''' INSERT {replace_sql} INTO {table} ({csv_fields}) VALUES ({csv_placeholders}) ''' ).strip().format( replace_sql=replace_sql, table=self.table, csv_fields=(','.join(fields)), csv_placeholders=(','.join(['?' for field in fields])) ) try: connection.execute(statement, values) except connection.IntegrityError as exc: raise self.IntegrityError() from exc except Exception as exc: raise self.InsertError() from exc
[docs] def query_objects(self, query=None, connection=None): query = query or {} self._validate_query(query=query) records = self._query_records(query=query, connection=connection) return [self._record_to_obj(record=record) for record in records]
def _validate_query(self, query=None): filterable_fields = self._get_filterable_fields() for _filter in (query or {}).get('filters', []): if _filter['field'] not in filterable_fields: raise Exception("Unknown filter field '{field}'.".format( field=_filter['field'])) def _get_filterable_fields(self): return [field for field, field_def in self.fields.items() if not field_def.get('unfilterable')] def _query_records(self, query=None, connection=None): return self._execute_query(query=query, connection=connection) def _execute_query(self, query=None, connection=None): args = [] statement = 'SELECT {fields} FROM {table}'.format( fields=query.get('fields', '*'), table=self.table ) where_section = self._get_where_section(query=query) if where_section.get('content'): statement += '\nWHERE ' + where_section['content'] args.extend(where_section['args']) limit_section = self._get_limit_section(query=query) order_by_section = self._get_order_by_section(query=query) if order_by_section.get('content'): statement += '\n' + order_by_section['content'] if limit_section.get('content'): statement += '\n' + limit_section['content'] return connection.execute(statement, args) def _get_where_section(self, query=None): clauses = [] args = [] for _filter in query.get('filters', []): where_item = self._filter_to_where_item(_filter=_filter) clauses.append(where_item['clause']) args.extend(where_item['args']) return {'content': ' AND '.join(clauses), 'args': args} def _filter_to_where_item(self, _filter=None): op = _filter['op'] negation = '' if op.startswith('!'): negation = 'NOT' op = op.lstrip('! ') if op == 'IN': args = _filter.get('arg', []) clause_rhs = '({placeholders})'.format( placeholders=(', '.join(['?' for v in args])) ) else: args = [_filter['arg']] clause_rhs = '?' where_item = { 'clause': '{negation} {field} {op} {rhs}'.format( negation=negation, field=_filter['field'], op=op, rhs=clause_rhs, ).lstrip(), 'args': self._format_args(args=args, _filter=_filter) } # Hack for handling != filters for null values. # NOT = value will return true if value is null :/ if op == '=' and negation: where_item['clause'] = '({clause} OR {field} IS NULL)'.format( clause=where_item['clause'], field=_filter['field']) return where_item def _format_args(self, args=None, _filter=None): return [ self._format_value_for_field(value=arg, field_key=_filter['field']) for arg in args ] def _format_value_for_field(self, value=None, field_key=None): field = self.fields[field_key] if field['type'] == 'JSON': value = self._serialize_json_value(value) return value def _get_field_for_filter(self, _filter=None): return self.fields[_filter['field']] def _get_order_by_section(self, query=None): content = '' order_by = query.get('order_by') if order_by: content = 'ORDER BY' if not isinstance(order_by, list): order_by = [order_by] content += ', '.join([ '{field} {direction}'.format( field=order_by_spec['field'], direction=order_by_spec['direction'] ) for order_by_spec in order_by ]) return {'content': content, 'args': []} def _get_limit_section(self, query=None): content = '' if 'limit' in query: content = 'LIMIT %s' % query['limit'] return {'content': content, 'args': []} def _record_to_obj(self, record=None): obj = {} for field, field_def in self.fields.items(): obj[field] = self._record_val_to_obj_val(field_def=field_def, value=record[field]) return obj def _record_val_to_obj_val(self, field_def=None, value=None): if field_def.get('type') == 'JSON': value = self._deserialize_json_value(value) return value def _deserialize_json_value(self, serialized_value=None): if serialized_value is None or serialized_value == '': return None return self.json.loads(serialized_value)
[docs] def update_objects(self, updates=None, query=None, connection=None): self._validate_query(query=query) result = self._update_records(updates=updates, query=query, connection=connection) return result
def _update_records(self, updates=None, query=None, connection=None): args = [] updates_section = self._get_updates_section(updates=updates) args.extend(updates_section['args']) where_section = self._get_where_section(query=query) args.extend(where_section['args']) where_content = where_section['content'] if where_content: where_content = ' WHERE ' + where_content statement = textwrap.dedent( ''' UPDATE {table} SET {updates_content} {where_content} ''' ).lstrip().format( table=self.table, updates_content=updates_section['content'], where_content=where_content, ) try: cursor = connection.execute(statement, args) return {'rowcount': cursor.rowcount} except Exception as exc: raise self.UpdateError() from exc def _get_updates_section(self, updates=None): set_items = [] args = [] for field_key, value in updates.items(): set_items.append('%s = ?' % field_key) args.append(self._format_value_for_field(value=value, field_key=field_key)) return {'content': ', '.join(set_items), 'args': args}