orf / datatables

MIT License
52 stars 23 forks source link

SQLAlchemy 2.x upgrade #20

Open mmssix opened 7 months ago

mmssix commented 7 months ago

I know this isnt the way it should be done, but this project is truly dead otherwise. So to make it easy to find for others that are in need, this is the library upgraded to use SQLAlchemy2.

I dont know if its the correct way to code this or not, but it does seem to work.

from collections import defaultdict, namedtuple
import re
import inspect

    "search.regex", "searchable", "orderable", "regex"

DataColumn = namedtuple("DataColumn", ("name", "model_name", "filter"))

class DataTablesError(ValueError):

class DataTable(object):
    def __init__(self, params, model, query, columns):
        self.params = params
        self.model = model
        self.query = query
        self.data = {}
        self.columns = []
        self.columns_dict = {}
        self.search_func = lambda qs, s: qs
        self.column_search_func = lambda mc, qs, s: qs

        for col in columns:
            name, model_name, filter_func = None, None, None

            if isinstance(col, DataColumn):
            elif isinstance(col, tuple):
                # col is either 1. (name, model_name), 2. (name, filter) or 3. (name, model_name, filter)
                if len(col) == 3:
                    name, model_name, filter_func = col
                elif len(col) == 2:
                    # Work out the second argument. If it is a function then it's type 2, else it is type 1.
                    if callable(col[1]):
                        name, filter_func = col
                        model_name = name
                        name, model_name = col
                    raise ValueError("Columns must be a tuple of 2 to 3 elements")
                # It's just a string
                name, model_name = col, col

            d = DataColumn(name=name, model_name=model_name, filter=filter_func)
            self.columns_dict[d.name] = d

        for column in (col for col in self.columns if "." in col.model_name):
            parent_table = getattr(self.model, column.model_name.split(".")[0])
            self.query = self.query.join(parent_table)

    def query_into_dict(self, key_start):
        returner = defaultdict(dict)

        # Matches columns[number][key] with an [optional_value] on the end
        pattern = "{}(?:\[(\d+)\])?\[(\w+)\](?:\[(\w+)\])?".format(key_start)

        columns = (param for param in self.params if re.match(pattern, param))

        for param in columns:

            column_id, key, optional_subkey = re.search(pattern, param).groups()

            if column_id is None:
                returner[key] = self.coerce_value(key, self.params[param])
            elif optional_subkey is None:
                returner[int(column_id)][key] = self.coerce_value(key, self.params[param])
                # Oh baby a triple
                subdict = returner[int(column_id)].setdefault(key, {})
                subdict[optional_subkey] = self.coerce_value("{}.{}".format(key, optional_subkey),

        return dict(returner)

    def coerce_value(key, value):
            return int(value)
        except ValueError:
            if key in BOOLEAN_FIELDS:
                return value == "true"

        return value

    def get_integer_param(self, param_name):
        if param_name not in self.params:
            raise DataTablesError("Parameter {} is missing".format(param_name))

            return int(self.params[param_name])
        except ValueError:
            raise DataTablesError("Parameter {} is invalid".format(param_name))

    def add_data(self, **kwargs):

    def json(self):
            return self._json()
        except DataTablesError as e:
            return {
                "error": str(e)

    def get_column(self, column):
        if "." in column.model_name:
            column_path = column.model_name.split(".")
            relationship = getattr(self.model, column_path[0])
            model_column = getattr(relationship.property.mapper.entity, column_path[1])
            model_column = getattr(self.model, column.model_name)

        return model_column

    def searchable(self, func):
        self.search_func = func

    def searchable_column(self, func):
        self.column_search_func = func

    def _json(self):
        draw = self.get_integer_param("draw")
        start = self.get_integer_param("start")
        length = self.get_integer_param("length")

        columns = self.query_into_dict("columns")
        ordering = self.query_into_dict("order")
        search = self.query_into_dict("search")

        query = self.query
        total_records = query.count()

        if callable(self.search_func) and search.get("value", None):
            query = self.search_func(query, search["value"])

        for column_data in columns.values():
            search_value = column_data["search"]["value"]
            if (
                not column_data["searchable"]
                or not search_value
                or not callable(self.column_search_func)

            column_name = column_data["data"]
            column = self.columns_dict[column_name]

            model_column = self.get_column(column)

            query = self.column_search_func(model_column, query, str(search_value))

        for order in ordering.values():
            direction, column = order["dir"], order["column"]

            if column not in columns:
                raise DataTablesError("Cannot order {}: column not found".format(column))

            if not columns[column]["orderable"]:

            column_name = columns[column]["data"]
            column = self.columns_dict[column_name]

            model_column = self.get_column(column)

            if isinstance(model_column, property):
                raise DataTablesError("Cannot order by column {} as it is a property".format(column.model_name))

            query = query.order_by(model_column.desc() if direction == "desc" else model_column.asc())

        filtered_records = query.count()

        if length > 0:
            query = query.slice(start, start + length)

        return {
            "draw": draw,
            "recordsTotal": total_records,
            "recordsFiltered": filtered_records,
            "data": [
                self.output_instance(instance) for instance in query.all()

    def output_instance(self, instance):
        returner = {
            key.name: self.get_value(key, instance) for key in self.columns

        if self.data:
            returner["DT_RowData"] = {
                k: v(instance) for k, v in self.data.items()

        return returner

    def get_value(self, key, instance):
        attr = key.model_name
        if "." in attr:
            tmp_list = attr.split(".")
            attr = tmp_list[-1]
            for sub in tmp_list[:-1]:
                instance = getattr(instance, sub)

        if key.filter is not None:
            r = key.filter(instance)

            r = getattr(instance, attr)
                if not inspect.isbuiltin(r):
                    attributes = vars(r)
                    values = {}
                    for attribute in attributes.keys():
                        if attribute != '_sa_instance_state':
                            values[attribute] = getattr(r, attribute)
                    r = values
            except Exception as e:

        return r() if inspect.isroutine(r) else r