Source code for aioorm.fields.many_to_many_fields

import re
from playhouse.fields import (ManyToManyField, ManyToManyQuery,
                              ManyToManyFieldDescriptor, DeferredThroughModel)
from peewee import Proxy, ForeignKeyField, SelectQuery, Model

from ..model import AioModel
from ..query import AioSelectQuery


[docs]class AioManyToManyField(ManyToManyField): def _get_descriptor(self): return AioManyToManyFieldDescriptor(self)
[docs] def add_to_class(self, model_class, name): if isinstance(self._through_model, Proxy): def callback(through_model): self._through_model = through_model self.add_to_class(model_class, name) self._through_model.attach_callback(callback) return elif isinstance(self._through_model, DeferredThroughModel): self._through_model.set_field(model_class, self, name) return self.name = name self.model_class = model_class if not self.verbose_name: self.verbose_name = re.sub('_+', ' ', name).title() setattr(model_class, name, self._get_descriptor()) if not self._is_backref: backref = AioManyToManyField( self.model_class, through_model=self._through_model, _is_backref=True) related_name = self._related_name or model_class._meta.name + 's' backref.add_to_class(self.rel_model, related_name)
[docs] def get_through_model(self): if not self._through_model: lhs, rhs = self.get_models() tables = [model._meta.db_table for model in (lhs, rhs)] class Meta: database = self.model_class._meta.database db_table = '%s_%s_through' % tuple(tables) indexes = ( ((lhs._meta.name, rhs._meta.name), True),) validate_backrefs = False attrs = { lhs._meta.name: ForeignKeyField(rel_model=lhs), rhs._meta.name: ForeignKeyField(rel_model=rhs)} attrs['Meta'] = Meta self._through_model = type( '%s%sThrough' % (lhs.__name__, rhs.__name__), (AioModel,), attrs) return self._through_model
[docs]class AioManyToManyFieldDescriptor(ManyToManyFieldDescriptor): def __get__(self, instance, instance_type=None): if instance is not None: return (AioManyToManyQuery(instance, self, self.rel_model) .select() .join(self.through_model) .join(self.model_class) .where(self.src_fk == instance)) return self.field def __set__(self, instance, value): raise NotImplementedError()
# query = self.__get__(instance) # query.add(value, clear_existing=True)
[docs]class AioManyToManyQuery(AioSelectQuery, ManyToManyQuery): # TODO def _id_list(self, model_or_id_list): if isinstance(model_or_id_list[0], Model): return [obj.get_id() for obj in model_or_id_list] return model_or_id_list
[docs] async def set(self, value): await self.add(value, clear_existing=True)
[docs] async def add(self, value, clear_existing=False): if clear_existing: await self.clear() fd = self._field_descriptor if isinstance(value, SelectQuery): query = value.select( SQL(str(self._instance.get_id())), fd.rel_model._meta.primary_key) await fd.through_model.insert_from( fields=[fd.src_fk, fd.dest_fk], query=query).execute() else: if not isinstance(value, (list, tuple)): value = [value] if not value: return inserts = [{ fd.src_fk.name: self._instance.get_id(), fd.dest_fk.name: rel_id} for rel_id in self._id_list(value)] await fd.through_model.insert_many(inserts).execute()
[docs] async def remove(self, value): fd = self._field_descriptor if isinstance(value, SelectQuery): subquery = value.select(value.model_class._meta.primary_key) return await (fd.through_model .delete() .where( (fd.dest_fk << subquery) & (fd.src_fk == self._instance.get_id())) .execute()) else: if not isinstance(value, (list, tuple)): value = [value] if not value: return return await (fd.through_model .delete() .where( (fd.dest_fk << self._id_list(value)) & (fd.src_fk == self._instance.get_id())) .execute())
[docs] async def clear(self): return await (self._field_descriptor.through_model .delete() .where(self._field_descriptor.src_fk == self._instance) .execute())