|
@@ -1,4 +1,5 @@
|
|
|
from __future__ import print_function, unicode_literals
|
|
|
+from rest_framework.exceptions import ValidationError
|
|
|
from rest_framework.serializers import ListSerializer
|
|
|
|
|
|
|
|
@@ -32,7 +33,7 @@ class BulkSerializerMixin(object):
|
|
|
class BulkListSerializer(ListSerializer):
|
|
|
update_lookup_field = 'id'
|
|
|
|
|
|
- def update(self, instances, all_validated_data):
|
|
|
+ def update(self, queryset, all_validated_data):
|
|
|
id_attr = getattr(self.child.Meta, 'update_lookup_field', 'id')
|
|
|
|
|
|
all_validated_data_by_id = {
|
|
@@ -40,9 +41,19 @@ class BulkListSerializer(ListSerializer):
|
|
|
for i in all_validated_data
|
|
|
}
|
|
|
|
|
|
+ # since this method is given a queryset which can have many
|
|
|
+ # model instances, first find all objects to update
|
|
|
+ # and only then update the models
|
|
|
+ objects_to_update = queryset.filter(**{
|
|
|
+ '{}__in'.format(id_attr): all_validated_data_by_id.keys(),
|
|
|
+ })
|
|
|
+
|
|
|
+ if len(all_validated_data_by_id) != objects_to_update.count():
|
|
|
+ raise ValidationError('Could not find find all objects to update.')
|
|
|
+
|
|
|
updated_objects = []
|
|
|
|
|
|
- for obj in instances:
|
|
|
+ for obj in objects_to_update:
|
|
|
obj_id = getattr(obj, id_attr, None)
|
|
|
obj_validated_data = all_validated_data_by_id.get(obj_id)
|
|
|
if obj_id and obj_validated_data:
|