serializers.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from __future__ import print_function, unicode_literals
  2. import inspect
  3. from rest_framework.exceptions import ValidationError
  4. from rest_framework.serializers import ListSerializer
  5. __all__ = [
  6. 'BulkListSerializer',
  7. 'BulkSerializerMixin',
  8. ]
  9. class BulkSerializerMixin(object):
  10. def to_internal_value(self, data):
  11. ret = super(BulkSerializerMixin, self).to_internal_value(data)
  12. id_attr = getattr(self.Meta, 'update_lookup_field', 'id')
  13. request_method = getattr(getattr(self.context.get('view'), 'request'), 'method', '')
  14. # add update_lookup_field field back to validated data
  15. # since super by default strips out read-only fields
  16. # hence id will no longer be present in validated_data
  17. if all((isinstance(self.root, BulkListSerializer),
  18. id_attr,
  19. request_method in ('PUT', 'PATCH'))):
  20. id_field = self.fields[id_attr]
  21. id_value = id_field.get_value(data)
  22. ret[id_attr] = id_value
  23. return ret
  24. class BulkListSerializer(ListSerializer):
  25. update_lookup_field = 'id'
  26. def update(self, queryset, all_validated_data):
  27. id_attr = getattr(self.child.Meta, 'update_lookup_field', 'id')
  28. all_validated_data_by_id = {
  29. i.pop(id_attr): i
  30. for i in all_validated_data
  31. }
  32. if not all((bool(i) and not inspect.isclass(i)
  33. for i in all_validated_data_by_id.keys())):
  34. raise ValidationError('')
  35. # since this method is given a queryset which can have many
  36. # model instances, first find all objects to update
  37. # and only then update the models
  38. objects_to_update = queryset.filter(**{
  39. '{}__in'.format(id_attr): all_validated_data_by_id.keys(),
  40. })
  41. if len(all_validated_data_by_id) != objects_to_update.count():
  42. raise ValidationError('Could not find all objects to update.')
  43. updated_objects = []
  44. for obj in objects_to_update:
  45. obj_id = getattr(obj, id_attr)
  46. obj_validated_data = all_validated_data_by_id.get(obj_id)
  47. # use model serializer to actually update the model
  48. # in case that method is overwritten
  49. updated_objects.append(self.child.update(obj, obj_validated_data))
  50. return updated_objects