From 7e453cccbf5f07368d09b3925bfe471534478a85 Mon Sep 17 00:00:00 2001 From: timur Date: Wed, 13 Mar 2024 18:09:42 +0100 Subject: [PATCH 01/35] Update readme --- README.md | 57 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index bc282fb4be6..22cfc560414 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,11 @@
- Customer-centric e-commerce on a modern stack + Truly composable and open commerce
- A headless, GraphQL commerce platform delivering ultra-fast, dynamic, personalized shopping experiences.
Beautiful online stores, anywhere, on any device. + The API-only, headless, GraphQL-first, composable e-commerce platform that puts developers first.

@@ -60,29 +60,56 @@ ## What makes Saleor special? -Saleor is a rapidly-growing open-source e-commerce platform that serves high-volume companies. Designed from the ground up to be extensible, headless, and composable. +- **Technology-agnostic** - no monolithic plugin architecture or technology lock-in. -Learn more about [architecture](https://docs.saleor.io/docs/3.x/overview/architecture). +- **GraphQL only** - Not afterthought API design or fragmentation across different styles of API. -## Features +- **Headless and API only** - APIs is the only way to interact, configure or extend backend. + +- **Truly open licence** - no community edition or commercial limitations, single version of Saleor. + +- **Cloud native** - battle tested on global brands. + +- **Native-multichannel**: Per [channel](https://docs.saleor.io/docs/3.x/developer/channels) control of pricing, currencies, stock, product, and more + + +## Why API-only Architecture? + +Saleor's API-first extensibility provides powerful tools for developers to extend backend using [webhooks](https://docs.saleor.io/docs/3.x/developer/extending/webhooks/overview), attributes, [metadata](https://docs.saleor.io/docs/3.x/api-usage/metadata), [apps](https://docs.saleor.io/docs/3.x/developer/extending/apps/overview), [subscription queries](https://docs.saleor.io/docs/3.x/developer/extending/webhooks/subscription-webhook-payloads), [API extensions](https://docs.saleor.io/docs/3.x/developer/extending/webhooks/synchronous-events/overview), [dashboard iframes](https://docs.saleor.io/docs/3.x/developer/extending/apps/overview). -- **Headless / API first**: Build mobile apps, custom storefronts, POS, automation, etc -- **Extensible**: Build anything with webhooks, apps, metadata, and attributes -- [**App Store**](https://github.com/saleor/apps): Leverage a collection of built-in integrations -- **GraphQL API**: Get many resources in a single request and [more](https://graphql.org/) -- **Multichannel**: Per channel control of pricing, currencies, stock, product, and more +Compared to traditional plugin architectures (monoliths) it provides the following benefits: + +* Less downtime as apps can be deployed independently. +* Reliability and performance - custom logic is separated from the core. +* Simplified upgrade paths - eliminates incompatibility conflicts between extensions. +* Technology-agnostic - works with any technology, stack or language. +* Parallel development - easier to collaborate than with monolithic core. +* Simplified debugging - easier to narrow down bugs in independent services. +* Scalability - each extension or service can be scaled independently. + +### What are the tradeoffs? +If you are a single developer working with small business that doesn't have high traffic or critical need for 24/7 availability, +using service oriented approach might feel more complex compared to traditional Wordpress or Magento approach that provides language specific framework, runtime, database schema, aspect oriented programming and other tools to quickstart. + +However, if you deploy on a daily basis, reliability and uptime is critical, +you need to collaborate with other developers, or you have non-trivial requirements you might be in the right place. + +## Features - **Enterprise ready**: Secure, scalable, and stable. Battle-tested by big brands -- **CMS**: Content is king, that's why we have a kingdom built-in - **Dashboard**: User-friendly, fast, and productive. (Decoupled project [repo](https://github.com/saleor/saleor-dashboard) ) - **Global by design** Multi-currency, multi-language, multi-warehouse, tutti multi! -- **Orders**: A comprehensive system for orders, dispatch, and refunds +- **CMS**: Content is king, that's why we have a kingdom built-in +- **Product management**: advance model for large and complex catalogs +- **Orders**: Flexible order model, split-payments, multi-warehouse, returns and more +- **Customers**: order history and preferences +- **Promotion engine**: sales, vouchers, cart rules, giftcards +- **Payment orchestration**: multi-gateway, extensible payment API, flexible flows - **Cart**: Advanced payment and tax options, with full control over discounts and promotions - **Payments**: Flexible API architecture allows integration of any payment method +- **Translations**: Fully translatable catalog - **SEO**: Packed with features that get stores to a wider audience -- **Cloud**: Optimized for deployments using Docker +- **Apps**: Custom dashboard apps -Saleor is free and always will be. -Help us out… If you love free stuff and great software, give us a star! 🌟 ![Saleor Dashboard - Modern UI for managing your e-commerce](https://user-images.githubusercontent.com/9268745/224249510-d3c7658e-6d5c-42c5-b4fb-93eaf65a5335.png) From a29910d1d921f97c5afbd35db56111b975df9912 Mon Sep 17 00:00:00 2001 From: timur Date: Thu, 14 Mar 2024 11:09:38 +0100 Subject: [PATCH 02/35] Update intro --- README.md | 47 +++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 22cfc560414..7da6570b610 100644 --- a/README.md +++ b/README.md @@ -8,17 +8,17 @@
- Truly composable and open commerce + Commerce that works with your language and stack
- The API-only, headless, GraphQL-first, composable e-commerce platform that puts developers first. + GraphQL native, API-only platform for scalable composable commerce.

- Join our active, engaged community:
+ Join our community:
Website | Twitter @@ -29,7 +29,7 @@
- Blog + Blog | Subscribe to newsletter
@@ -64,13 +64,13 @@ - **GraphQL only** - Not afterthought API design or fragmentation across different styles of API. -- **Headless and API only** - APIs is the only way to interact, configure or extend backend. +- **Headless and API only** - APIs are the only way to interact, configure, or extend the backend. -- **Truly open licence** - no community edition or commercial limitations, single version of Saleor. +- **Open source** - a single version of Saleor without feature fragmentation or commercial limitations. - **Cloud native** - battle tested on global brands. -- **Native-multichannel**: Per [channel](https://docs.saleor.io/docs/3.x/developer/channels) control of pricing, currencies, stock, product, and more +- **Native-multichannel** - Per [channel](https://docs.saleor.io/docs/3.x/developer/channels) control of pricing, currencies, stock, product, and more. ## Why API-only Architecture? @@ -79,17 +79,16 @@ Saleor's API-first extensibility provides powerful tools for developers to exten Compared to traditional plugin architectures (monoliths) it provides the following benefits: -* Less downtime as apps can be deployed independently. +* There is less downtime as apps are deployed independently. * Reliability and performance - custom logic is separated from the core. * Simplified upgrade paths - eliminates incompatibility conflicts between extensions. -* Technology-agnostic - works with any technology, stack or language. -* Parallel development - easier to collaborate than with monolithic core. +* Technology-agnostic - works with any technology, stack, or language. +* Parallel development - easier to collaborate than with a monolithic core. * Simplified debugging - easier to narrow down bugs in independent services. -* Scalability - each extension or service can be scaled independently. +* Scalability - extensions and apps can be scaled independently. ### What are the tradeoffs? -If you are a single developer working with small business that doesn't have high traffic or critical need for 24/7 availability, -using service oriented approach might feel more complex compared to traditional Wordpress or Magento approach that provides language specific framework, runtime, database schema, aspect oriented programming and other tools to quickstart. +If you are a single developer working with a small business that doesn't have high traffic or a critical need for 24/7 availability, using a service-oriented approach might feel more complex compared to the traditional WordPress or Magento approach that provides a language-specific framework, runtime, database schema, aspect-oriented programming, and other tools to a quick start. However, if you deploy on a daily basis, reliability and uptime is critical, you need to collaborate with other developers, or you have non-trivial requirements you might be in the right place. @@ -98,17 +97,17 @@ you need to collaborate with other developers, or you have non-trivial requireme - **Enterprise ready**: Secure, scalable, and stable. Battle-tested by big brands - **Dashboard**: User-friendly, fast, and productive. (Decoupled project [repo](https://github.com/saleor/saleor-dashboard) ) - **Global by design** Multi-currency, multi-language, multi-warehouse, tutti multi! -- **CMS**: Content is king, that's why we have a kingdom built-in -- **Product management**: advance model for large and complex catalogs -- **Orders**: Flexible order model, split-payments, multi-warehouse, returns and more -- **Customers**: order history and preferences -- **Promotion engine**: sales, vouchers, cart rules, giftcards -- **Payment orchestration**: multi-gateway, extensible payment API, flexible flows -- **Cart**: Advanced payment and tax options, with full control over discounts and promotions -- **Payments**: Flexible API architecture allows integration of any payment method -- **Translations**: Fully translatable catalog -- **SEO**: Packed with features that get stores to a wider audience -- **Apps**: Custom dashboard apps +- **CMS**: Manage product or marketing content. +- **Product management**: A rich content model for large and complex catalogs. +- **Orders**: Flexible order model, split payments, multi-warehouse, returns, and more. +- **Customers**: Order history and preferences. +- **Promotion engine**: Sales, vouchers, cart rules, giftcards. +- **Payment orchestration**: multi-gateway, extensible payment API, flexible flows. +- **Cart**: Advanced payment and tax options, with full control over discounts and promotions. +- **Payments**: Flexible API architecture allows integration of any payment method. +- **Translations**: Fully translatable catalog. +- **SEO**: Unlimited SEO freedom with headless architecture. +- **Apps**: Extend dashboard via iframe with any web stack. ![Saleor Dashboard - Modern UI for managing your e-commerce](https://user-images.githubusercontent.com/9268745/224249510-d3c7658e-6d5c-42c5-b4fb-93eaf65a5335.png) From b84d767a1aee8e0ed059f8908fb3e9def07c08b7 Mon Sep 17 00:00:00 2001 From: timur Date: Thu, 14 Mar 2024 11:11:34 +0100 Subject: [PATCH 03/35] Remove survey --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index 7da6570b610..a308ae51394 100644 --- a/README.md +++ b/README.md @@ -172,13 +172,6 @@ If nothing grabs your attention, check [our roadmap](https://github.com/orgs/sal Get more details in our [Contributing Guide](https://docs.saleor.io/docs/developer/community/contributing). -## Your feedback - -Do you use Saleor as an e-commerce platform? -Fill out this short survey and help us grow. It will take just a minute, but means a lot! - -[Take a survey](https://mirumee.typeform.com/to/sOIJbJ) - ## License Disclaimer: Everything you see here is open and free to use as long as you comply with the [license](https://github.com/saleor/saleor/blob/master/LICENSE). There are no hidden charges. We promise to do our best to fix bugs and improve the code. From fa6e1a4ced6ec679826ad861d886413c359e460a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Wiaduch?= <2can@users.noreply.github.com> Date: Mon, 25 Mar 2024 10:26:08 +0100 Subject: [PATCH 04/35] Fix typos (#15666) --- saleor/graphql/account/types.py | 2 +- saleor/graphql/app/types.py | 2 +- saleor/graphql/attribute/descriptions.py | 2 +- .../channel/mutations/channel_create.py | 2 +- saleor/graphql/channel/types.py | 2 +- saleor/graphql/core/types/common.py | 4 ++-- .../external_notification_trigger.py | 2 +- saleor/graphql/schema.graphql | 24 +++++++++---------- saleor/graphql/shop/types.py | 2 +- 9 files changed, 21 insertions(+), 21 deletions(-) diff --git a/saleor/graphql/account/types.py b/saleor/graphql/account/types.py index ebf62b0600a..f2e8637964e 100644 --- a/saleor/graphql/account/types.py +++ b/saleor/graphql/account/types.py @@ -731,7 +731,7 @@ class ChoiceValue(graphene.ObjectType): "\n\nMany fields in the JSON refer to address fields by one-letter " "abbreviations. These are defined as follows:\n\n" "- `N`: Name\n" - "- `O`: Organisation\n" + "- `O`: Organization\n" "- `A`: Street Address Line(s)\n" "- `D`: Dependent locality (may be an inner-city district or a suburb)\n" "- `C`: City or Locality\n" diff --git a/saleor/graphql/app/types.py b/saleor/graphql/app/types.py index 36de29bb9c7..58861dccc69 100644 --- a/saleor/graphql/app/types.py +++ b/saleor/graphql/app/types.py @@ -545,7 +545,7 @@ class App(ModelObjectType[models.App]): ) version = graphene.String(description="Version number of the app.") access_token = graphene.String( - description="JWT token used to authenticate by thridparty app." + description="JWT token used to authenticate by third-party app." ) author = graphene.String( description=("The App's author name." + ADDED_IN_313 + PREVIEW_FEATURE) diff --git a/saleor/graphql/attribute/descriptions.py b/saleor/graphql/attribute/descriptions.py index e313204b35b..7a938fd06b1 100644 --- a/saleor/graphql/attribute/descriptions.py +++ b/saleor/graphql/attribute/descriptions.py @@ -47,7 +47,7 @@ class AttributeValueDescriptions: + RICH_CONTENT ) PLAIN_TEXT = ( - "Represents the text of the attribute value, plain text without formating." + "Represents the text of the attribute value, plain text without formatting." ) BOOLEAN = "Represents the boolean value of the attribute value." DATE = "Represents the date value of the attribute value." diff --git a/saleor/graphql/channel/mutations/channel_create.py b/saleor/graphql/channel/mutations/channel_create.py index 3f40ad844d7..c4488f40f43 100644 --- a/saleor/graphql/channel/mutations/channel_create.py +++ b/saleor/graphql/channel/mutations/channel_create.py @@ -91,7 +91,7 @@ class OrderSettingsInput(BaseInputObjectType): automatically_fulfill_non_shippable_gift_card = graphene.Boolean( required=False, description="When enabled, all non-shippable gift card orders " - "will be fulfilled automatically. By defualt set to True.", + "will be fulfilled automatically. By default set to True.", ) expire_orders_after = Minute( required=False, diff --git a/saleor/graphql/channel/types.py b/saleor/graphql/channel/types.py index 71af5463a48..7238fa33b58 100644 --- a/saleor/graphql/channel/types.py +++ b/saleor/graphql/channel/types.py @@ -259,7 +259,7 @@ class OrderSettings(ObjectType): allow_unpaid_orders = graphene.Boolean( required=True, description=( - "Determine if it is possible to place unpdaid order by calling " + "Determine if it is possible to place unpaid order by calling " "`checkoutComplete` mutation." + ADDED_IN_315 + PREVIEW_FEATURE ), ) diff --git a/saleor/graphql/core/types/common.py b/saleor/graphql/core/types/common.py index f8715a54e21..253adb83ec4 100644 --- a/saleor/graphql/core/types/common.py +++ b/saleor/graphql/core/types/common.py @@ -259,7 +259,7 @@ class CheckoutError(Error): code = CheckoutErrorCode(description="The error code.", required=True) variants = NonNullList( graphene.ID, - description="List of varint IDs which causes the error.", + description="List of variant IDs which causes the error.", required=False, ) lines = NonNullList( @@ -408,7 +408,7 @@ class PermissionGroupError(Error): ) channels = NonNullList( graphene.ID, - description="List of chnnels IDs which causes the error.", + description="List of channels IDs which causes the error.", required=False, ) diff --git a/saleor/graphql/notifications/mutations/external_notification_trigger.py b/saleor/graphql/notifications/mutations/external_notification_trigger.py index e4ed876f923..8875888b571 100644 --- a/saleor/graphql/notifications/mutations/external_notification_trigger.py +++ b/saleor/graphql/notifications/mutations/external_notification_trigger.py @@ -32,7 +32,7 @@ class ExternalNotificationTriggerInput(graphene.InputObjectType): extra_payload = JSONString( description=( "Additional payload that will be merged with " - "the one based on the bussines object ID." + "the one based on the business object ID." ) ) external_event_type = graphene.String( diff --git a/saleor/graphql/schema.graphql b/saleor/graphql/schema.graphql index 28142631a1c..d5c1b8ce2e2 100644 --- a/saleor/graphql/schema.graphql +++ b/saleor/graphql/schema.graphql @@ -3147,7 +3147,7 @@ type App implements Node & ObjectWithMetadata @doc(category: "Apps") { """Version number of the app.""" version: String - """JWT token used to authenticate by thridparty app.""" + """JWT token used to authenticate by third-party app.""" accessToken: String """ @@ -5615,7 +5615,7 @@ type OrderSettings { deleteExpiredOrdersAfter: Day! """ - Determine if it is possible to place unpdaid order by calling `checkoutComplete` mutation. + Determine if it is possible to place unpaid order by calling `checkoutComplete` mutation. Added in Saleor 3.15. @@ -6602,7 +6602,7 @@ type AttributeValue implements Node @doc(category: "Attributes") { richText: JSONString """ - Represents the text of the attribute value, plain text without formating. + Represents the text of the attribute value, plain text without formatting. """ plainText: String @@ -10188,7 +10188,7 @@ type Shop implements ObjectWithMetadata { enableAccountConfirmationByEmail: Boolean """ - Determines if user can login without confirmation when `enableAccountConfrimation` is enabled. + Determines if user can login without confirmation when `enableAccountConfirmation` is enabled. Added in Saleor 3.15. @@ -15015,7 +15015,7 @@ type AddressValidationData @doc(category: "Users") { Many fields in the JSON refer to address fields by one-letter abbreviations. These are defined as follows: - `N`: Name - - `O`: Organisation + - `O`: Organization - `A`: Street Address Line(s) - `D`: Dependent locality (may be an inner-city district or a suburb) - `C`: City or Locality @@ -15033,7 +15033,7 @@ type AddressValidationData @doc(category: "Users") { Many fields in the JSON refer to address fields by one-letter abbreviations. These are defined as follows: - `N`: Name - - `O`: Organisation + - `O`: Organization - `A`: Street Address Line(s) - `D`: Dependent locality (may be an inner-city district or a suburb) - `C`: City or Locality @@ -27552,7 +27552,7 @@ input ExternalNotificationTriggerInput { ids: [ID!]! """ - Additional payload that will be merged with the one based on the bussines object ID. + Additional payload that will be merged with the one based on the business object ID. """ extraPayload: JSONString @@ -28864,7 +28864,7 @@ type CheckoutError @doc(category: "Checkout") { """The error code.""" code: CheckoutErrorCode! - """List of varint IDs which causes the error.""" + """List of variant IDs which causes the error.""" variants: [ID!] """List of line Ids which cause the error.""" @@ -29567,7 +29567,7 @@ input OrderSettingsInput @doc(category: "Orders") { automaticallyConfirmAllNewOrders: Boolean """ - When enabled, all non-shippable gift card orders will be fulfilled automatically. By defualt set to True. + When enabled, all non-shippable gift card orders will be fulfilled automatically. By default set to True. """ automaticallyFulfillNonShippableGiftCard: Boolean @@ -29926,7 +29926,7 @@ input AttributeValueCreateInput @doc(category: "Attributes") { richText: JSONString """ - Represents the text of the attribute value, plain text without formating. + Represents the text of the attribute value, plain text without formatting. DEPRECATED: this field will be removed in Saleor 4.0.The plain text attribute hasn't got predefined value, so can be specified only from instance that supports the given attribute. """ @@ -30055,7 +30055,7 @@ input AttributeValueUpdateInput @doc(category: "Attributes") { richText: JSONString """ - Represents the text of the attribute value, plain text without formating. + Represents the text of the attribute value, plain text without formatting. DEPRECATED: this field will be removed in Saleor 4.0.The plain text attribute hasn't got predefined value, so can be specified only from instance that supports the given attribute. """ @@ -31889,7 +31889,7 @@ type PermissionGroupError @doc(category: "Users") { """List of user IDs which causes the error.""" users: [ID!] - """List of chnnels IDs which causes the error.""" + """List of channels IDs which causes the error.""" channels: [ID!] } diff --git a/saleor/graphql/shop/types.py b/saleor/graphql/shop/types.py index 3a8cea914af..bf91ae960ff 100644 --- a/saleor/graphql/shop/types.py +++ b/saleor/graphql/shop/types.py @@ -324,7 +324,7 @@ class Shop(graphene.ObjectType): graphene.Boolean, description=( "Determines if user can login without confirmation when " - "`enableAccountConfrimation` is enabled." + ADDED_IN_315 + "`enableAccountConfirmation` is enabled." + ADDED_IN_315 ), permissions=[SitePermissions.MANAGE_SETTINGS], ) From e2b69f3e106c9f36dd1d17a7c928c692c4ee0ec8 Mon Sep 17 00:00:00 2001 From: Patryk Zawadzki <81205+patrys@users.noreply.github.com> Date: Mon, 25 Mar 2024 11:01:55 +0100 Subject: [PATCH 05/35] Use more readable version of "sleep forever" (#15683) --- .devcontainer/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 1b4a693ea8c..981924e4532 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -6,7 +6,7 @@ services: build: context: .. dockerfile: .devcontainer/Dockerfile - command: /bin/sh -c "while sleep 1000; do :; done" + command: sleep infinity env_file: - common.env - backend.env From a6693fdc3b1773dc09ddc4dbaf96eed3eb3200e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Mon, 25 Mar 2024 13:16:52 +0100 Subject: [PATCH 06/35] Add allow_writer context manager (#15651) * Add allow_writer context manager and middlewares to use it * Add allow_writer usages; add allow_writer_in_context context manager * Remove test settings * Add more usages of allow_writer * Add docstrings * Fix queries count in tests * Add tests --- saleor/checkout/calculations.py | 26 ++-- saleor/checkout/payment_utils.py | 4 +- saleor/checkout/utils.py | 6 +- saleor/core/db/connection.py | 131 ++++++++++++++++++ saleor/core/db/tests/__init__.py | 0 saleor/core/db/tests/test_connection.py | 88 ++++++++++++ saleor/discount/utils.py | 25 ++-- .../mutations/authentication/set_password.py | 2 + .../benchmark/test_checkout_mutations.py | 2 +- saleor/graphql/core/connection.py | 17 ++- saleor/graphql/core/dataloaders.py | 6 +- saleor/graphql/core/mutations.py | 3 + saleor/graphql/meta/mutations/base.py | 2 + .../graphql/product/dataloaders/products.py | 7 +- .../product/mutations/digital_contents.py | 2 + saleor/order/calculations.py | 68 ++++----- saleor/settings.py | 6 + 17 files changed, 329 insertions(+), 66 deletions(-) create mode 100644 saleor/core/db/connection.py create mode 100644 saleor/core/db/tests/__init__.py create mode 100644 saleor/core/db/tests/test_connection.py diff --git a/saleor/checkout/calculations.py b/saleor/checkout/calculations.py index 3493f023aff..92b1070af1f 100644 --- a/saleor/checkout/calculations.py +++ b/saleor/checkout/calculations.py @@ -7,6 +7,7 @@ from prices import Money, TaxedMoney from ..checkout import base_calculations +from ..core.db.connection import allow_writer from ..core.prices import quantize_price from ..core.taxes import TaxData, TaxEmptyData, zero_money, zero_taxed_money from ..discount.utils import ( @@ -325,18 +326,19 @@ def _fetch_checkout_prices_if_expired( checkout.price_expiration = timezone.now() + settings.CHECKOUT_PRICES_TTL - checkout.save( - update_fields=checkout_update_fields, - using=settings.DATABASE_CONNECTION_DEFAULT_NAME, - ) - checkout.lines.bulk_update( - [line_info.line for line_info in lines], - [ - "total_price_net_amount", - "total_price_gross_amount", - "tax_rate", - ], - ) + with allow_writer(): + checkout.save( + update_fields=checkout_update_fields, + using=settings.DATABASE_CONNECTION_DEFAULT_NAME, + ) + checkout.lines.bulk_update( + [line_info.line for line_info in lines], + [ + "total_price_net_amount", + "total_price_gross_amount", + "tax_rate", + ], + ) return checkout_info, lines diff --git a/saleor/checkout/payment_utils.py b/saleor/checkout/payment_utils.py index 0b2a6a96944..48d0043b9dd 100644 --- a/saleor/checkout/payment_utils.py +++ b/saleor/checkout/payment_utils.py @@ -7,6 +7,7 @@ from django.db.models import Exists, Q from prices import Money +from ..core.db.connection import allow_writer from ..core.taxes import zero_money from ..payment.models import TransactionItem from . import CheckoutAuthorizeStatus, CheckoutChargeStatus @@ -115,7 +116,8 @@ def update_checkout_payment_statuses( fields_to_update.append("charge_status") if fields_to_update: fields_to_update.append("last_change") - checkout.save(update_fields=fields_to_update) + with allow_writer(): + checkout.save(update_fields=fields_to_update) def update_refundable_for_checkout(checkout_pk): diff --git a/saleor/checkout/utils.py b/saleor/checkout/utils.py index be07aab2b73..9d8f97082c3 100644 --- a/saleor/checkout/utils.py +++ b/saleor/checkout/utils.py @@ -13,6 +13,7 @@ from prices import Money from ..account.models import User +from ..core.db.connection import allow_writer from ..core.exceptions import ProductNotPublished from ..core.taxes import zero_taxed_money from ..core.utils.promo_code import ( @@ -960,9 +961,7 @@ def cancel_active_payments(checkout: Checkout): def is_shipping_required(lines: Iterable["CheckoutLineInfo"]): """Check if shipping is required for given checkout lines.""" - return any( - line_info.product.product_type.is_shipping_required for line_info in lines - ) + return any(line_info.product_type.is_shipping_required for line_info in lines) def validate_variants_in_checkout_lines(lines: Iterable["CheckoutLineInfo"]): @@ -1012,6 +1011,7 @@ def delete_external_shipping_id(checkout: Checkout, save: bool = False): metadata.save(update_fields=["private_metadata"]) +@allow_writer() def get_or_create_checkout_metadata(checkout: "Checkout") -> CheckoutMetadata: if hasattr(checkout, "metadata_storage"): return checkout.metadata_storage diff --git a/saleor/core/db/connection.py b/saleor/core/db/connection.py new file mode 100644 index 00000000000..e359caa6878 --- /dev/null +++ b/saleor/core/db/connection.py @@ -0,0 +1,131 @@ +import logging +import traceback +from contextlib import contextmanager + +from django.conf import settings +from django.core.management.color import color_style +from django.db import connections +from django.db.backends.base.base import BaseDatabaseWrapper + +from ...graphql.core.context import SaleorContext, get_database_connection_name + +logger = logging.getLogger(__name__) + +writer = settings.DATABASE_CONNECTION_DEFAULT_NAME +replica = settings.DATABASE_CONNECTION_REPLICA_NAME + +# Limit the number of frames in the traceback in `log_writer_usage_middleware` to avoid +# excessive log size. +TRACEBACK_LIMIT = 20 + +UNSAFE_WRITER_ACCESS_MSG = ( + "Unsafe access to the writer DB detected. Call `using()` on the `QuerySet` " + "to utilize a replica DB, or employ the `allow_writer` context manager to " + "explicitly permit access to the writer." +) + + +class UnsafeDBUsageError(Exception): + pass + + +class UnsafeWriterAccessError(UnsafeDBUsageError): + pass + + +class UnsafeReplicaUsageError(UnsafeDBUsageError): + pass + + +@contextmanager +def allow_writer(): + """Context manager that allows write access to the default database connection. + + This context manager works in conjunction with the `restrict_writer_middleware` and + `log_writer_usage_middleware` middlewares. If any of these middlewares are enabled, + use the `allow_writer` context manager to allow write access to the default + database. Otherwise an error will be raised or a log message will be emitted. + """ + + default_connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + # Check if we are already in an allow_writer block. If so we don't need to do + # anything and we don't have to close access to the writer at the end. + in_allow_writer_block = getattr(default_connection, "_allow_writer", False) + if not in_allow_writer_block: + setattr(default_connection, "_allow_writer", True) + try: + yield + finally: + if not in_allow_writer_block: + # Close writer access when exiting the outermost allow_writer block. + setattr(default_connection, "_allow_writer", False) + + +@contextmanager +def allow_writer_in_context(context: SaleorContext): + """Context manager that allows write access to the default database connection in a context (SaleorContext). + + This is a helper context manager that conditionally allows write access based on the + database connection name in the given context. + """ + conn_name = get_database_connection_name(context) + if conn_name == settings.DATABASE_CONNECTION_DEFAULT_NAME: + with allow_writer(): + yield + else: + yield + + +def restrict_writer_middleware(get_response): + """Middleware that restricts write access to the default database connection. + + This middleware will raise an error if a write operation is attempted on the default + database connection. To allow writes, use the `allow_writer` context manager. Make + sure that writer is not used accidentally and always use the `using` queryset method + with proper database connection name. + """ + + def middleware(request): + with connections[writer].execute_wrapper(_restrict_writer): + with connections[replica].execute_wrapper(_restrict_writer): + return get_response(request) + + return middleware + + +def _restrict_writer(execute, sql, params, many, context): + conn: BaseDatabaseWrapper = context["connection"] + allow_writer = getattr(conn, "_allow_writer", False) + if conn.alias == writer and not allow_writer: + raise UnsafeWriterAccessError(f"{UNSAFE_WRITER_ACCESS_MSG} SQL: {sql}") + return execute(sql, params, many, context) + + +def log_writer_usage_middleware(get_response): + """Middleware that logs write access to the default database connection. + + This is similar to the `restrict_writer_middleware` middleware, but instead of + raising an error, it logs a message when a write operation is attempted on the + default database connection. + """ + + def middleware(request): + with connections[writer].execute_wrapper(_log_writer_usage): + return get_response(request) + + return middleware + + +def _log_writer_usage(execute, sql, params, many, context): + conn: BaseDatabaseWrapper = context["connection"] + allow_writer = getattr(conn, "_allow_writer", False) + if conn.alias == writer and not allow_writer: + stack_trace = traceback.extract_stack(limit=TRACEBACK_LIMIT) + error_msg = color_style().NOTICE(UNSAFE_WRITER_ACCESS_MSG) + log_msg = ( + f"{error_msg} SQL: {sql} \n" + f"Traceback: \n{''.join(traceback.format_list(stack_trace))}" + ) + logger.error(log_msg) + return execute(sql, params, many, context) diff --git a/saleor/core/db/tests/__init__.py b/saleor/core/db/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/saleor/core/db/tests/test_connection.py b/saleor/core/db/tests/test_connection.py new file mode 100644 index 00000000000..6137d3af642 --- /dev/null +++ b/saleor/core/db/tests/test_connection.py @@ -0,0 +1,88 @@ +from unittest.mock import patch + +import pytest +from django.db import connections + +from ....graphql.context import SaleorContext +from ....tests.models import Book +from ..connection import ( + UnsafeWriterAccessError, + _log_writer_usage, + _restrict_writer, + allow_writer, + allow_writer_in_context, +) + + +def test_allow_writer(settings): + default_connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + assert not getattr(default_connection, "_allow_writer", False) + + with allow_writer(): + assert hasattr(default_connection, "_allow_writer") + assert default_connection._allow_writer + + +def test_allow_writer_yield_exception(settings): + default_connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + def example_function(): + raise Exception() + + try: + with allow_writer(): + example_function() + except Exception: + pass + + assert hasattr(default_connection, "_allow_writer") + assert not default_connection._allow_writer + + +def test_allow_writer_in_context_writer(settings): + context = SaleorContext() + context.allow_replica = False + + with allow_writer_in_context(context): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + assert hasattr(connection, "_allow_writer") + assert connection._allow_writer + + +@patch("saleor.core.db.connection.get_database_connection_name") +def test_allow_writer_in_context_replica(mocked_get_database_connection_name, settings): + mocked_get_database_connection_name.return_value = "replica" + + context = SaleorContext() + context.allow_replica = True + + with allow_writer_in_context(context): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + assert not getattr(connection, "_allow_writer") + + +def test_restrict_writer_raises_error(settings): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + with pytest.raises(UnsafeWriterAccessError): + with connection.execute_wrapper(_restrict_writer): + Book.objects.first() + + +def test_restrict_writer_in_allow_writer(settings): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + with connection.execute_wrapper(_restrict_writer): + with allow_writer(): + Book.objects.first() + + +def test_log_writer_usage(settings, caplog): + connection = connections[settings.DATABASE_CONNECTION_DEFAULT_NAME] + + with connection.execute_wrapper(_log_writer_usage): + Book.objects.first() + + assert caplog.records + msg = caplog.records[0].getMessage() + assert "Unsafe access to the writer DB detected" in msg diff --git a/saleor/discount/utils.py b/saleor/discount/utils.py index cf27fb685ff..4c21523cfd4 100644 --- a/saleor/discount/utils.py +++ b/saleor/discount/utils.py @@ -21,6 +21,7 @@ ) from ..checkout.fetch import CheckoutLineInfo, find_checkout_line_info from ..checkout.models import Checkout, CheckoutLine +from ..core.db.connection import allow_writer from ..core.exceptions import InsufficientStock from ..core.taxes import zero_money from ..core.utils.promo_code import InvalidPromoCode @@ -431,14 +432,17 @@ def create_discount_objects_for_catalogue_promotions( line_discounts_to_update.append(discount_to_update) - if line_discounts_to_create: - CheckoutLineDiscount.objects.bulk_create(line_discounts_to_create) - if line_discounts_to_update and updated_fields: - CheckoutLineDiscount.objects.bulk_update( - line_discounts_to_update, updated_fields - ) - if line_discount_ids_to_remove: - CheckoutLineDiscount.objects.filter(id__in=line_discount_ids_to_remove).delete() + with allow_writer(): + if line_discounts_to_create: + CheckoutLineDiscount.objects.bulk_create(line_discounts_to_create) + if line_discounts_to_update and updated_fields: + CheckoutLineDiscount.objects.bulk_update( + line_discounts_to_update, updated_fields + ) + if line_discount_ids_to_remove: + CheckoutLineDiscount.objects.filter( + id__in=line_discount_ids_to_remove + ).delete() def _get_discount_amount( @@ -640,7 +644,8 @@ def _set_checkout_base_prices(checkout_info, lines_info): if is_update_needed: checkout.base_subtotal = subtotal checkout.base_total = total - checkout.save(update_fields=["base_total_amount", "base_subtotal_amount"]) + with allow_writer(): + checkout.save(update_fields=["base_total_amount", "base_subtotal_amount"]) def _clear_checkout_discount( @@ -833,6 +838,7 @@ def _create_or_update_checkout_discount( ) +@allow_writer() def _handle_order_promotion( checkout: "Checkout", checkout_info: "CheckoutInfo", @@ -882,6 +888,7 @@ def delete_gift_line(checkout: "Checkout", lines_info: Iterable["CheckoutLineInf lines_info.remove(gift_line_info) # type: ignore[attr-defined] +@allow_writer() def _handle_gift_reward( checkout: "Checkout", checkout_info: "CheckoutInfo", diff --git a/saleor/graphql/account/mutations/authentication/set_password.py b/saleor/graphql/account/mutations/authentication/set_password.py index c319421f0e2..2ef7414ce8f 100644 --- a/saleor/graphql/account/mutations/authentication/set_password.py +++ b/saleor/graphql/account/mutations/authentication/set_password.py @@ -6,6 +6,7 @@ from .....account import events as account_events from .....account import models from .....account.error_codes import AccountErrorCode +from .....core.db.connection import allow_writer from .....order.utils import match_orders_with_new_user from ....core import ResolveInfo from ....core.context import disallow_replica_in_context @@ -35,6 +36,7 @@ class Meta: error_type_field = "account_errors" @classmethod + @allow_writer() def mutate( # type: ignore[override] cls, root, info: ResolveInfo, /, *, email, password, token ): diff --git a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py index 7dc79c2d0e4..d112bbcf12e 100644 --- a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py +++ b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py @@ -562,7 +562,7 @@ def test_create_checkout_with_order_promotion( } # when - with django_assert_num_queries(89): + with django_assert_num_queries(88): response = user_api_client.post_graphql(MUTATION_CHECKOUT_CREATE, variables) # then diff --git a/saleor/graphql/core/connection.py b/saleor/graphql/core/connection.py index d109f749486..d50b1f020cd 100644 --- a/saleor/graphql/core/connection.py +++ b/saleor/graphql/core/connection.py @@ -345,13 +345,16 @@ def create_connection_slice( ) args["sort_by"] = sort_by - slice = connection_from_queryset_slice( - queryset, - args, - connection_type, - edge_type or connection_type.Edge, - pageinfo_type or graphene.relay.PageInfo, - ) + from ...core.db.connection import allow_writer_in_context + + with allow_writer_in_context(info.context): + slice = connection_from_queryset_slice( + queryset, + args, + connection_type, + edge_type or connection_type.Edge, + pageinfo_type or graphene.relay.PageInfo, + ) if isinstance(iterable, ChannelQsContext): edges_with_context = [] diff --git a/saleor/graphql/core/dataloaders.py b/saleor/graphql/core/dataloaders.py index 48b464f87e1..329470986c7 100644 --- a/saleor/graphql/core/dataloaders.py +++ b/saleor/graphql/core/dataloaders.py @@ -7,6 +7,7 @@ from promise import Promise from promise.dataloader import DataLoader as BaseLoader +from ...core.db.connection import allow_writer_in_context from ...thumbnail.models import Thumbnail from ...thumbnail.utils import get_thumbnail_format from . import SaleorContext @@ -47,7 +48,10 @@ def batch_load_fn( # pylint: disable=method-hidden ) as scope: span = scope.span span.set_tag(opentracing.tags.COMPONENT, "dataloaders") - results = self.batch_load(keys) + + with allow_writer_in_context(self.context): + results = self.batch_load(keys) + if not isinstance(results, Promise): return Promise.resolve(results) return results diff --git a/saleor/graphql/core/mutations.py b/saleor/graphql/core/mutations.py index 023e65772e6..5e5e3f46747 100644 --- a/saleor/graphql/core/mutations.py +++ b/saleor/graphql/core/mutations.py @@ -27,6 +27,7 @@ from graphene.types.mutation import MutationOptions from graphql.error import GraphQLError +from ...core.db.connection import allow_writer from ...core.error_codes import MetadataErrorCode from ...core.exceptions import PermissionDenied from ...core.utils.events import call_event @@ -511,6 +512,7 @@ def check_permissions( return one_of_permissions_or_auth_filter_required(context, all_permissions) @classmethod + @allow_writer() def mutate(cls, root, info: ResolveInfo, **data): disallow_replica_in_context(info.context) setup_context_user(info.context) @@ -1043,6 +1045,7 @@ def perform_mutation( # type: ignore[override] return count, errors @classmethod + @allow_writer() def mutate(cls, root, info: ResolveInfo, **data): disallow_replica_in_context(info.context) setup_context_user(info.context) diff --git a/saleor/graphql/meta/mutations/base.py b/saleor/graphql/meta/mutations/base.py index 0ecd4926a4b..e125f013d37 100644 --- a/saleor/graphql/meta/mutations/base.py +++ b/saleor/graphql/meta/mutations/base.py @@ -5,6 +5,7 @@ from ....checkout import models as checkout_models from ....checkout.models import Checkout from ....core import models +from ....core.db.connection import allow_writer from ....core.error_codes import MetadataErrorCode from ....core.exceptions import PermissionDenied from ....discount import models as discount_models @@ -168,6 +169,7 @@ def check_permissions(cls, context, permissions=None, **data): return super().check_permissions(context, permissions) @classmethod + @allow_writer() def mutate(cls, root, info: ResolveInfo, **data): try: type_name, object_pk = cls.get_object_type_name_and_pk(data) diff --git a/saleor/graphql/product/dataloaders/products.py b/saleor/graphql/product/dataloaders/products.py index d599332f5f1..6053cd93ead 100644 --- a/saleor/graphql/product/dataloaders/products.py +++ b/saleor/graphql/product/dataloaders/products.py @@ -4,6 +4,7 @@ from django.db.models import F +from ....core.db.connection import allow_writer_in_context from ....product import ProductMediaTypes from ....product.models import ( Category, @@ -19,7 +20,10 @@ VariantChannelListingPromotionRule, VariantMedia, ) -from ...core.dataloaders import BaseThumbnailBySizeAndFormatLoader, DataLoader +from ...core.dataloaders import ( + BaseThumbnailBySizeAndFormatLoader, + DataLoader, +) ProductIdAndChannelSlug = tuple[int, str] VariantIdAndChannelSlug = tuple[int, str] @@ -556,6 +560,7 @@ class ProductTypeByProductIdLoader(DataLoader): context_key = "producttype_by_product_id" def batch_load(self, keys): + @allow_writer_in_context(self.context) def with_products(products): product_ids = {p.id for p in products} product_types_map = ( diff --git a/saleor/graphql/product/mutations/digital_contents.py b/saleor/graphql/product/mutations/digital_contents.py index f4e14145797..b114de20329 100644 --- a/saleor/graphql/product/mutations/digital_contents.py +++ b/saleor/graphql/product/mutations/digital_contents.py @@ -1,6 +1,7 @@ import graphene from django.core.exceptions import ValidationError +from ....core.db.connection import allow_writer from ....core.exceptions import PermissionDenied from ....permission.enums import ProductPermissions from ....product import models @@ -172,6 +173,7 @@ class Meta: permissions = (ProductPermissions.MANAGE_PRODUCTS,) @classmethod + @allow_writer() def mutate( # type: ignore[override] cls, root, info: ResolveInfo, /, *, variant_id: str ): diff --git a/saleor/order/calculations.py b/saleor/order/calculations.py index 8d649bd324d..bc23505a2d3 100644 --- a/saleor/order/calculations.py +++ b/saleor/order/calculations.py @@ -7,6 +7,7 @@ from django.db.models import prefetch_related_objects from prices import Money, TaxedMoney +from ..core.db.connection import allow_writer from ..core.prices import quantize_price from ..core.taxes import TaxData, TaxEmptyData, TaxError, zero_taxed_money from ..discount import DiscountType @@ -61,40 +62,44 @@ def fetch_order_prices_if_expired( _update_order_discount_for_voucher(order) _recalculate_prices( - order, manager, lines, database_connection_name=database_connection_name + order, + manager, + lines, + database_connection_name=database_connection_name, ) order.subtotal = get_subtotal(lines, order.currency) with transaction.atomic(savepoint=False): - order.save( - update_fields=[ - "subtotal_net_amount", - "subtotal_gross_amount", - "total_net_amount", - "total_gross_amount", - "undiscounted_total_net_amount", - "undiscounted_total_gross_amount", - "shipping_price_net_amount", - "shipping_price_gross_amount", - "shipping_tax_rate", - "should_refresh_prices", - "tax_error", - ] - ) - order.lines.bulk_update( - lines, - [ - "unit_price_net_amount", - "unit_price_gross_amount", - "undiscounted_unit_price_net_amount", - "undiscounted_unit_price_gross_amount", - "total_price_net_amount", - "total_price_gross_amount", - "undiscounted_total_price_net_amount", - "undiscounted_total_price_gross_amount", - "tax_rate", - ], - ) + with allow_writer(): + order.save( + update_fields=[ + "subtotal_net_amount", + "subtotal_gross_amount", + "total_net_amount", + "total_gross_amount", + "undiscounted_total_net_amount", + "undiscounted_total_gross_amount", + "shipping_price_net_amount", + "shipping_price_gross_amount", + "shipping_tax_rate", + "should_refresh_prices", + "tax_error", + ] + ) + order.lines.bulk_update( + lines, + [ + "unit_price_net_amount", + "unit_price_gross_amount", + "undiscounted_unit_price_net_amount", + "undiscounted_unit_price_gross_amount", + "total_price_net_amount", + "total_price_gross_amount", + "undiscounted_total_price_net_amount", + "undiscounted_total_price_gross_amount", + "tax_rate", + ], + ) return order, lines @@ -102,7 +107,8 @@ def fetch_order_prices_if_expired( def _update_order_discount_for_voucher(order: Order): """Create or delete OrderDiscount instances.""" if not order.voucher_id: - order.discounts.filter(type=DiscountType.VOUCHER).delete() + with allow_writer(): + order.discounts.filter(type=DiscountType.VOUCHER).delete() elif ( order.voucher_id diff --git a/saleor/settings.py b/saleor/settings.py index 77ed50553e9..3c70df8960d 100644 --- a/saleor/settings.py +++ b/saleor/settings.py @@ -231,6 +231,12 @@ def get_url_from_env(name, *, schemes=None) -> Optional[str]: "saleor.core.middleware.jwt_refresh_token_middleware", ] +ENABLE_RESTRICT_WRITER_MIDDLEWARE = get_bool_from_env( + "ENABLE_RESTRICT_WRITER_MIDDLEWARE", False +) +if ENABLE_RESTRICT_WRITER_MIDDLEWARE: + MIDDLEWARE = ["saleor.core.db.connection.log_writer_usage_middleware"] + MIDDLEWARE + INSTALLED_APPS = [ # External apps that need to go before django's "storages", From e0db9b3ddb9ae603252f76bda29d5a27f27043cb Mon Sep 17 00:00:00 2001 From: Maciej Korycinski Date: Mon, 25 Mar 2024 17:23:30 +0100 Subject: [PATCH 07/35] Fix incorrect assigment of the shipping address (#15694) --- .../checkout/mutations/checkout_delivery_method_update.py | 4 ++-- .../tests/mutations/test_checkout_delivery_method_update.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py b/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py index b07aae60497..2b785d976aa 100644 --- a/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py +++ b/saleor/graphql/checkout/mutations/checkout_delivery_method_update.py @@ -256,8 +256,8 @@ def _update_delivery_method( and collection_point.click_and_collect_option == WarehouseClickAndCollectOption.LOCAL_STOCK ): - checkout.shipping_address = collection_point.address - checkout_info.shipping_address = collection_point.address + checkout.shipping_address = collection_point.address.get_copy() + checkout_info.shipping_address = checkout.shipping_address checkout_fields_to_update += ["shipping_address"] invalidate_prices_updated_fields = invalidate_checkout( checkout_info, lines, manager, save=False diff --git a/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py b/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py index 6a9f7f01f94..c806867976d 100644 --- a/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py +++ b/saleor/graphql/checkout/tests/mutations/test_checkout_delivery_method_update.py @@ -553,6 +553,7 @@ def test_checkout_delivery_method_update_valid_method_not_all_shipping_data_for_ ) errors = data["errors"] assert checkout.shipping_address == delivery_method.address + assert checkout.shipping_address_id != delivery_method.address.id assert not errors assert getattr(checkout, attribute_name) == delivery_method From cb8a346d9872a6776e87e21aeab033335834a82b Mon Sep 17 00:00:00 2001 From: Maciej Korycinski Date: Mon, 25 Mar 2024 19:26:35 +0100 Subject: [PATCH 08/35] Save discounted_price for ProduvtVariantBulkUpdate (#15669) --- .../product_variant_bulk_update.py | 7 ++++++- .../test_product_variant_bulk_update.py | 19 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py b/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py index c527059cccb..7a4824d644c 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_bulk_update.py @@ -691,7 +691,12 @@ def save_variants(cls, variants_data_with_errors_list): models.ProductVariantChannelListing.objects.bulk_create(listings_to_create) models.ProductVariantChannelListing.objects.bulk_update( listings_to_update, - fields=["price_amount", "cost_price_amount", "preorder_quantity_threshold"], + fields=[ + "price_amount", + "discounted_price_amount", + "cost_price_amount", + "preorder_quantity_threshold", + ], ) warehouse_models.Stock.objects.filter(id__in=stocks_to_remove).delete() models.ProductVariantChannelListing.objects.filter( diff --git a/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py b/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py index 70c5f60b006..9decc564a61 100644 --- a/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py +++ b/saleor/graphql/product/tests/mutations/test_product_variant_bulk_update.py @@ -410,9 +410,7 @@ def test_product_variant_bulk_update_channel_listings_input( variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk) ProductChannelListing.objects.create(product=product, channel=channel_PLN) - existing_variant_listing = variant.channel_listings.exclude( - channel=channel_PLN - ).last() + existing_variant_listing = variant.channel_listings.get() assert variant.channel_listings.count() == 1 product_id = graphene.Node.to_global_id("Product", product.pk) @@ -452,8 +450,21 @@ def test_product_variant_bulk_update_channel_listings_input( ) get_graphql_content(response, ignore_errors=True) - existing_variant_listing.refresh_from_db() # then + existing_variant_listing.refresh_from_db() + assert ( + existing_variant_listing.price_amount == new_price_for_existing_variant_listing + ) + assert ( + existing_variant_listing.discounted_price_amount + == new_price_for_existing_variant_listing + ) + new_variant_listing = variant.channel_listings.get(channel=channel_PLN) + assert new_variant_listing.price_amount == not_existing_variant_listing_price + assert ( + new_variant_listing.discounted_price_amount + == not_existing_variant_listing_price + ) # only promotions with created channel will be marked as dirty second_promotion_rule.refresh_from_db() From 779934f3c3466f26c79a8200489fee18a05165eb Mon Sep 17 00:00:00 2001 From: Piotr Zabieglik <55899043+zedzior@users.noreply.github.com> Date: Mon, 25 Mar 2024 21:19:10 +0100 Subject: [PATCH 09/35] Add `translatable_content` to translation types. (#15652) * Add category_id field to CategoryTranslatableContent. * Add translatable content to CategoryTranslation type. * Replace CategoryTranslatableContent.ID with category ID. * Handle collections. * Correct ID in TranslatableContent types to point to translatable object. * Test translation query with new ID. * Add translatable_content to translation types. * Split TRANSLATION_CREATED query; extend translation query tests. * Split TRANSLATION_UPDATED query; extend translation query tests. * Move queries from fixture file to subscription_queries. * Remove comments. * Update changelog. * Add translated object id to translatable content types; keep unique id. * Update changelog. * Update subscription queries. * Adjust saleor/graphql/translations/tests/test_translations.py. * Adjust saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py. * Update and add missing labels pointing to 3.14 version. --- CHANGELOG.md | 1 + saleor/graphql/schema.graphql | 696 +++++++++++------- .../translations/tests/test_translations.py | 84 ++- saleor/graphql/translations/types.py | 236 +++++- .../tests/subscription_webhooks/fixtures.py | 200 +++++ .../subscription_queries.py | 213 ++++++ ...deliveries_for_translation_subscription.py | 472 ++++++++++-- 7 files changed, 1564 insertions(+), 338 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69054996092..4ce1ddd81e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ All notable, unreleased changes to this project will be documented in this file. ### Breaking changes ### GraphQL API +- Add `translatableContent` to all translation types; add translated object id to all translatable content types - #15617 by @zedzior - Add a `taxConfiguration` to a `Channel` - #15610 by @Air-t diff --git a/saleor/graphql/schema.graphql b/saleor/graphql/schema.graphql index d5c1b8ce2e2..800908aa588 100644 --- a/saleor/graphql/schema.graphql +++ b/saleor/graphql/schema.graphql @@ -4204,6 +4204,13 @@ type ShippingMethodTranslation implements Node @doc(category: "Shipping") { Rich text format. For reference see https://editorjs.io/ """ description: JSONString + + """ + Represents the shipping method fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: ShippingMethodTranslatableContent } type LanguageDisplay { @@ -4997,6 +5004,44 @@ enum LanguageCodeEnum { ZU_ZA } +""" +Represents shipping method's original translatable fields and related translations. +""" +type ShippingMethodTranslatableContent implements Node @doc(category: "Shipping") { + """The ID of the shipping method translatable content.""" + id: ID! + + """ + The ID of the shipping method to translate. + + Added in Saleor 3.14. + """ + shippingMethodId: ID! + + """Shipping method name to translate.""" + name: String! + + """ + Shipping method description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString + + """Returns translated shipping method fields for the given language code.""" + translation( + """A language code to return the translation for shipping method.""" + languageCode: LanguageCodeEnum! + ): ShippingMethodTranslation + + """ + Shipping method are the methods you'll use to get customer's orders to them. They are directly exposed to the customers. + + Requires one of the following permissions: MANAGE_SHIPPING. + """ + shippingMethod: ShippingMethodType @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") +} + """Represents shipping method channel listing.""" type ShippingMethodChannelListing implements Node @doc(category: "Shipping") { """The ID of shipping method channel listing.""" @@ -6643,6 +6688,103 @@ type AttributeValueTranslation implements Node @doc(category: "Attributes") { """Translated plain text attribute value .""" plainText: String + + """ + Represents the attribute value fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: AttributeValueTranslatableContent +} + +""" +Represents attribute value's original translatable fields and related translations. +""" +type AttributeValueTranslatableContent implements Node @doc(category: "Attributes") { + """The ID of the attribute value translatable content.""" + id: ID! + + """ + The ID of the attribute value to translate. + + Added in Saleor 3.14. + """ + attributeValueId: ID! + + """Name of the attribute value to translate.""" + name: String! + + """ + Attribute value. + + Rich text format. For reference see https://editorjs.io/ + """ + richText: JSONString + + """Attribute plain text value.""" + plainText: String + + """Returns translated attribute value fields for the given language code.""" + translation( + """A language code to return the translation for attribute value.""" + languageCode: LanguageCodeEnum! + ): AttributeValueTranslation + + """Represents a value of an attribute.""" + attributeValue: AttributeValue @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + + """ + Associated attribute that can be translated. + + Added in Saleor 3.9. + """ + attribute: AttributeTranslatableContent +} + +""" +Represents attribute's original translatable fields and related translations. +""" +type AttributeTranslatableContent implements Node @doc(category: "Attributes") { + """The ID of the attribute translatable content.""" + id: ID! + + """ + The ID of the attribute to translate. + + Added in Saleor 3.14. + """ + attributeId: ID! + + """Name of the attribute to translate.""" + name: String! + + """Returns translated attribute fields for the given language code.""" + translation( + """A language code to return the translation for attribute.""" + languageCode: LanguageCodeEnum! + ): AttributeTranslation + + """Custom attribute of a product.""" + attribute: Attribute @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") +} + +"""Represents attribute translations.""" +type AttributeTranslation implements Node @doc(category: "Attributes") { + """The ID of the attribute translation.""" + id: ID! + + """Translation language.""" + language: LanguageDisplay! + + """Translated attribute name.""" + name: String! + + """ + Represents the attribute fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: AttributeTranslatableContent } type File { @@ -6682,18 +6824,6 @@ input AttributeValueFilterInput @doc(category: "Attributes") { slugs: [String!] } -"""Represents attribute translations.""" -type AttributeTranslation implements Node @doc(category: "Attributes") { - """The ID of the attribute translation.""" - id: ID! - - """Translation language.""" - language: LanguageDisplay! - - """Translated attribute name.""" - name: String! -} - type ProductTypeCountableConnection @doc(category: "Products") { """Pagination data for this connection.""" pageInfo: PageInfo! @@ -7468,6 +7598,60 @@ type CategoryTranslation implements Node @doc(category: "Products") { Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") + + """ + Represents the category fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: CategoryTranslatableContent +} + +""" +Represents category original translatable fields and related translations. +""" +type CategoryTranslatableContent implements Node @doc(category: "Products") { + """The ID of the category translatable content.""" + id: ID! + + """ + The ID of the category to translate. + + Added in Saleor 3.14. + """ + categoryId: ID! + + """SEO title to translate.""" + seoTitle: String + + """SEO description to translate.""" + seoDescription: String + + """Name of the category translatable content.""" + name: String! + + """ + Category description to translate. + + Rich text format. For reference see https://editorjs.io/ + """ + description: JSONString + + """ + Description of the category. + + Rich text format. For reference see https://editorjs.io/ + """ + descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") + + """Returns translated category fields for the given language code.""" + translation( + """A language code to return the translation for category.""" + languageCode: LanguageCodeEnum! + ): CategoryTranslation + + """Represents a single category of products.""" + category: Category @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") } """Represents a version of a product such as different size or color.""" @@ -7929,6 +8113,43 @@ type ProductVariantTranslation implements Node @doc(category: "Products") { """Translated product variant name.""" name: String! + + """ + Represents the product variant fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: ProductVariantTranslatableContent +} + +""" +Represents product variant's original translatable fields and related translations. +""" +type ProductVariantTranslatableContent implements Node @doc(category: "Products") { + """The ID of the product variant translatable content.""" + id: ID! + + """ + The ID of the product variant to translate. + + Added in Saleor 3.14. + """ + productVariantId: ID! + + """Name of the product variant to translate.""" + name: String! + + """Returns translated product variant fields for the given language code.""" + translation( + """A language code to return the translation for product variant.""" + languageCode: LanguageCodeEnum! + ): ProductVariantTranslation + + """Represents a version of a product such as different size or color.""" + productVariant: ProductVariant @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + + """List of product variant attribute values that can be translated.""" + attributeValues: [AttributeValueTranslatableContent!]! } """Represents digital content associated with a product variant.""" @@ -8363,225 +8584,134 @@ type CollectionTranslation implements Node @doc(category: "Products") { Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") + + """ + Represents the collection fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: CollectionTranslatableContent } -"""Represents collection channel listing.""" -type CollectionChannelListing implements Node @doc(category: "Products") { - """The ID of the collection channel listing.""" +""" +Represents collection's original translatable fields and related translations. +""" +type CollectionTranslatableContent implements Node @doc(category: "Products") { + """The ID of the collection translatable content.""" id: ID! - publicationDate: Date @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `publishedAt` field to fetch the publication date.") """ - The collection publication date. + The ID of the collection to translate. - Added in Saleor 3.3. + Added in Saleor 3.14. """ - publishedAt: DateTime - - """Indicates if the collection is published in the channel.""" - isPublished: Boolean! + collectionId: ID! - """The channel to which the collection belongs.""" - channel: Channel! -} + """SEO title to translate.""" + seoTitle: String -"""Represents product translations.""" -type ProductTranslation implements Node @doc(category: "Products") { - """The ID of the product translation.""" - id: ID! + """SEO description to translate.""" + seoDescription: String - """Translation language.""" - language: LanguageDisplay! - - """Translated SEO title.""" - seoTitle: String - - """Translated SEO description.""" - seoDescription: String - - """Translated product name.""" - name: String + """Collection's name to translate.""" + name: String! """ - Translated description of the product. + Collection's description to translate. Rich text format. For reference see https://editorjs.io/ """ description: JSONString """ - Translated description of the product. + Description of the collection. Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") -} - -type WarehouseCountableConnection @doc(category: "Products") { - """Pagination data for this connection.""" - pageInfo: PageInfo! - edges: [WarehouseCountableEdge!]! - - """A total count of items in the collection.""" - totalCount: Int -} - -type WarehouseCountableEdge @doc(category: "Products") { - """The item at the end of the edge.""" - node: Warehouse! - - """A cursor for use in pagination.""" - cursor: String! -} - -input WarehouseFilterInput @doc(category: "Products") { - clickAndCollectOption: WarehouseClickAndCollectOptionEnum - metadata: [MetadataFilter!] - search: String - ids: [ID!] - isPrivate: Boolean - channels: [ID!] - slugs: [String!] -} - -input WarehouseSortingInput @doc(category: "Products") { - """Specifies the direction in which to sort warehouses.""" - direction: OrderDirection! - """Sort warehouses by the selected field.""" - field: WarehouseSortField! -} + """Returns translated collection fields for the given language code.""" + translation( + """A language code to return the translation for collection.""" + languageCode: LanguageCodeEnum! + ): CollectionTranslation -enum WarehouseSortField @doc(category: "Products") { - """Sort warehouses by name.""" - NAME + """Represents a collection of products.""" + collection: Collection @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") } -type TranslatableItemConnection { - """Pagination data for this connection.""" - pageInfo: PageInfo! - edges: [TranslatableItemEdge!]! +"""Represents collection channel listing.""" +type CollectionChannelListing implements Node @doc(category: "Products") { + """The ID of the collection channel listing.""" + id: ID! + publicationDate: Date @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `publishedAt` field to fetch the publication date.") - """A total count of items in the collection.""" - totalCount: Int -} + """ + The collection publication date. + + Added in Saleor 3.3. + """ + publishedAt: DateTime -type TranslatableItemEdge { - """The item at the end of the edge.""" - node: TranslatableItem! + """Indicates if the collection is published in the channel.""" + isPublished: Boolean! - """A cursor for use in pagination.""" - cursor: String! + """The channel to which the collection belongs.""" + channel: Channel! } -union TranslatableItem = ProductTranslatableContent | CollectionTranslatableContent | CategoryTranslatableContent | AttributeTranslatableContent | AttributeValueTranslatableContent | ProductVariantTranslatableContent | PageTranslatableContent | ShippingMethodTranslatableContent | VoucherTranslatableContent | MenuItemTranslatableContent | PromotionTranslatableContent | PromotionRuleTranslatableContent | SaleTranslatableContent - -""" -Represents product's original translatable fields and related translations. -""" -type ProductTranslatableContent implements Node @doc(category: "Products") { - """The ID of the product translatable content.""" +"""Represents product translations.""" +type ProductTranslation implements Node @doc(category: "Products") { + """The ID of the product translation.""" id: ID! - """SEO title to translate.""" + """Translation language.""" + language: LanguageDisplay! + + """Translated SEO title.""" seoTitle: String - """SEO description to translate.""" + """Translated SEO description.""" seoDescription: String - """Product's name to translate.""" - name: String! + """Translated product name.""" + name: String """ - Product's description to translate. + Translated description of the product. Rich text format. For reference see https://editorjs.io/ """ description: JSONString """ - Description of the product. + Translated description of the product. Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") - """Returns translated product fields for the given language code.""" - translation( - """A language code to return the translation for product.""" - languageCode: LanguageCodeEnum! - ): ProductTranslation - - """Represents an individual item for sale in the storefront.""" - product: Product @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") - - """List of product attribute values that can be translated.""" - attributeValues: [AttributeValueTranslatableContent!]! -} - -""" -Represents attribute value's original translatable fields and related translations. -""" -type AttributeValueTranslatableContent implements Node @doc(category: "Attributes") { - """The ID of the attribute value translatable content.""" - id: ID! - - """Name of the attribute value to translate.""" - name: String! - """ - Attribute value. + Represents the product fields to translate. - Rich text format. For reference see https://editorjs.io/ - """ - richText: JSONString - - """Attribute plain text value.""" - plainText: String - - """Returns translated attribute value fields for the given language code.""" - translation( - """A language code to return the translation for attribute value.""" - languageCode: LanguageCodeEnum! - ): AttributeValueTranslation - - """Represents a value of an attribute.""" - attributeValue: AttributeValue @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") - - """ - Associated attribute that can be translated. - - Added in Saleor 3.9. + Added in Saleor 3.14. """ - attribute: AttributeTranslatableContent + translatableContent: ProductTranslatableContent } """ -Represents attribute's original translatable fields and related translations. +Represents product's original translatable fields and related translations. """ -type AttributeTranslatableContent implements Node @doc(category: "Attributes") { - """The ID of the attribute.""" +type ProductTranslatableContent implements Node @doc(category: "Products") { + """The ID of the product translatable content.""" id: ID! - """Name of the attribute to translate.""" - name: String! - - """Returns translated attribute fields for the given language code.""" - translation( - """A language code to return the translation for attribute.""" - languageCode: LanguageCodeEnum! - ): AttributeTranslation - - """Custom attribute of a product.""" - attribute: Attribute @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") -} - -""" -Represents collection's original translatable fields and related translations. -""" -type CollectionTranslatableContent implements Node @doc(category: "Products") { - """The ID of the collection translatable content.""" - id: ID! + """ + The ID of the product to translate. + + Added in Saleor 3.14. + """ + productId: ID! """SEO title to translate.""" seoTitle: String @@ -8589,96 +8719,95 @@ type CollectionTranslatableContent implements Node @doc(category: "Products") { """SEO description to translate.""" seoDescription: String - """Collection's name to translate.""" + """Product's name to translate.""" name: String! """ - Collection's description to translate. + Product's description to translate. Rich text format. For reference see https://editorjs.io/ """ description: JSONString """ - Description of the collection. + Description of the product. Rich text format. For reference see https://editorjs.io/ """ descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") - """Returns translated collection fields for the given language code.""" + """Returns translated product fields for the given language code.""" translation( - """A language code to return the translation for collection.""" + """A language code to return the translation for product.""" languageCode: LanguageCodeEnum! - ): CollectionTranslation + ): ProductTranslation - """Represents a collection of products.""" - collection: Collection @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") -} + """Represents an individual item for sale in the storefront.""" + product: Product @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") -""" -Represents category original translatable fields and related translations. -""" -type CategoryTranslatableContent implements Node @doc(category: "Products") { - """The ID of the category translatable content.""" - id: ID! + """List of product attribute values that can be translated.""" + attributeValues: [AttributeValueTranslatableContent!]! +} - """SEO title to translate.""" - seoTitle: String +type WarehouseCountableConnection @doc(category: "Products") { + """Pagination data for this connection.""" + pageInfo: PageInfo! + edges: [WarehouseCountableEdge!]! - """SEO description to translate.""" - seoDescription: String + """A total count of items in the collection.""" + totalCount: Int +} - """Name of the category translatable content.""" - name: String! +type WarehouseCountableEdge @doc(category: "Products") { + """The item at the end of the edge.""" + node: Warehouse! - """ - Category description to translate. - - Rich text format. For reference see https://editorjs.io/ - """ - description: JSONString + """A cursor for use in pagination.""" + cursor: String! +} - """ - Description of the category. - - Rich text format. For reference see https://editorjs.io/ - """ - descriptionJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `description` field instead.") +input WarehouseFilterInput @doc(category: "Products") { + clickAndCollectOption: WarehouseClickAndCollectOptionEnum + metadata: [MetadataFilter!] + search: String + ids: [ID!] + isPrivate: Boolean + channels: [ID!] + slugs: [String!] +} - """Returns translated category fields for the given language code.""" - translation( - """A language code to return the translation for category.""" - languageCode: LanguageCodeEnum! - ): CategoryTranslation +input WarehouseSortingInput @doc(category: "Products") { + """Specifies the direction in which to sort warehouses.""" + direction: OrderDirection! - """Represents a single category of products.""" - category: Category @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") + """Sort warehouses by the selected field.""" + field: WarehouseSortField! } -""" -Represents product variant's original translatable fields and related translations. -""" -type ProductVariantTranslatableContent implements Node @doc(category: "Products") { - """The ID of the product variant translatable content.""" - id: ID! +enum WarehouseSortField @doc(category: "Products") { + """Sort warehouses by name.""" + NAME +} - """Name of the product variant to translate.""" - name: String! +type TranslatableItemConnection { + """Pagination data for this connection.""" + pageInfo: PageInfo! + edges: [TranslatableItemEdge!]! - """Returns translated product variant fields for the given language code.""" - translation( - """A language code to return the translation for product variant.""" - languageCode: LanguageCodeEnum! - ): ProductVariantTranslation + """A total count of items in the collection.""" + totalCount: Int +} - """Represents a version of a product such as different size or color.""" - productVariant: ProductVariant @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") +type TranslatableItemEdge { + """The item at the end of the edge.""" + node: TranslatableItem! - """List of product variant attribute values that can be translated.""" - attributeValues: [AttributeValueTranslatableContent!]! + """A cursor for use in pagination.""" + cursor: String! } +union TranslatableItem = ProductTranslatableContent | CollectionTranslatableContent | CategoryTranslatableContent | AttributeTranslatableContent | AttributeValueTranslatableContent | ProductVariantTranslatableContent | PageTranslatableContent | ShippingMethodTranslatableContent | VoucherTranslatableContent | MenuItemTranslatableContent | PromotionTranslatableContent | PromotionRuleTranslatableContent | SaleTranslatableContent + """ Represents page's original translatable fields and related translations. """ @@ -8686,6 +8815,13 @@ type PageTranslatableContent implements Node @doc(category: "Pages") { """The ID of the page translatable content.""" id: ID! + """ + The ID of the page to translate. + + Added in Saleor 3.14. + """ + pageId: ID! + """SEO title to translate.""" seoTitle: String @@ -8754,6 +8890,13 @@ type PageTranslation implements Node @doc(category: "Pages") { Rich text format. For reference see https://editorjs.io/ """ contentJson: JSONString @deprecated(reason: "This field will be removed in Saleor 4.0. Use the `content` field instead.") + + """ + Represents the page fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: PageTranslatableContent } """ @@ -8943,42 +9086,18 @@ type PageType implements Node & ObjectWithMetadata @doc(category: "Pages") { } """ -Represents shipping method's original translatable fields and related translations. +Represents voucher's original translatable fields and related translations. """ -type ShippingMethodTranslatableContent implements Node @doc(category: "Shipping") { - """The ID of the shipping method translatable content.""" +type VoucherTranslatableContent implements Node @doc(category: "Discounts") { + """The ID of the voucher translatable content.""" id: ID! - """Shipping method name to translate.""" - name: String! - - """ - Shipping method description to translate. - - Rich text format. For reference see https://editorjs.io/ - """ - description: JSONString - - """Returns translated shipping method fields for the given language code.""" - translation( - """A language code to return the translation for shipping method.""" - languageCode: LanguageCodeEnum! - ): ShippingMethodTranslation - """ - Shipping method are the methods you'll use to get customer's orders to them. They are directly exposed to the customers. + The ID of the voucher to translate. - Requires one of the following permissions: MANAGE_SHIPPING. + Added in Saleor 3.14. """ - shippingMethod: ShippingMethodType @deprecated(reason: "This field will be removed in Saleor 4.0. Get model fields from the root level queries.") -} - -""" -Represents voucher's original translatable fields and related translations. -""" -type VoucherTranslatableContent implements Node @doc(category: "Discounts") { - """The ID of the voucher translatable content.""" - id: ID! + voucherId: ID! """Voucher name to translate.""" name: String @@ -9007,6 +9126,13 @@ type VoucherTranslation implements Node @doc(category: "Discounts") { """Translated voucher name.""" name: String + + """ + Represents the voucher fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: VoucherTranslatableContent } """ @@ -9354,6 +9480,13 @@ type MenuItemTranslatableContent implements Node @doc(category: "Menu") { """The ID of the menu item translatable content.""" id: ID! + """ + The ID of the menu item to translate. + + Added in Saleor 3.14. + """ + menuItemId: ID! + """Name of the menu item to translate.""" name: String! @@ -9379,6 +9512,13 @@ type MenuItemTranslation implements Node @doc(category: "Menu") { """Translated menu item name.""" name: String! + + """ + Represents the menu item fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: MenuItemTranslatableContent } """ @@ -9528,6 +9668,9 @@ type PromotionTranslatableContent implements Node @doc(category: "Discounts") { """ID of the promotion translatable content.""" id: ID! + """ID of the promotion to translate.""" + promotionId: ID! + """Name of the promotion.""" name: String! @@ -9566,6 +9709,13 @@ type PromotionTranslation implements Node @doc(category: "Discounts") { Rich text format. For reference see https://editorjs.io/ """ description: JSONString + + """ + Represents the promotion fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: PromotionTranslatableContent } """ @@ -9577,6 +9727,13 @@ type PromotionRuleTranslatableContent implements Node @doc(category: "Discounts" """ID of the promotion rule translatable content.""" id: ID! + """ + ID of the promotion rule to translate. + + Added in Saleor 3.14. + """ + promotionRuleId: ID! + """Name of the promotion rule.""" name: String @@ -9615,6 +9772,13 @@ type PromotionRuleTranslation implements Node @doc(category: "Discounts") { Rich text format. For reference see https://editorjs.io/ """ description: JSONString + + """ + Represents the promotion rule fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: PromotionRuleTranslatableContent } """ @@ -9626,6 +9790,13 @@ type SaleTranslatableContent implements Node @doc(category: "Discounts") { """The ID of the sale translatable content.""" id: ID! + """ + The ID of the sale to translate. + + Added in Saleor 3.14. + """ + saleId: ID! + """Name of the sale to translate.""" name: String! @@ -9657,6 +9828,13 @@ type SaleTranslation implements Node @doc(category: "Discounts") { """Translated name of sale.""" name: String + + """ + Represents the sale fields to translate. + + Added in Saleor 3.14. + """ + translatableContent: SaleTranslatableContent } """ diff --git a/saleor/graphql/translations/tests/test_translations.py b/saleor/graphql/translations/tests/test_translations.py index 59bbce2f7d6..4195bcdffed 100644 --- a/saleor/graphql/translations/tests/test_translations.py +++ b/saleor/graphql/translations/tests/test_translations.py @@ -3206,8 +3206,10 @@ def test_translations_query_inline_fragment( __typename ...on ProductTranslatableContent{ id + productId name translation(languageCode: $languageCode){ + id name } } @@ -3223,7 +3225,9 @@ def test_translation_query_product( product_translation_fr, ): product_id = graphene.Node.to_global_id("Product", product.id) - + translation_id = graphene.Node.to_global_id( + "ProductTranslation", product_translation_fr.id + ) variables = { "id": product_id, "kind": TranslatableKinds.PRODUCT.name, @@ -3236,7 +3240,9 @@ def test_translation_query_product( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["productId"] == product_id assert data["name"] == product.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == product_translation_fr.name @@ -3248,8 +3254,10 @@ def test_translation_query_product( __typename ...on CollectionTranslatableContent{ id + collectionId name translation(languageCode: $languageCode){ + id name } } @@ -3268,6 +3276,9 @@ def test_translation_query_collection( channel_listing = published_collection.channel_listings.get() channel_listing.save() collection_id = graphene.Node.to_global_id("Collection", published_collection.id) + translation_id = graphene.Node.to_global_id( + "CollectionTranslation", collection_translation_fr.id + ) variables = { "id": collection_id, @@ -3281,7 +3292,9 @@ def test_translation_query_collection( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["collectionId"] == collection_id assert data["name"] == published_collection.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == collection_translation_fr.name @@ -3293,9 +3306,15 @@ def test_translation_query_collection( __typename ...on CategoryTranslatableContent{ id + categoryId name translation(languageCode: $languageCode){ + id name + translatableContent { + id + name + } } } } @@ -3307,12 +3326,15 @@ def test_translation_query_category( staff_api_client, category, category_translation_fr, permission_manage_translations ): category_id = graphene.Node.to_global_id("Category", category.id) - + translation_id = graphene.Node.to_global_id( + "CategoryTranslation", category_translation_fr.id + ) variables = { "id": category_id, "kind": TranslatableKinds.CATEGORY.name, "languageCode": LanguageCodeEnum.FR.name, } + response = staff_api_client.post_graphql( QUERY_TRANSLATION_CATEGORY, variables, @@ -3320,8 +3342,10 @@ def test_translation_query_category( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["categoryId"] == category_id assert data["name"] == category.name assert data["translation"]["name"] == category_translation_fr.name + assert data["translation"]["id"] == translation_id QUERY_TRANSLATION_ATTRIBUTE = """ @@ -3332,8 +3356,10 @@ def test_translation_query_category( __typename ...on AttributeTranslatableContent{ id + attributeId name translation(languageCode: $languageCode){ + id name } } @@ -3347,6 +3373,9 @@ def test_translation_query_attribute( ): attribute = translated_attribute.attribute attribute_id = graphene.Node.to_global_id("Attribute", attribute.id) + translation_id = graphene.Node.to_global_id( + "AttributeTranslation", translated_attribute.id + ) variables = { "id": attribute_id, @@ -3360,7 +3389,9 @@ def test_translation_query_attribute( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["attributeId"] == attribute_id assert data["name"] == attribute.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == translated_attribute.name @@ -3372,8 +3403,10 @@ def test_translation_query_attribute( __typename ...on AttributeValueTranslatableContent{ id + attributeValueId name translation(languageCode: $languageCode){ + id name } } @@ -3391,6 +3424,9 @@ def test_translation_query_attribute_value( attribute_value_id = graphene.Node.to_global_id( "AttributeValue", pink_attribute_value.id ) + translation_id = graphene.Node.to_global_id( + "AttributeValueTranslation", translated_attribute_value.id + ) variables = { "id": attribute_value_id, @@ -3404,7 +3440,9 @@ def test_translation_query_attribute_value( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["attributeValueId"] == attribute_value_id assert data["name"] == pink_attribute_value.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == translated_attribute_value.name @@ -3416,8 +3454,10 @@ def test_translation_query_attribute_value( __typename ...on ProductVariantTranslatableContent{ id + productVariantId name translation(languageCode: $languageCode){ + id name } } @@ -3434,6 +3474,9 @@ def test_translation_query_variant( variant_translation_fr, ): variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + translation_id = graphene.Node.to_global_id( + "ProductVariantTranslation", variant_translation_fr.id + ) variables = { "id": variant_id, "kind": TranslatableKinds.VARIANT.name, @@ -3446,7 +3489,9 @@ def test_translation_query_variant( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["productVariantId"] == variant_id assert data["name"] == variant.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == variant_translation_fr.name @@ -3458,8 +3503,10 @@ def test_translation_query_variant( __typename ...on PageTranslatableContent{ id + pageId title translation(languageCode: $languageCode){ + id title } } @@ -3487,6 +3534,9 @@ def test_translation_query_page( page.save() page_id = graphene.Node.to_global_id("Page", page.id) + translation_id = graphene.Node.to_global_id( + "PageTranslation", page_translation_fr.id + ) perms = list(Permission.objects.filter(codename__in=perm_codenames)) variables = { @@ -3499,7 +3549,9 @@ def test_translation_query_page( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["pageId"] == page_id assert data["title"] == page.title + assert data["translation"]["id"] == translation_id assert data["translation"]["title"] == page_translation_fr.title @@ -3511,9 +3563,11 @@ def test_translation_query_page( __typename ...on ShippingMethodTranslatableContent{ id + shippingMethodId name description translation(languageCode: $languageCode){ + id name } } @@ -3539,6 +3593,9 @@ def test_translation_query_shipping_method( shipping_method_id = graphene.Node.to_global_id( "ShippingMethodType", shipping_method.id ) + translation_id = graphene.Node.to_global_id( + "ShippingMethodTranslation", shipping_method_translation_fr.id + ) perms = list(Permission.objects.filter(codename__in=perm_codenames)) variables = { @@ -3551,8 +3608,10 @@ def test_translation_query_shipping_method( ) content = get_graphql_content(response, ignore_errors=True) data = content["data"]["translation"] + assert data["shippingMethodId"] == shipping_method_id assert data["name"] == shipping_method.name assert data["description"] == shipping_method.description + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == shipping_method_translation_fr.name @@ -3564,8 +3623,10 @@ def test_translation_query_shipping_method( __typename ...on SaleTranslatableContent{ id + saleId name translation(languageCode: $languageCode){ + id name } } @@ -3585,6 +3646,9 @@ def test_translation_query_sale( promotion = promotion_converted_from_sale promotion_translation = promotion.translations.first() sale_id = graphene.Node.to_global_id("Sale", promotion.old_sale_id) + translation_id = graphene.Node.to_global_id( + "SaleTranslation", promotion_converted_from_sale_translation_fr.id + ) variables = { "id": sale_id, @@ -3602,7 +3666,9 @@ def test_translation_query_sale( # then content = get_graphql_content(response, ignore_errors=True) data = content["data"]["translation"] + assert data["saleId"] == sale_id assert data["name"] == promotion.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == promotion_translation.name @@ -3614,8 +3680,10 @@ def test_translation_query_sale( __typename ...on VoucherTranslatableContent{ id + voucherId name translation(languageCode: $languageCode){ + id name } } @@ -3635,6 +3703,9 @@ def test_translation_query_voucher( staff_api_client, voucher, voucher_translation_fr, perm_codenames, return_voucher ): voucher_id = graphene.Node.to_global_id("Voucher", voucher.id) + translation_id = graphene.Node.to_global_id( + "VoucherTranslation", voucher_translation_fr.id + ) perms = list(Permission.objects.filter(codename__in=perm_codenames)) variables = { @@ -3647,7 +3718,9 @@ def test_translation_query_voucher( ) content = get_graphql_content(response, ignore_errors=True) data = content["data"]["translation"] + assert data["voucherId"] == voucher_id assert data["name"] == voucher.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == voucher_translation_fr.name @@ -3659,8 +3732,10 @@ def test_translation_query_voucher( __typename ...on MenuItemTranslatableContent{ id + menuItemId name translation(languageCode: $languageCode){ + id name } } @@ -3676,6 +3751,9 @@ def test_translation_query_menu_item( permission_manage_translations, ): menu_item_id = graphene.Node.to_global_id("MenuItem", menu_item.id) + translation_id = graphene.Node.to_global_id( + "MenuItemTranslation", menu_item_translation_fr.id + ) variables = { "id": menu_item_id, @@ -3689,7 +3767,9 @@ def test_translation_query_menu_item( ) content = get_graphql_content(response) data = content["data"]["translation"] + assert data["menuItemId"] == menu_item_id assert data["name"] == menu_item.name + assert data["translation"]["id"] == translation_id assert data["translation"]["name"] == menu_item_translation_fr.name diff --git a/saleor/graphql/translations/types.py b/saleor/graphql/translations/types.py index df4281852ac..1c780cc7f04 100644 --- a/saleor/graphql/translations/types.py +++ b/saleor/graphql/translations/types.py @@ -19,11 +19,12 @@ from ...product import models as product_models from ...shipping import models as shipping_models from ...site import models as site_models -from ..attribute.dataloaders import AttributesByAttributeId +from ..attribute.dataloaders import AttributesByAttributeId, AttributeValueByIdLoader from ..channel import ChannelContext from ..core.context import get_database_connection_name from ..core.descriptions import ( ADDED_IN_39, + ADDED_IN_314, ADDED_IN_317, DEPRECATED_IN_3X_FIELD, DEPRECATED_IN_3X_TYPE, @@ -34,15 +35,27 @@ from ..core.tracing import traced_resolver from ..core.types import LanguageDisplay, ModelObjectType, NonNullList from ..core.utils import str_to_enum +from ..discount.dataloaders import ( + PromotionByIdLoader, + PromotionRuleByIdLoader, + VoucherByIdLoader, +) +from ..menu.dataloaders import MenuItemByIdLoader from ..page.dataloaders import ( + PageByIdLoader, SelectedAttributesAllByPageIdLoader, SelectedAttributesVisibleInStorefrontPageIdLoader, ) from ..product.dataloaders import ( + CategoryByIdLoader, + CollectionByIdLoader, + ProductByIdLoader, + ProductVariantByIdLoader, SelectedAttributesAllByProductIdLoader, SelectedAttributesByProductVariantIdLoader, SelectedAttributesVisibleInStorefrontByProductIdLoader, ) +from ..shipping.dataloaders import ShippingMethodByIdLoader from ..utils import get_user_or_app_from_context from .fields import TranslationField @@ -100,27 +113,52 @@ class AttributeValueTranslation( description="Translated rich-text attribute value." + RICH_CONTENT ) plain_text = graphene.String(description="Translated plain text attribute value .") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.AttributeValueTranslatableContent", + description="Represents the attribute value fields to translate." + + ADDED_IN_314, + ) class Meta: model = attribute_models.AttributeValueTranslation interfaces = [graphene.relay.Node] description = "Represents attribute value translations." + @staticmethod + def resolve_translatable_content( + root: attribute_models.AttributeValueTranslation, info + ): + return AttributeValueByIdLoader(info.context).load(root.attribute_value_id) + class AttributeTranslation(BaseTranslationType[attribute_models.AttributeTranslation]): id = graphene.GlobalID( required=True, description="The ID of the attribute translation." ) name = graphene.String(required=True, description="Translated attribute name.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.AttributeTranslatableContent", + description="Represents the attribute fields to translate." + ADDED_IN_314, + ) class Meta: model = attribute_models.AttributeTranslation interfaces = [graphene.relay.Node] description = "Represents attribute translations." + @staticmethod + def resolve_translatable_content(root: attribute_models.AttributeTranslation, info): + return AttributesByAttributeId(info.context).load(root.attribute_id) + class AttributeTranslatableContent(ModelObjectType[attribute_models.Attribute]): - id = graphene.GlobalID(required=True, description="The ID of the attribute.") + id = graphene.GlobalID( + required=True, description="The ID of the attribute translatable content." + ) + attribute_id = graphene.ID( + required=True, + description="The ID of the attribute to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the attribute to translate." ) @@ -145,6 +183,10 @@ class Meta: def resolve_attribute(root: attribute_models.Attribute, _info): return root + @staticmethod + def resolve_attribute_id(root: attribute_models.Attribute, _info): + return graphene.Node.to_global_id("Attribute", root.id) + class AttributeValueTranslatableContent( ModelObjectType[attribute_models.AttributeValue] @@ -152,6 +194,10 @@ class AttributeValueTranslatableContent( id = graphene.GlobalID( required=True, description="The ID of the attribute value translatable content." ) + attribute_value_id = graphene.ID( + required=True, + description="The ID of the attribute value to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the attribute value to translate.", @@ -189,6 +235,10 @@ def resolve_attribute_value(root: attribute_models.AttributeValue, _info): def resolve_attribute(root: attribute_models.AttributeValue, info): return AttributesByAttributeId(info.context).load(root.attribute_id) + @staticmethod + def resolve_attribute_value_id(root: attribute_models.AttributeValue, _info): + return graphene.Node.to_global_id("AttributeValue", root.id) + class ProductVariantTranslation( BaseTranslationType[product_models.ProductVariantTranslation] @@ -199,17 +249,32 @@ class ProductVariantTranslation( name = graphene.String( required=True, description="Translated product variant name." ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.ProductVariantTranslatableContent", + description="Represents the product variant fields to translate." + + ADDED_IN_314, + ) class Meta: model = product_models.ProductVariantTranslation interfaces = [graphene.relay.Node] description = "Represents product variant translations." + @staticmethod + def resolve_translatable_content( + root: product_models.ProductVariantTranslation, info + ): + return ProductVariantByIdLoader(info.context).load(root.product_variant_id) + class ProductVariantTranslatableContent(ModelObjectType[product_models.ProductVariant]): id = graphene.GlobalID( required=True, description="The ID of the product variant translatable content." ) + product_variant_id = graphene.ID( + required=True, + description="The ID of the product variant to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the product variant to translate.", @@ -252,6 +317,10 @@ def resolve_attribute_values(root: product_models.ProductVariant, info): .then(get_translatable_attribute_values) ) + @staticmethod + def resolve_product_variant_id(root: product_models.ProductVariant, _info): + return graphene.Node.to_global_id("ProductVariant", root.id) + class ProductTranslation(BaseTranslationType[product_models.ProductTranslation]): id = graphene.GlobalID( @@ -269,6 +338,10 @@ class ProductTranslation(BaseTranslationType[product_models.ProductTranslation]) f"{DEPRECATED_IN_3X_FIELD} Use the `description` field instead." ), ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.ProductTranslatableContent", + description="Represents the product fields to translate." + ADDED_IN_314, + ) class Meta: model = product_models.ProductTranslation @@ -280,11 +353,19 @@ def resolve_description_json(root: product_models.ProductTranslation, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_translatable_content(root: product_models.ProductTranslation, info): + return ProductByIdLoader(info.context).load(root.product_id) + class ProductTranslatableContent(ModelObjectType[product_models.Product]): id = graphene.GlobalID( required=True, description="The ID of the product translatable content." ) + product_id = graphene.ID( + required=True, + description="The ID of the product to translate." + ADDED_IN_314, + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") name = graphene.String(required=True, description="Product's name to translate.") @@ -348,6 +429,10 @@ def resolve_attribute_values(root: product_models.Product, info): .then(get_translatable_attribute_values) ) + @staticmethod + def resolve_product_id(root: product_models.Product, _info): + return graphene.Node.to_global_id("Product", root.id) + class CollectionTranslation(BaseTranslationType[product_models.CollectionTranslation]): id = graphene.GlobalID( @@ -365,6 +450,10 @@ class CollectionTranslation(BaseTranslationType[product_models.CollectionTransla f"{DEPRECATED_IN_3X_FIELD} Use the `description` field instead." ), ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.CollectionTranslatableContent", + description="Represents the collection fields to translate." + ADDED_IN_314, + ) class Meta: model = product_models.CollectionTranslation @@ -376,11 +465,19 @@ def resolve_description_json(root: product_models.CollectionTranslation, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_translatable_content(root: product_models.CollectionTranslation, info): + return CollectionByIdLoader(info.context).load(root.collection_id) + class CollectionTranslatableContent(ModelObjectType[product_models.Collection]): id = graphene.GlobalID( required=True, description="The ID of the collection translatable content." ) + collection_id = graphene.ID( + required=True, + description="The ID of the collection to translate." + ADDED_IN_314, + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") name = graphene.String(required=True, description="Collection's name to translate.") @@ -429,6 +526,10 @@ def resolve_description_json(root: product_models.Collection, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_collection_id(root: product_models.Collection, _info): + return graphene.Node.to_global_id("Collection", root.id) + class CategoryTranslation(BaseTranslationType[product_models.CategoryTranslation]): id = graphene.GlobalID( @@ -446,6 +547,10 @@ class CategoryTranslation(BaseTranslationType[product_models.CategoryTranslation f"{DEPRECATED_IN_3X_FIELD} Use the `description` field instead." ), ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.CategoryTranslatableContent", + description="Represents the category fields to translate." + ADDED_IN_314, + ) class Meta: model = product_models.CategoryTranslation @@ -457,11 +562,19 @@ def resolve_description_json(root: product_models.CategoryTranslation, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_translatable_content(root: product_models.CategoryTranslation, info): + return CategoryByIdLoader(info.context).load(root.category_id) + class CategoryTranslatableContent(ModelObjectType[product_models.Category]): id = graphene.GlobalID( required=True, description="The ID of the category translatable content." ) + category_id = graphene.ID( + required=True, + description="The ID of the category to translate." + ADDED_IN_314, + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") name = graphene.String( @@ -501,6 +614,10 @@ def resolve_description_json(root: product_models.Category, _info): description = root.description return description if description is not None else {} + @staticmethod + def resolve_category_id(root: product_models.Category, _info): + return graphene.Node.to_global_id("Category", root.id) + class PageTranslation(BaseTranslationType[page_models.PageTranslation]): id = graphene.GlobalID(required=True, description="The ID of the page translation.") @@ -512,6 +629,10 @@ class PageTranslation(BaseTranslationType[page_models.PageTranslation]): description="Translated description of the page." + RICH_CONTENT, deprecation_reason=f"{DEPRECATED_IN_3X_FIELD} Use the `content` field instead.", ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.PageTranslatableContent", + description="Represents the page fields to translate." + ADDED_IN_314, + ) class Meta: model = page_models.PageTranslation @@ -523,11 +644,18 @@ def resolve_content_json(root: page_models.PageTranslation, _info): content = root.content return content if content is not None else {} + @staticmethod + def resolve_translatable_content(root: page_models.PageTranslation, info): + return PageByIdLoader(info.context).load(root.page_id) + class PageTranslatableContent(ModelObjectType[page_models.Page]): id = graphene.GlobalID( required=True, description="The ID of the page translatable content." ) + page_id = graphene.ID( + required=True, description="The ID of the page to translate." + ADDED_IN_314 + ) seo_title = graphene.String(description="SEO title to translate.") seo_description = graphene.String(description="SEO description to translate.") title = graphene.String(required=True, description="Page title to translate.") @@ -594,23 +722,39 @@ def resolve_attribute_values(root: page_models.Page, info): .then(get_translatable_attribute_values) ) + @staticmethod + def resolve_page_id(root: page_models.Page, _info): + return graphene.Node.to_global_id("Page", root.id) + class VoucherTranslation(BaseTranslationType[discount_models.VoucherTranslation]): id = graphene.GlobalID( required=True, description="The ID of the voucher translation." ) name = graphene.String(description="Translated voucher name.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.VoucherTranslatableContent", + description="Represents the voucher fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.VoucherTranslation interfaces = [graphene.relay.Node] description = "Represents voucher translations." + @staticmethod + def resolve_translatable_content(root: discount_models.VoucherTranslation, info): + return VoucherByIdLoader(info.context).load(root.voucher_id) + class VoucherTranslatableContent(ModelObjectType[discount_models.Voucher]): id = graphene.GlobalID( required=True, description="The ID of the voucher translatable content." ) + voucher_id = graphene.ID( + required=True, + description="The ID of the voucher to translate." + ADDED_IN_314, + ) name = graphene.String(description="Voucher name to translate.") translation = TranslationField(VoucherTranslation, type_name="voucher") voucher = PermissionsField( @@ -638,10 +782,18 @@ class Meta: def resolve_voucher(root: discount_models.Voucher, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_voucher_id(root: discount_models.Voucher, _info): + return graphene.Node.to_global_id("Voucher", root.id) + class SaleTranslation(BaseTranslationType[discount_models.PromotionTranslation]): id = graphene.GlobalID(required=True, description="The ID of the sale translation.") name = graphene.String(description="Translated name of sale.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.SaleTranslatableContent", + description="Represents the sale fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.PromotionTranslation @@ -652,11 +804,19 @@ class Meta: + " Use `PromotionTranslation` instead." ) + @staticmethod + def resolve_translatable_content(root: discount_models.PromotionTranslation, info): + return PromotionByIdLoader(info.context).load(root.promotion_id) + class SaleTranslatableContent(ModelObjectType[discount_models.Promotion]): id = graphene.GlobalID( required=True, description="The ID of the sale translatable content." ) + sale_id = graphene.ID( + required=True, + description="The ID of the sale to translate." + ADDED_IN_314, + ) name = graphene.String(required=True, description="Name of the sale to translate.") translation = TranslationField(SaleTranslation, type_name="sale") sale = PermissionsField( @@ -684,6 +844,10 @@ class Meta: def resolve_sale(root: discount_models.Promotion, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_sale_id(root: discount_models.Promotion, _info): + return graphene.Node.to_global_id("Sale", root.old_sale_id) + class ShopTranslation(BaseTranslationType[site_models.SiteSettingsTranslation]): id = graphene.GlobalID(required=True, description="The ID of the shop translation.") @@ -705,17 +869,29 @@ class MenuItemTranslation(BaseTranslationType[menu_models.MenuItemTranslation]): required=True, description="The ID of the menu item translation." ) name = graphene.String(required=True, description="Translated menu item name.") + translatable_content = graphene.Field( + "saleor.graphql.translations.types.MenuItemTranslatableContent", + description="Represents the menu item fields to translate." + ADDED_IN_314, + ) class Meta: model = menu_models.MenuItemTranslation interfaces = [graphene.relay.Node] description = "Represents menu item translations." + @staticmethod + def resolve_translatable_content(root: menu_models.MenuItemTranslation, info): + return MenuItemByIdLoader(info.context).load(root.menu_item_id) + class MenuItemTranslatableContent(ModelObjectType[menu_models.MenuItem]): id = graphene.GlobalID( required=True, description="The ID of the menu item translatable content." ) + menu_item_id = graphene.ID( + required=True, + description="The ID of the menu item to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Name of the menu item to translate." ) @@ -743,6 +919,10 @@ class Meta: def resolve_menu_item(root: menu_models.MenuItem, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_menu_item_id(root: menu_models.MenuItem, _info): + return graphene.Node.to_global_id("MenuItem", root.id) + class ShippingMethodTranslation( BaseTranslationType[shipping_models.ShippingMethodTranslation] @@ -756,12 +936,23 @@ class ShippingMethodTranslation( description = JSONString( description="Translated description of the shipping method." + RICH_CONTENT ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.ShippingMethodTranslatableContent", + description="Represents the shipping method fields to translate." + + ADDED_IN_314, + ) class Meta: model = shipping_models.ShippingMethodTranslation interfaces = [graphene.relay.Node] description = "Represents shipping method translations." + @staticmethod + def resolve_translatable_content( + root: shipping_models.ShippingMethodTranslation, info + ): + return ShippingMethodByIdLoader(info.context).load(root.shipping_method_id) + class ShippingMethodTranslatableContent( ModelObjectType[shipping_models.ShippingMethod] @@ -769,6 +960,10 @@ class ShippingMethodTranslatableContent( id = graphene.GlobalID( required=True, description="The ID of the shipping method translatable content." ) + shipping_method_id = graphene.ID( + required=True, + description="The ID of the shipping method to translate." + ADDED_IN_314, + ) name = graphene.String( required=True, description="Shipping method name to translate." ) @@ -804,6 +999,10 @@ class Meta: def resolve_shipping_method(root: shipping_models.ShippingMethod, _info): return ChannelContext(node=root, channel_slug=None) + @staticmethod + def resolve_shipping_method_id(root: shipping_models.ShippingMethod, _info): + return graphene.Node.to_global_id("ShippingMethodType", root.id) + class PromotionTranslation(BaseTranslationType[discount_models.PromotionTranslation]): id = graphene.GlobalID( @@ -813,17 +1012,28 @@ class PromotionTranslation(BaseTranslationType[discount_models.PromotionTranslat description = JSONString( description="Translated description of the promotion." + RICH_CONTENT ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.PromotionTranslatableContent", + description="Represents the promotion fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.Promotion interfaces = [graphene.relay.Node] description = "Represents promotion translations." + ADDED_IN_317 + @staticmethod + def resolve_translatable_content(root: discount_models.PromotionTranslation, info): + return PromotionByIdLoader(info.context).load(root.promotion_id) + class PromotionTranslatableContent(ModelObjectType[discount_models.Promotion]): id = graphene.GlobalID( required=True, description="ID of the promotion translatable content." ) + promotion_id = graphene.ID( + required=True, description="ID of the promotion to translate." + ) name = graphene.String(required=True, description="Name of the promotion.") description = JSONString(description="Description of the promotion." + RICH_CONTENT) translation = TranslationField(PromotionTranslation, type_name="promotion") @@ -836,6 +1046,10 @@ class Meta: "and related translations." + ADDED_IN_317 ) + @staticmethod + def resolve_promotion_id(root: discount_models.Promotion, _info): + return graphene.Node.to_global_id("Promotion", root.id) + class PromotionRuleTranslation( BaseTranslationType[discount_models.PromotionTranslation] @@ -847,17 +1061,31 @@ class PromotionRuleTranslation( description = JSONString( description="Translated description of the promotion rule." + RICH_CONTENT ) + translatable_content = graphene.Field( + "saleor.graphql.translations.types.PromotionRuleTranslatableContent", + description="Represents the promotion rule fields to translate." + ADDED_IN_314, + ) class Meta: model = discount_models.PromotionRule interfaces = [graphene.relay.Node] description = "Represents promotion rule translations." + ADDED_IN_317 + @staticmethod + def resolve_translatable_content( + root: discount_models.PromotionRuleTranslation, info + ): + return PromotionRuleByIdLoader(info.context).load(root.promotion_rule_id) + class PromotionRuleTranslatableContent(ModelObjectType[discount_models.Promotion]): id = graphene.GlobalID( required=True, description="ID of the promotion rule translatable content." ) + promotion_rule_id = graphene.ID( + required=True, + description="ID of the promotion rule to translate." + ADDED_IN_314, + ) name = graphene.String(description="Name of the promotion rule.") description = JSONString( description="Description of the promotion rule." + RICH_CONTENT @@ -871,3 +1099,7 @@ class Meta: "Represents promotion rule's original translatable fields " "and related translations." + ADDED_IN_317 ) + + @staticmethod + def resolve_promotion_rule_id(root: discount_models.PromotionRule, _info): + return graphene.Node.to_global_id("PromotionRule", root.id) diff --git a/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py b/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py index 0ab0f291ce3..00c6e43bebf 100644 --- a/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py +++ b/saleor/plugins/webhook/tests/subscription_webhooks/fixtures.py @@ -942,6 +942,110 @@ def subscription_translation_created_webhook(subscription_webhook): ) +@pytest.fixture +def subscription_product_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PRODUCT, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_product_variant_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PRODUCT_VARIANT, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_collection_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_COLLECTION, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_category_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_CATEGORY, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_attribute_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_ATTRIBUTE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_attribute_value_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_ATTRIBUTE_VALUE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_page_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PAGE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_shipping_method_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_SHIPPING_METHOD, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_promotion_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PROMOTION, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_sale_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_SALE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_promotion_rule_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_PROMOTION_RULE, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_voucher_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_VOUCHER, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + +@pytest.fixture +def subscription_menu_item_translation_created_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_CREATED_MENU_ITEM, + WebhookEventAsyncType.TRANSLATION_CREATED, + ) + + @pytest.fixture def subscription_translation_updated_webhook(subscription_webhook): return subscription_webhook( @@ -950,6 +1054,102 @@ def subscription_translation_updated_webhook(subscription_webhook): ) +@pytest.fixture +def subscription_product_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PRODUCT, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_product_variant_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PRODUCT_VARIANT, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_collection_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_COLLECTION, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_category_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_CATEGORY, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_attribute_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_ATTRIBUTE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_attribute_value_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_ATTRIBUTE_VALUE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_page_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PAGE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_shipping_method_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_SHIPPING_METHOD, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_promotion_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PROMOTION, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_promotion_rule_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_PROMOTION_RULE, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_voucher_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_VOUCHER, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + +@pytest.fixture +def subscription_menu_item_translation_updated_webhook(subscription_webhook): + return subscription_webhook( + queries.TRANSLATION_UPDATED_MENU_ITEM, + WebhookEventAsyncType.TRANSLATION_UPDATED, + ) + + @pytest.fixture def subscription_warehouse_created_webhook(subscription_webhook): return subscription_webhook( diff --git a/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py b/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py index 91c3138406b..2a141a325bd 100644 --- a/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py +++ b/saleor/plugins/webhook/tests/subscription_webhooks/subscription_queries.py @@ -1,4 +1,9 @@ +from enum import Enum + +from graphene.utils.str_converters import to_snake_case + from .....graphql.tests.queries import fragments +from .....graphql.webhook.subscription_types import TRANSLATIONS_TYPES_MAP ACCOUNT_CONFIRMATION_REQUESTED = ( fragments.CUSTOMER_DETAILS @@ -2035,6 +2040,140 @@ } """ +TranslationTypes = Enum( + "TranslationTypes", + { + to_snake_case(k.__name__).upper(): k.__name__ + for k in TRANSLATIONS_TYPES_MAP.keys() + }, +) + + +class TranslationQueryType(Enum): + CREATED = "TranslationCreated" + UPDATED = "TranslationUpdated" + + +def build_translation_query( + type: TranslationTypes, + query_type: TranslationQueryType, + translated_object_id: str, +) -> str: + return ( # noqa: UP031 + """ + subscription { + event { + ... on %s { + translation { + ... on %s { + id + name + translatableContent { + %s + name + } + } + } + } + } + } + """ + ) % (query_type.value, type.value, translated_object_id) + + +TRANSLATION_CREATED_PRODUCT = build_translation_query( + TranslationTypes.PRODUCT_TRANSLATION, + TranslationQueryType.CREATED, + "productId", +) +TRANSLATION_CREATED_PRODUCT_VARIANT = build_translation_query( + TranslationTypes.PRODUCT_VARIANT_TRANSLATION, + TranslationQueryType.CREATED, + "productVariantId", +) +TRANSLATION_CREATED_COLLECTION = build_translation_query( + TranslationTypes.COLLECTION_TRANSLATION, + TranslationQueryType.CREATED, + "collectionId", +) +TRANSLATION_CREATED_CATEGORY = build_translation_query( + TranslationTypes.CATEGORY_TRANSLATION, + TranslationQueryType.CREATED, + "categoryId", +) +TRANSLATION_CREATED_ATTRIBUTE = build_translation_query( + TranslationTypes.ATTRIBUTE_TRANSLATION, + TranslationQueryType.CREATED, + "attributeId", +) +TRANSLATION_CREATED_ATTRIBUTE_VALUE = build_translation_query( + TranslationTypes.ATTRIBUTE_VALUE_TRANSLATION, + TranslationQueryType.CREATED, + "attributeValueId", +) +TRANSLATION_CREATED_SHIPPING_METHOD = build_translation_query( + TranslationTypes.SHIPPING_METHOD_TRANSLATION, + TranslationQueryType.CREATED, + "shippingMethodId", +) +TRANSLATION_CREATED_PROMOTION = build_translation_query( + TranslationTypes.PROMOTION_TRANSLATION, + TranslationQueryType.CREATED, + "promotionId", +) +TRANSLATION_CREATED_PROMOTION_RULE = build_translation_query( + TranslationTypes.PROMOTION_RULE_TRANSLATION, + TranslationQueryType.CREATED, + "promotionRuleId", +) +TRANSLATION_CREATED_VOUCHER = build_translation_query( + TranslationTypes.VOUCHER_TRANSLATION, + TranslationQueryType.CREATED, + "voucherId", +) +TRANSLATION_CREATED_MENU_ITEM = build_translation_query( + TranslationTypes.MENU_ITEM_TRANSLATION, + TranslationQueryType.CREATED, + "menuItemId", +) +TRANSLATION_CREATED_PAGE = """ + subscription { + event { + ... on TranslationCreated { + translation { + ... on PageTranslation { + id + title + translatableContent { + pageId + title + } + } + } + } + } + } +""" +TRANSLATION_CREATED_SALE = """ + subscription { + event { + ... on TranslationCreated { + translation { + ... on SaleTranslation { + __typename + id + name + translatableContent { + saleId + name + } + } + } + } + } + } +""" + TRANSLATION_UPDATED = """ subscription { event { @@ -2087,6 +2226,80 @@ } """ +TRANSLATION_UPDATED_PRODUCT = build_translation_query( + TranslationTypes.PRODUCT_TRANSLATION, + TranslationQueryType.UPDATED, + "productId", +) +TRANSLATION_UPDATED_PRODUCT_VARIANT = build_translation_query( + TranslationTypes.PRODUCT_VARIANT_TRANSLATION, + TranslationQueryType.UPDATED, + "productVariantId", +) +TRANSLATION_UPDATED_COLLECTION = build_translation_query( + TranslationTypes.COLLECTION_TRANSLATION, + TranslationQueryType.UPDATED, + "collectionId", +) +TRANSLATION_UPDATED_CATEGORY = build_translation_query( + TranslationTypes.CATEGORY_TRANSLATION, + TranslationQueryType.UPDATED, + "categoryId", +) +TRANSLATION_UPDATED_ATTRIBUTE = build_translation_query( + TranslationTypes.ATTRIBUTE_TRANSLATION, + TranslationQueryType.UPDATED, + "attributeId", +) +TRANSLATION_UPDATED_ATTRIBUTE_VALUE = build_translation_query( + TranslationTypes.ATTRIBUTE_VALUE_TRANSLATION, + TranslationQueryType.UPDATED, + "attributeValueId", +) +TRANSLATION_UPDATED_SHIPPING_METHOD = build_translation_query( + TranslationTypes.SHIPPING_METHOD_TRANSLATION, + TranslationQueryType.UPDATED, + "shippingMethodId", +) +TRANSLATION_UPDATED_PROMOTION = build_translation_query( + TranslationTypes.PROMOTION_TRANSLATION, + TranslationQueryType.UPDATED, + "promotionId", +) +TRANSLATION_UPDATED_PROMOTION_RULE = build_translation_query( + TranslationTypes.PROMOTION_RULE_TRANSLATION, + TranslationQueryType.UPDATED, + "promotionRuleId", +) +TRANSLATION_UPDATED_VOUCHER = build_translation_query( + TranslationTypes.VOUCHER_TRANSLATION, + TranslationQueryType.UPDATED, + "voucherId", +) +TRANSLATION_UPDATED_MENU_ITEM = build_translation_query( + TranslationTypes.MENU_ITEM_TRANSLATION, + TranslationQueryType.UPDATED, + "menuItemId", +) +TRANSLATION_UPDATED_PAGE = """ + subscription { + event { + ... on TranslationUpdated { + translation { + ... on PageTranslation { + id + title + translatableContent { + pageId + title + } + } + } + } + } + } +""" + TEST_VALID_SUBSCRIPTION = """ subscription{ event{ diff --git a/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py b/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py index 6d265780511..03070f4cd05 100644 --- a/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py +++ b/saleor/plugins/webhook/tests/subscription_webhooks/test_create_deliveries_for_translation_subscription.py @@ -9,18 +9,31 @@ def test_translation_created_product( - product_translation_fr, subscription_translation_created_webhook + product_translation_fr, subscription_product_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_product_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "ProductTranslation", product_translation_fr.id ) + product = product_translation_fr.product + product_id = graphene.Node.to_global_id("Product", product.id) deliveries = create_deliveries_for_subscriptions( event_type, product_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": product_translation_fr.name, + "translatableContent": { + "productId": product_id, + "name": product.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -28,18 +41,31 @@ def test_translation_created_product( def test_translation_created_product_variant( - variant_translation_fr, subscription_translation_created_webhook + variant_translation_fr, subscription_product_variant_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_product_variant_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "ProductVariantTranslation", variant_translation_fr.id ) + variant = variant_translation_fr.product_variant + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) deliveries = create_deliveries_for_subscriptions( event_type, variant_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": variant_translation_fr.name, + "translatableContent": { + "productVariantId": variant_id, + "name": variant.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -47,18 +73,31 @@ def test_translation_created_product_variant( def test_translation_created_collection( - collection_translation_fr, subscription_translation_created_webhook + collection_translation_fr, subscription_collection_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_collection_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "CollectionTranslation", collection_translation_fr.id ) + collection = collection_translation_fr.collection + collection_id = graphene.Node.to_global_id("Collection", collection.id) deliveries = create_deliveries_for_subscriptions( event_type, collection_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": collection_translation_fr.name, + "translatableContent": { + "collectionId": collection_id, + "name": collection.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -66,18 +105,31 @@ def test_translation_created_collection( def test_translation_created_category( - category_translation_fr, subscription_translation_created_webhook + category_translation_fr, subscription_category_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_category_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "CategoryTranslation", category_translation_fr.id ) + category = category_translation_fr.category + category_id = graphene.Node.to_global_id("Category", category.id) deliveries = create_deliveries_for_subscriptions( event_type, category_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": category_translation_fr.name, + "translatableContent": { + "categoryId": category_id, + "name": category.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -85,18 +137,31 @@ def test_translation_created_category( def test_translation_created_attribute( - translated_attribute, subscription_translation_created_webhook + translated_attribute, subscription_attribute_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_attribute_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "AttributeTranslation", translated_attribute.id ) + attribute = translated_attribute.attribute + attribute_id = graphene.Node.to_global_id("Attribute", attribute.id) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute.name, + "translatableContent": { + "attributeId": attribute_id, + "name": attribute.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -104,18 +169,33 @@ def test_translation_created_attribute( def test_translation_created_attribute_value( - translated_attribute_value, subscription_translation_created_webhook + translated_attribute_value, subscription_attribute_value_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_attribute_value_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "AttributeValueTranslation", translated_attribute_value.id ) + attribute_value = translated_attribute_value.attribute_value + attribute_value_id = graphene.Node.to_global_id( + "AttributeValue", attribute_value.id + ) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute_value, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute_value.name, + "translatableContent": { + "attributeValueId": attribute_value_id, + "name": attribute_value.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -123,18 +203,31 @@ def test_translation_created_attribute_value( def test_translation_created_page( - page_translation_fr, subscription_translation_created_webhook + page_translation_fr, subscription_page_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_page_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "PageTranslation", page_translation_fr.id ) + page = page_translation_fr.page + page_id = graphene.Node.to_global_id("Page", page.id) deliveries = create_deliveries_for_subscriptions( event_type, page_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "title": page_translation_fr.title, + "translatableContent": { + "pageId": page_id, + "title": page.title, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -142,18 +235,34 @@ def test_translation_created_page( def test_translation_created_shipping_method( - shipping_method_translation_fr, subscription_translation_created_webhook + shipping_method_translation_fr, + subscription_shipping_method_translation_created_webhook, ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_shipping_method_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "ShippingMethodTranslation", shipping_method_translation_fr.id ) + shipping_method = shipping_method_translation_fr.shipping_method + shipping_method_id = graphene.Node.to_global_id( + "ShippingMethodType", shipping_method.id + ) deliveries = create_deliveries_for_subscriptions( event_type, shipping_method_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": shipping_method_translation_fr.name, + "translatableContent": { + "shippingMethodId": shipping_method_id, + "name": shipping_method.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -161,13 +270,15 @@ def test_translation_created_shipping_method( def test_translation_created_promotion( - promotion_translation_fr, subscription_translation_created_webhook + promotion_translation_fr, subscription_promotion_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_promotion_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "PromotionTranslation", promotion_translation_fr.id ) + promotion = promotion_translation_fr.promotion + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_translation_fr, webhooks ) @@ -176,7 +287,11 @@ def test_translation_created_promotion( { "translation": { "id": translation_id, - "__typename": "PromotionTranslation", + "name": promotion_translation_fr.name, + "translatableContent": { + "promotionId": promotion_id, + "name": promotion.name, + }, } } ) @@ -188,19 +303,26 @@ def test_translation_created_promotion( def test_translation_created_promotion_converted_from_sale( promotion_converted_from_sale_translation_fr, - subscription_translation_created_webhook, + subscription_sale_translation_created_webhook, ): translation = promotion_converted_from_sale_translation_fr - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_sale_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id("SaleTranslation", translation.id) + promotion = promotion_converted_from_sale_translation_fr.promotion + promotion_id = graphene.Node.to_global_id("Sale", promotion.old_sale_id) deliveries = create_deliveries_for_subscriptions(event_type, translation, webhooks) expected_payload = json.dumps( { "translation": { - "id": translation_id, "__typename": "SaleTranslation", + "id": translation_id, + "name": promotion_converted_from_sale_translation_fr.name, + "translatableContent": { + "saleId": promotion_id, + "name": promotion.name, + }, } } ) @@ -211,18 +333,32 @@ def test_translation_created_promotion_converted_from_sale( def test_translation_created_promotion_rule( - promotion_rule_translation_fr, subscription_translation_created_webhook + promotion_rule_translation_fr, + subscription_promotion_rule_translation_created_webhook, ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_promotion_rule_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "PromotionRuleTranslation", promotion_rule_translation_fr.id ) + promotion_rule = promotion_rule_translation_fr.promotion_rule + promotion_rule_id = graphene.Node.to_global_id("PromotionRule", promotion_rule.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_rule_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": promotion_rule_translation_fr.name, + "translatableContent": { + "promotionRuleId": promotion_rule_id, + "name": promotion_rule.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -230,18 +366,31 @@ def test_translation_created_promotion_rule( def test_translation_created_voucher( - voucher_translation_fr, subscription_translation_created_webhook + voucher_translation_fr, subscription_voucher_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_voucher_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "VoucherTranslation", voucher_translation_fr.id ) + voucher = voucher_translation_fr.voucher + voucher_id = graphene.Node.to_global_id("Voucher", voucher.id) deliveries = create_deliveries_for_subscriptions( event_type, voucher_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": voucher_translation_fr.name, + "translatableContent": { + "voucherId": voucher_id, + "name": voucher.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -249,18 +398,31 @@ def test_translation_created_voucher( def test_translation_created_menu_item( - menu_item_translation_fr, subscription_translation_created_webhook + menu_item_translation_fr, subscription_menu_item_translation_created_webhook ): - webhooks = [subscription_translation_created_webhook] + webhooks = [subscription_menu_item_translation_created_webhook] event_type = WebhookEventAsyncType.TRANSLATION_CREATED translation_id = graphene.Node.to_global_id( "MenuItemTranslation", menu_item_translation_fr.id ) + menu_item = menu_item_translation_fr.menu_item + menu_item_id = graphene.Node.to_global_id("MenuItem", menu_item.id) deliveries = create_deliveries_for_subscriptions( event_type, menu_item_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": menu_item_translation_fr.name, + "translatableContent": { + "menuItemId": menu_item_id, + "name": menu_item.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -268,18 +430,31 @@ def test_translation_created_menu_item( def test_translation_updated_product( - product_translation_fr, subscription_translation_updated_webhook + product_translation_fr, subscription_product_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_product_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "ProductTranslation", product_translation_fr.id ) + product = product_translation_fr.product + product_id = graphene.Node.to_global_id("Product", product.id) deliveries = create_deliveries_for_subscriptions( event_type, product_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": product_translation_fr.name, + "translatableContent": { + "productId": product_id, + "name": product.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -287,18 +462,31 @@ def test_translation_updated_product( def test_translation_updated_product_variant( - variant_translation_fr, subscription_translation_updated_webhook + variant_translation_fr, subscription_product_variant_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_product_variant_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "ProductVariantTranslation", variant_translation_fr.id ) + variant = variant_translation_fr.product_variant + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) deliveries = create_deliveries_for_subscriptions( event_type, variant_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": variant_translation_fr.name, + "translatableContent": { + "productVariantId": variant_id, + "name": variant.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -306,18 +494,31 @@ def test_translation_updated_product_variant( def test_translation_updated_collection( - collection_translation_fr, subscription_translation_updated_webhook + collection_translation_fr, subscription_collection_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_collection_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "CollectionTranslation", collection_translation_fr.id ) + collection = collection_translation_fr.collection + collection_id = graphene.Node.to_global_id("Collection", collection.id) deliveries = create_deliveries_for_subscriptions( event_type, collection_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": collection_translation_fr.name, + "translatableContent": { + "collectionId": collection_id, + "name": collection.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -325,18 +526,31 @@ def test_translation_updated_collection( def test_translation_updated_category( - category_translation_fr, subscription_translation_updated_webhook + category_translation_fr, subscription_category_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_category_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "CategoryTranslation", category_translation_fr.id ) + category = category_translation_fr.category + category_id = graphene.Node.to_global_id("Category", category.id) deliveries = create_deliveries_for_subscriptions( event_type, category_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": category_translation_fr.name, + "translatableContent": { + "categoryId": category_id, + "name": category.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -344,18 +558,31 @@ def test_translation_updated_category( def test_translation_updated_attribute( - translated_attribute, subscription_translation_updated_webhook + translated_attribute, subscription_attribute_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_attribute_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "AttributeTranslation", translated_attribute.id ) + attribute = translated_attribute.attribute + attribute_id = graphene.Node.to_global_id("Attribute", attribute.id) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute.name, + "translatableContent": { + "attributeId": attribute_id, + "name": attribute.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -363,18 +590,33 @@ def test_translation_updated_attribute( def test_translation_updated_attribute_value( - translated_attribute_value, subscription_translation_updated_webhook + translated_attribute_value, subscription_attribute_value_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_attribute_value_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "AttributeValueTranslation", translated_attribute_value.id ) + attribute_value = translated_attribute_value.attribute_value + attribute_value_id = graphene.Node.to_global_id( + "AttributeValue", attribute_value.id + ) deliveries = create_deliveries_for_subscriptions( event_type, translated_attribute_value, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": translated_attribute_value.name, + "translatableContent": { + "attributeValueId": attribute_value_id, + "name": attribute_value.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -382,18 +624,31 @@ def test_translation_updated_attribute_value( def test_translation_updated_page( - page_translation_fr, subscription_translation_updated_webhook + page_translation_fr, subscription_page_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_page_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "PageTranslation", page_translation_fr.id ) + page = page_translation_fr.page + page_id = graphene.Node.to_global_id("Page", page.id) deliveries = create_deliveries_for_subscriptions( event_type, page_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "title": page_translation_fr.title, + "translatableContent": { + "pageId": page_id, + "title": page.title, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -401,18 +656,34 @@ def test_translation_updated_page( def test_translation_updated_shipping_method( - shipping_method_translation_fr, subscription_translation_updated_webhook + shipping_method_translation_fr, + subscription_shipping_method_translation_updated_webhook, ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_shipping_method_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "ShippingMethodTranslation", shipping_method_translation_fr.id ) + shipping_method = shipping_method_translation_fr.shipping_method + shipping_method_id = graphene.Node.to_global_id( + "ShippingMethodType", shipping_method.id + ) deliveries = create_deliveries_for_subscriptions( event_type, shipping_method_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": shipping_method_translation_fr.name, + "translatableContent": { + "shippingMethodId": shipping_method_id, + "name": shipping_method.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -420,19 +691,30 @@ def test_translation_updated_shipping_method( def test_translation_updated_promotion( - promotion_translation_fr, subscription_translation_updated_webhook + promotion_translation_fr, subscription_promotion_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_promotion_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "PromotionTranslation", promotion_translation_fr.id ) + promotion = promotion_translation_fr.promotion + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_translation_fr, webhooks ) expected_payload = json.dumps( - {"translation": {"id": translation_id, "__typename": "PromotionTranslation"}} + { + "translation": { + "id": translation_id, + "name": promotion_translation_fr.name, + "translatableContent": { + "promotionId": promotion_id, + "name": promotion.name, + }, + } + } ) assert deliveries[0].payload.payload == expected_payload @@ -441,18 +723,32 @@ def test_translation_updated_promotion( def test_translation_updated_promotion_rule( - promotion_rule_translation_fr, subscription_translation_updated_webhook + promotion_rule_translation_fr, + subscription_promotion_rule_translation_updated_webhook, ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_promotion_rule_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "PromotionRuleTranslation", promotion_rule_translation_fr.id ) + promotion_rule = promotion_rule_translation_fr.promotion_rule + promotion_rule_id = graphene.Node.to_global_id("PromotionRule", promotion_rule.id) deliveries = create_deliveries_for_subscriptions( event_type, promotion_rule_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": promotion_rule_translation_fr.name, + "translatableContent": { + "promotionRuleId": promotion_rule_id, + "name": promotion_rule.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -460,18 +756,31 @@ def test_translation_updated_promotion_rule( def test_translation_updated_voucher( - voucher_translation_fr, subscription_translation_updated_webhook + voucher_translation_fr, subscription_voucher_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_voucher_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "VoucherTranslation", voucher_translation_fr.id ) + voucher = voucher_translation_fr.voucher + voucher_id = graphene.Node.to_global_id("Voucher", voucher.id) deliveries = create_deliveries_for_subscriptions( event_type, voucher_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": voucher_translation_fr.name, + "translatableContent": { + "voucherId": voucher_id, + "name": voucher.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) @@ -479,18 +788,31 @@ def test_translation_updated_voucher( def test_translation_updated_menu_item( - menu_item_translation_fr, subscription_translation_updated_webhook + menu_item_translation_fr, subscription_menu_item_translation_updated_webhook ): - webhooks = [subscription_translation_updated_webhook] + webhooks = [subscription_menu_item_translation_updated_webhook] event_type = WebhookEventAsyncType.TRANSLATION_UPDATED translation_id = graphene.Node.to_global_id( "MenuItemTranslation", menu_item_translation_fr.id ) + menu_item = menu_item_translation_fr.menu_item + menu_item_id = graphene.Node.to_global_id("MenuItem", menu_item.id) deliveries = create_deliveries_for_subscriptions( event_type, menu_item_translation_fr, webhooks ) - expected_payload = json.dumps({"translation": {"id": translation_id}}) + expected_payload = json.dumps( + { + "translation": { + "id": translation_id, + "name": menu_item_translation_fr.name, + "translatableContent": { + "menuItemId": menu_item_id, + "name": menu_item.name, + }, + } + } + ) assert deliveries[0].payload.payload == expected_payload assert len(deliveries) == len(webhooks) From 7363e5bafb284658e7c2f8b84184cbcce65e5992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Tue, 26 Mar 2024 14:00:52 +0100 Subject: [PATCH 10/35] Handle replica/writer usages in thubnails views (#15698) --- saleor/thumbnail/views.py | 41 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/saleor/thumbnail/views.py b/saleor/thumbnail/views.py index 1472a6fce8d..7e7079f999d 100644 --- a/saleor/thumbnail/views.py +++ b/saleor/thumbnail/views.py @@ -2,6 +2,7 @@ from collections import namedtuple from typing import Optional +from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.http import ( HttpResponseBadRequest, @@ -12,6 +13,7 @@ from ..account.models import User from ..app.models import App, AppInstallation +from ..core.db.connection import allow_writer from ..core.utils.events import call_event from ..graphql.core.utils import from_global_id_or_error from ..plugins.manager import get_plugins_manager @@ -82,16 +84,22 @@ def handle_thumbnail( else: instance_id_lookup = model_data.thumbnail_field + "_id" - if thumbnail := Thumbnail.objects.filter( - format=format, size=size_px, **{instance_id_lookup: pk} - ).first(): + if ( + thumbnail := Thumbnail.objects.using(settings.DATABASE_CONNECTION_REPLICA_NAME) + .filter(format=format, size=size_px, **{instance_id_lookup: pk}) + .first() + ): return HttpResponseRedirect(thumbnail.image.url) try: if object_type in UUID_IDENTIFIABLE_TYPES: - instance = model_data.model.objects.get(uuid=pk) + instance = model_data.model.objects.using( + settings.DATABASE_CONNECTION_REPLICA_NAME + ).get(uuid=pk) else: - instance = model_data.model.objects.get(id=pk) + instance = model_data.model.objects.using( + settings.DATABASE_CONNECTION_REPLICA_NAME + ).get(id=pk) except ObjectDoesNotExist: return HttpResponseNotFound("Instance with the given id cannot be found.") @@ -118,16 +126,17 @@ def handle_thumbnail( thumbnail_file_name = prepare_thumbnail_file_name(image.name, size_px, format) # save image thumbnail - thumbnail = Thumbnail( - size=size_px, format=format, **{model_data.thumbnail_field: instance} - ) - thumbnail.image.save(thumbnail_file_name, thumbnail_file) - thumbnail.save() - - # set additional `instance` attribute, to easily get instance data - # for ThumbnailCreated subscription type - setattr(thumbnail, "instance", instance) - manager = get_plugins_manager(allow_replica=False) - call_event(manager.thumbnail_created, thumbnail) + with allow_writer(): + thumbnail = Thumbnail( + size=size_px, format=format, **{model_data.thumbnail_field: instance} + ) + thumbnail.image.save(thumbnail_file_name, thumbnail_file) + thumbnail.save() + + # set additional `instance` attribute, to easily get instance data + # for ThumbnailCreated subscription type + setattr(thumbnail, "instance", instance) + manager = get_plugins_manager(allow_replica=False) + call_event(manager.thumbnail_created, thumbnail) return HttpResponseRedirect(thumbnail.image.url) From dd96cf1e27be4dfd566adfbc57408e9ecb1b4791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:16:35 +0100 Subject: [PATCH 11/35] Bump dependencies (#15681) Co-authored-by: maarcingebala --- .pre-commit-config.yaml | 6 +- poetry.lock | 139 ++++++++++++++++++++-------------------- 2 files changed, 73 insertions(+), 72 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9d6819dcafe..37be86ac92e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: exclude_types: [svg] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.3.4 hooks: - id: ruff - id: ruff-format @@ -24,12 +24,12 @@ repos: exclude: tests/ - repo: https://github.com/fpgmaas/deptry.git - rev: "0.14.0" + rev: "0.15.0" hooks: - id: deptry - repo: https://github.com/returntocorp/semgrep - rev: v1.64.0 + rev: v1.66.0 hooks: - id: semgrep language_version: python3.9 diff --git a/poetry.lock b/poetry.lock index 681a059ab4e..3e32fe2c163 100644 --- a/poetry.lock +++ b/poetry.lock @@ -59,13 +59,13 @@ trio = ["trio (>=0.23)"] [[package]] name = "asgiref" -version = "3.7.2" +version = "3.8.1" description = "ASGI specs, helper code, and adapters" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "asgiref-3.7.2-py3-none-any.whl", hash = "sha256:89b2ef2247e3b562a16eef663bc0e2e703ec6468e2fa8a5cd61cd449786d4f6e"}, - {file = "asgiref-3.7.2.tar.gz", hash = "sha256:9e0ce3aa93a819ba5b45120216b23878cf6e8525eb3848653452b4192b92afed"}, + {file = "asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47"}, + {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, ] [package.dependencies] @@ -321,17 +321,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.64" +version = "1.34.69" description = "The AWS SDK for Python" optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ - {file = "boto3-1.34.64-py3-none-any.whl", hash = "sha256:8c6fbd3d45399a4e4685010117fb2dc52fc6afdab5a9460957d463ae0c2cc55d"}, - {file = "boto3-1.34.64.tar.gz", hash = "sha256:e5d681f443645e6953ed0727bf756bf16d85efefcb69cf051d04a070ce65e545"}, + {file = "boto3-1.34.69-py3-none-any.whl", hash = "sha256:2e25ef6bd325217c2da329829478be063155897d8d3b29f31f7f23ab548519b1"}, + {file = "boto3-1.34.69.tar.gz", hash = "sha256:898a5fed26b1351352703421d1a8b886ef2a74be6c97d5ecc92432ae01fda203"}, ] [package.dependencies] -botocore = ">=1.34.64,<1.35.0" +botocore = ">=1.34.69,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -340,13 +340,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.64" +version = "1.34.69" description = "Low-level, data-driven core of boto 3." optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ - {file = "botocore-1.34.64-py3-none-any.whl", hash = "sha256:0ab760908749fe82325698591c49755a5bb20307d85a419aca9cc74e783b9407"}, - {file = "botocore-1.34.64.tar.gz", hash = "sha256:084f8c45216d62dc1add2350e236a2d5283526aacd0681e9818b37a6a5e5438b"}, + {file = "botocore-1.34.69-py3-none-any.whl", hash = "sha256:d3802d076d4d507bf506f9845a6970ce43adc3d819dd57c2791f5c19ed6e5950"}, + {file = "botocore-1.34.69.tar.gz", hash = "sha256:d1ab2bff3c2fd51719c2021d9fa2f30fbb9ed0a308f69e9a774ac92c8091380a"}, ] [package.dependencies] @@ -792,13 +792,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "click-didyoumean" -version = "0.3.0" +version = "0.3.1" description = "Enables git-like *did-you-mean* feature in click" optional = false -python-versions = ">=3.6.2,<4.0.0" +python-versions = ">=3.6.2" files = [ - {file = "click-didyoumean-0.3.0.tar.gz", hash = "sha256:f184f0d851d96b6d29297354ed981b7dd71df7ff500d82fa6d11f0856bee8035"}, - {file = "click_didyoumean-0.3.0-py3-none-any.whl", hash = "sha256:a0713dc7a1de3f06bc0df5a9567ad19ead2d3d5689b434768a6145bff77c0667"}, + {file = "click_didyoumean-0.3.1-py3-none-any.whl", hash = "sha256:5c4bb6007cfea5f2fd6583a2fb6701a22a41eb98957e63d0fac41c10e7c3117c"}, + {file = "click_didyoumean-0.3.1.tar.gz", hash = "sha256:4f82fdff0dbe64ef8ab2279bd6aa3f6a99c3b28c05aa09cbfc07c9d7fbb5a463"}, ] [package.dependencies] @@ -1665,13 +1665,13 @@ yaml = ["PyYAML"] [[package]] name = "google-api-core" -version = "2.17.1" +version = "2.18.0" description = "Google API client core library" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-core-2.17.1.tar.gz", hash = "sha256:9df18a1f87ee0df0bc4eea2770ebc4228392d8cc4066655b320e2cfccb15db95"}, - {file = "google_api_core-2.17.1-py3-none-any.whl", hash = "sha256:610c5b90092c360736baccf17bd3efbcb30dd380e7a6dc28a71059edb8bd0d8e"}, + {file = "google-api-core-2.18.0.tar.gz", hash = "sha256:62d97417bfc674d6cef251e5c4d639a9655e00c45528c4364fbfebb478ce72a9"}, + {file = "google_api_core-2.18.0-py3-none-any.whl", hash = "sha256:5a63aa102e0049abe85b5b88cb9409234c1f70afcda21ce1e40b285b9629c1d6"}, ] [package.dependencies] @@ -1679,6 +1679,7 @@ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""} grpcio-status = {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""} +proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1689,13 +1690,13 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] [[package]] name = "google-auth" -version = "2.28.2" +version = "2.29.0" description = "Google Authentication Library" optional = false python-versions = ">=3.7" files = [ - {file = "google-auth-2.28.2.tar.gz", hash = "sha256:80b8b4969aa9ed5938c7828308f20f035bc79f9d8fb8120bf9dc8db20b41ba30"}, - {file = "google_auth-2.28.2-py2.py3-none-any.whl", hash = "sha256:9fd67bbcd40f16d9d42f950228e9cf02a2ded4ae49198b27432d0cded5a74c38"}, + {file = "google-auth-2.29.0.tar.gz", hash = "sha256:672dff332d073227550ffc7457868ac4218d6c500b155fe6cc17d2b13602c360"}, + {file = "google_auth-2.29.0-py2.py3-none-any.whl", hash = "sha256:d452ad095688cd52bae0ad6fafe027f6a6d6f560e810fec20914e17a09526415"}, ] [package.dependencies] @@ -1730,13 +1731,13 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"] [[package]] name = "google-cloud-pubsub" -version = "2.20.2" +version = "2.20.3" description = "Google Cloud Pub/Sub API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-pubsub-2.20.2.tar.gz", hash = "sha256:236046ea860230c788e4d4ea2d0f12299cdf1d94ac71ec42ed1a0ce1ba28d66f"}, - {file = "google_cloud_pubsub-2.20.2-py2.py3-none-any.whl", hash = "sha256:9607bb8f973cbd123b5fa2db9c0aa38501a1a42f18593739067cd307263d090f"}, + {file = "google-cloud-pubsub-2.20.3.tar.gz", hash = "sha256:76af0f179509e431d2bbe3f51f426256783dc959b8f458c1370cb19d6f7be0b1"}, + {file = "google_cloud_pubsub-2.20.3-py2.py3-none-any.whl", hash = "sha256:d9ae1e812800208492afa1a35ec194637600cd1c27eca382bd80aa1d4a168ffa"}, ] [package.dependencies] @@ -1753,13 +1754,13 @@ libcst = ["libcst (>=0.3.10)"] [[package]] name = "google-cloud-storage" -version = "2.15.0" +version = "2.16.0" description = "Google Cloud Storage API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-storage-2.15.0.tar.gz", hash = "sha256:7560a3c48a03d66c553dc55215d35883c680fe0ab44c23aa4832800ccc855c74"}, - {file = "google_cloud_storage-2.15.0-py2.py3-none-any.whl", hash = "sha256:5d9237f88b648e1d724a0f20b5cde65996a37fe51d75d17660b1404097327dd2"}, + {file = "google-cloud-storage-2.16.0.tar.gz", hash = "sha256:dda485fa503710a828d01246bd16ce9db0823dc51bbca742ce96a6817d58669f"}, + {file = "google_cloud_storage-2.16.0-py2.py3-none-any.whl", hash = "sha256:91a06b96fb79cf9cdfb4e759f178ce11ea885c79938f89590344d079305f5852"}, ] [package.dependencies] @@ -2242,13 +2243,13 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.0.2" +version = "7.1.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.2-py3-none-any.whl", hash = "sha256:f4bc4c0c070c490abf4ce96d715f68e95923320370efb66143df00199bb6c100"}, - {file = "importlib_metadata-7.0.2.tar.gz", hash = "sha256:198f568f3230878cb1b44fbd7975f87906c22336dba2e4a7f05278c281fbd792"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] @@ -2257,7 +2258,7 @@ zipp = ">=0.5" [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "iniconfig" @@ -3057,13 +3058,13 @@ xpath = ["lxml (>=4.4.0)"] [[package]] name = "phonenumberslite" -version = "8.13.32" +version = "8.13.33" description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers." optional = false python-versions = "*" files = [ - {file = "phonenumberslite-8.13.32-py2.py3-none-any.whl", hash = "sha256:7b6a539c2dd483a10385528a9cb31e217c6deecfa1fcce6fd0386cadadae2490"}, - {file = "phonenumberslite-8.13.32.tar.gz", hash = "sha256:e1368eb5a2622c2a1d8330fbddc8e7ac871e1b3776f82c338423823c27739159"}, + {file = "phonenumberslite-8.13.33-py2.py3-none-any.whl", hash = "sha256:4d92f4f9079bb83588dde45fd8a414bc13e4962886aa4d23576984196f4d83c2"}, + {file = "phonenumberslite-8.13.33.tar.gz", hash = "sha256:7426bc46af3de5a800a4c8f33ab13e33225d2c8ed4fc52aa3c0380dadd8d7381"}, ] [[package]] @@ -3240,13 +3241,13 @@ files = [ [[package]] name = "pre-commit" -version = "3.6.2" +version = "3.7.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.9" files = [ - {file = "pre_commit-3.6.2-py2.py3-none-any.whl", hash = "sha256:ba637c2d7a670c10daedc059f5c49b5bd0aadbccfcd7ec15592cf9665117532c"}, - {file = "pre_commit-3.6.2.tar.gz", hash = "sha256:c3ef34f463045c88658c5b99f38c1e297abdcc0ff13f98d3370055fbbfabc67e"}, + {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, + {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"}, ] [package.dependencies] @@ -3614,17 +3615,17 @@ test = ["covdefaults (>=2.2.2)", "coverage (>=7.0.5)", "flaky (>=3.7)", "pytest [[package]] name = "pytest-mock" -version = "3.12.0" +version = "3.14.0" description = "Thin-wrapper around the mock package for easier use with pytest" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, - {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, ] [package.dependencies] -pytest = ">=5.0" +pytest = ">=6.2.5" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] @@ -3882,13 +3883,13 @@ files = [ [[package]] name = "razorpay" -version = "1.4.1" +version = "1.4.2" description = "Razorpay Python Client" optional = false python-versions = "*" files = [ - {file = "razorpay-1.4.1-py3-none-any.whl", hash = "sha256:89b01c386cc93cc5af07d613c99c24fa3ded95d7def94b89e0ac294fb49cbe7c"}, - {file = "razorpay-1.4.1.tar.gz", hash = "sha256:236af5a6ae6512345907f4b7f42297980bb8314dbbbc684313a476b7f45179bd"}, + {file = "razorpay-1.4.2-py3-none-any.whl", hash = "sha256:aaf525baebc5001e5c54ae317a372c3c17c81314cedcbbf71aa0a2d3419a9757"}, + {file = "razorpay-1.4.2.tar.gz", hash = "sha256:0d812030ba29e776e66d1ceaacc31635810fde18b66ff2591e582eb651bb5131"}, ] [package.dependencies] @@ -4183,28 +4184,28 @@ files = [ [[package]] name = "ruff" -version = "0.3.3" +version = "0.3.4" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:973a0e388b7bc2e9148c7f9be8b8c6ae7471b9be37e1cc732f8f44a6f6d7720d"}, - {file = "ruff-0.3.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfa60d23269d6e2031129b053fdb4e5a7b0637fc6c9c0586737b962b2f834493"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eca7ff7a47043cf6ce5c7f45f603b09121a7cc047447744b029d1b719278eb5"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7d3f6762217c1da954de24b4a1a70515630d29f71e268ec5000afe81377642d"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b24c19e8598916d9c6f5a5437671f55ee93c212a2c4c569605dc3842b6820386"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5a6cbf216b69c7090f0fe4669501a27326c34e119068c1494f35aaf4cc683778"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:352e95ead6964974b234e16ba8a66dad102ec7bf8ac064a23f95371d8b198aab"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d6ab88c81c4040a817aa432484e838aaddf8bfd7ca70e4e615482757acb64f8"}, - {file = "ruff-0.3.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79bca3a03a759cc773fca69e0bdeac8abd1c13c31b798d5bb3c9da4a03144a9f"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2700a804d5336bcffe063fd789ca2c7b02b552d2e323a336700abb8ae9e6a3f8"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd66469f1a18fdb9d32e22b79f486223052ddf057dc56dea0caaf1a47bdfaf4e"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45817af234605525cdf6317005923bf532514e1ea3d9270acf61ca2440691376"}, - {file = "ruff-0.3.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0da458989ce0159555ef224d5b7c24d3d2e4bf4c300b85467b08c3261c6bc6a8"}, - {file = "ruff-0.3.3-py3-none-win32.whl", hash = "sha256:f2831ec6a580a97f1ea82ea1eda0401c3cdf512cf2045fa3c85e8ef109e87de0"}, - {file = "ruff-0.3.3-py3-none-win_amd64.whl", hash = "sha256:be90bcae57c24d9f9d023b12d627e958eb55f595428bafcb7fec0791ad25ddfc"}, - {file = "ruff-0.3.3-py3-none-win_arm64.whl", hash = "sha256:0171aab5fecdc54383993389710a3d1227f2da124d76a2784a7098e818f92d61"}, - {file = "ruff-0.3.3.tar.gz", hash = "sha256:38671be06f57a2f8aba957d9f701ea889aa5736be806f18c0cd03d6ff0cbca8d"}, + {file = "ruff-0.3.4-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:60c870a7d46efcbc8385d27ec07fe534ac32f3b251e4fc44b3cbfd9e09609ef4"}, + {file = "ruff-0.3.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6fc14fa742e1d8f24910e1fff0bd5e26d395b0e0e04cc1b15c7c5e5fe5b4af91"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3ee7880f653cc03749a3bfea720cf2a192e4f884925b0cf7eecce82f0ce5854"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf133dd744f2470b347f602452a88e70dadfbe0fcfb5fd46e093d55da65f82f7"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f3860057590e810c7ffea75669bdc6927bfd91e29b4baa9258fd48b540a4365"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:986f2377f7cf12efac1f515fc1a5b753c000ed1e0a6de96747cdf2da20a1b369"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fd98e85869603e65f554fdc5cddf0712e352fe6e61d29d5a6fe087ec82b76c"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64abeed785dad51801b423fa51840b1764b35d6c461ea8caef9cf9e5e5ab34d9"}, + {file = "ruff-0.3.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df52972138318bc7546d92348a1ee58449bc3f9eaf0db278906eb511889c4b50"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:98e98300056445ba2cc27d0b325fd044dc17fcc38e4e4d2c7711585bd0a958ed"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:519cf6a0ebed244dce1dc8aecd3dc99add7a2ee15bb68cf19588bb5bf58e0488"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:bb0acfb921030d00070539c038cd24bb1df73a2981e9f55942514af8b17be94e"}, + {file = "ruff-0.3.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cf187a7e7098233d0d0c71175375c5162f880126c4c716fa28a8ac418dcf3378"}, + {file = "ruff-0.3.4-py3-none-win32.whl", hash = "sha256:af27ac187c0a331e8ef91d84bf1c3c6a5dea97e912a7560ac0cef25c526a4102"}, + {file = "ruff-0.3.4-py3-none-win_amd64.whl", hash = "sha256:de0d5069b165e5a32b3c6ffbb81c350b1e3d3483347196ffdf86dc0ef9e37dd6"}, + {file = "ruff-0.3.4-py3-none-win_arm64.whl", hash = "sha256:6810563cc08ad0096b57c717bd78aeac888a1bfd38654d9113cb3dc4d3f74232"}, + {file = "ruff-0.3.4.tar.gz", hash = "sha256:f0f4484c6541a99862b693e13a151435a279b271cff20e37101116a21e2a1ad1"}, ] [[package]] @@ -4487,13 +4488,13 @@ files = [ [[package]] name = "textual" -version = "0.52.1" +version = "0.53.1" description = "Modern Text User Interface framework" optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "textual-0.52.1-py3-none-any.whl", hash = "sha256:960a19df2319482918b4a58736d9552cdc1ab65d170ba0bc15273ce0e1922b7a"}, - {file = "textual-0.52.1.tar.gz", hash = "sha256:4232e5c2b423ed7c63baaeb6030355e14e1de1b9df096c9655b68a1e60e4de5f"}, + {file = "textual-0.53.1-py3-none-any.whl", hash = "sha256:32201aa9d334ed064d5e670f15fe3d7f19c736ca54cecb054a5b995691104434"}, + {file = "textual-0.53.1.tar.gz", hash = "sha256:23ba673be7974819ded35ea88d28df7117987e53d58f15b2cc890ac2ecf56401"}, ] [package.dependencies] @@ -4520,12 +4521,12 @@ tornado = "*" [[package]] name = "thrift" -version = "0.16.0" +version = "0.20.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" files = [ - {file = "thrift-0.16.0.tar.gz", hash = "sha256:2b5b6488fcded21f9d312aa23c9ff6a0195d0f6ae26ddbd5ad9e3e25dfc14408"}, + {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, ] [package.dependencies] From ce4b31717ca2be3cc6c6f66dd5b876e9dd023253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:17:08 +0100 Subject: [PATCH 12/35] Use dataloaders for default account addresses (#15713) --- saleor/graphql/account/types.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/saleor/graphql/account/types.py b/saleor/graphql/account/types.py index f2e8637964e..beeb8341f7c 100644 --- a/saleor/graphql/account/types.py +++ b/saleor/graphql/account/types.py @@ -64,6 +64,7 @@ from .dataloaders import ( AccessibleChannelsByGroupIdLoader, AccessibleChannelsByUserIdLoader, + AddressByIdLoader, CustomerEventsByUserLoader, RestrictedChannelAccessByUserIdLoader, ThumbnailByUserIdSizeAndFormatLoader, @@ -715,6 +716,20 @@ def get_stored_payment_methods(data): ] ).then(get_stored_payment_methods) + @staticmethod + def resolve_default_billing_address(root: models.User, info: ResolveInfo): + if root.default_billing_address_id: + return AddressByIdLoader(info.context).load(root.default_billing_address_id) + return None + + @staticmethod + def resolve_default_shipping_address(root: models.User, info: ResolveInfo): + if root.default_shipping_address_id: + return AddressByIdLoader(info.context).load( + root.default_shipping_address_id + ) + return None + class UserCountableConnection(CountableConnection): class Meta: From 47cedfd7d6524d79bdb04708edcdbb235874de6b Mon Sep 17 00:00:00 2001 From: Iga Karbowiak <40886528+IKarbowiak@users.noreply.github.com> Date: Wed, 27 Mar 2024 13:56:53 +0100 Subject: [PATCH 13/35] Fix relation between order and click and collect address (#15697) * Create a copy of collection point address during checkout completion * Add migration for creating order shipping addresses --- saleor/checkout/complete_checkout.py | 3 ++ .../test_checkout_complete_with_payment.py | 8 ++-- ...est_checkout_complete_with_transactions.py | 6 ++- .../0172_update_order_cc_addresses.py | 39 +++++++++++++++++++ .../migrations/0176_merge_20240325_1315.py | 12 ++++++ .../migrations/0177_merge_20240325_1329.py | 12 ++++++ .../migrations/0180_merge_20240325_1333.py | 12 ++++++ .../migrations/0182_merge_20240325_1338.py | 12 ++++++ ...erge_20240325_1338_0183_order_tax_error.py | 12 ++++++ saleor/order/migrations/tasks/saleor3_19.py | 29 ++++++++++++++ 10 files changed, 140 insertions(+), 5 deletions(-) create mode 100644 saleor/order/migrations/0172_update_order_cc_addresses.py create mode 100644 saleor/order/migrations/0176_merge_20240325_1315.py create mode 100644 saleor/order/migrations/0177_merge_20240325_1329.py create mode 100644 saleor/order/migrations/0180_merge_20240325_1333.py create mode 100644 saleor/order/migrations/0182_merge_20240325_1338.py create mode 100644 saleor/order/migrations/0184_merge_0182_merge_20240325_1338_0183_order_tax_error.py create mode 100644 saleor/order/migrations/tasks/saleor3_19.py diff --git a/saleor/checkout/complete_checkout.py b/saleor/checkout/complete_checkout.py index 0cdd8795293..a974813af6b 100644 --- a/saleor/checkout/complete_checkout.py +++ b/saleor/checkout/complete_checkout.py @@ -145,6 +145,9 @@ def _process_shipping_data_for_order( if checkout_info.user.addresses.filter(pk=shipping_address.pk).exists(): shipping_address = shipping_address.get_copy() + if shipping_address and delivery_method_info.warehouse_pk: + shipping_address = shipping_address.get_copy() + shipping_method = delivery_method_info.delivery_method tax_class = getattr(shipping_method, "tax_class", None) diff --git a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py index 21b68047422..901904b4781 100644 --- a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py +++ b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py @@ -3105,7 +3105,7 @@ def test_complete_checkout_for_local_click_and_collect( order_count = Order.objects.count() checkout = checkout_with_item_for_cc checkout.collection_point = warehouse_for_cc - checkout.shipping_address = None + checkout.shipping_address = warehouse_for_cc.address checkout.save(update_fields=["collection_point", "shipping_address"]) variables = { @@ -3147,7 +3147,8 @@ def test_complete_checkout_for_local_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(payment.currency) assert order.lines.count() == 1 @@ -3215,7 +3216,8 @@ def test_complete_checkout_for_global_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(payment.currency) assert order.lines.count() == 1 diff --git a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py index a7c39822c05..c553436e272 100644 --- a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py +++ b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_transactions.py @@ -2773,7 +2773,8 @@ def test_complete_checkout_for_local_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(order.channel.currency_code) assert order.lines.count() == 1 @@ -2833,7 +2834,8 @@ def test_complete_checkout_for_global_click_and_collect( assert order.collection_point == warehouse_for_cc assert order.shipping_method is None - assert order.shipping_address == warehouse_for_cc.address + assert order.shipping_address + assert order.shipping_address.id != warehouse_for_cc.address.id assert order.shipping_price == zero_taxed_money(order.channel.currency_code) assert order.lines.count() == 1 diff --git a/saleor/order/migrations/0172_update_order_cc_addresses.py b/saleor/order/migrations/0172_update_order_cc_addresses.py new file mode 100644 index 00000000000..773ebd9bbb6 --- /dev/null +++ b/saleor/order/migrations/0172_update_order_cc_addresses.py @@ -0,0 +1,39 @@ +from django.db import migrations +from django.db.models import Exists, OuterRef +from django.forms.models import model_to_dict + +from .tasks.saleor3_19 import update_order_addresses_task + +# The batch of size 250 takes ~0.5 second and consumes ~20MB memory at peak +ADDRESS_UPDATE_BATCH_SIZE = 250 + + +def update_order_addresses(apps, schema_editor): + Order = apps.get_model("order", "Order") + Warehouse = apps.get_model("warehouse", "Warehouse") + Address = apps.get_model("account", "Address") + qs = Order.objects.filter( + Exists(Warehouse.objects.filter(address_id=OuterRef("shipping_address_id"))), + ) + order_ids = qs.values_list("pk", flat=True)[:ADDRESS_UPDATE_BATCH_SIZE] + addresses = [] + if order_ids: + orders = Order.objects.filter(id__in=order_ids) + for order in orders: + if cc_address := order.shipping_address: + order_address = Address(**model_to_dict(cc_address, exclude=["id"])) + order.shipping_address = order_address + addresses.append(order_address) + Address.objects.bulk_create(addresses, ignore_conflicts=True) + Order.objects.bulk_update(orders, ["shipping_address"]) + update_order_addresses_task.delay() + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0171_order_order_user_email_user_id_idx"), + ] + + operations = [ + migrations.RunPython(update_order_addresses, migrations.RunPython.noop), + ] diff --git a/saleor/order/migrations/0176_merge_20240325_1315.py b/saleor/order/migrations/0176_merge_20240325_1315.py new file mode 100644 index 00000000000..7bb6b948017 --- /dev/null +++ b/saleor/order/migrations/0176_merge_20240325_1315.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:15 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0172_update_order_cc_addresses"), + ("order", "0175_merge_20231122_1040"), + ] + + operations = [] diff --git a/saleor/order/migrations/0177_merge_20240325_1329.py b/saleor/order/migrations/0177_merge_20240325_1329.py new file mode 100644 index 00000000000..2fd49329770 --- /dev/null +++ b/saleor/order/migrations/0177_merge_20240325_1329.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:29 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0176_merge_20231122_1346"), + ("order", "0176_merge_20240325_1315"), + ] + + operations = [] diff --git a/saleor/order/migrations/0180_merge_20240325_1333.py b/saleor/order/migrations/0180_merge_20240325_1333.py new file mode 100644 index 00000000000..675bc15cab5 --- /dev/null +++ b/saleor/order/migrations/0180_merge_20240325_1333.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:33 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0177_merge_20240325_1329"), + ("order", "0179_merge_20231122_1348"), + ] + + operations = [] diff --git a/saleor/order/migrations/0182_merge_20240325_1338.py b/saleor/order/migrations/0182_merge_20240325_1338.py new file mode 100644 index 00000000000..d11d3a3d82d --- /dev/null +++ b/saleor/order/migrations/0182_merge_20240325_1338.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:38 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0180_merge_20240325_1333"), + ("order", "0181_order_subtotal_as_a_field"), + ] + + operations = [] diff --git a/saleor/order/migrations/0184_merge_0182_merge_20240325_1338_0183_order_tax_error.py b/saleor/order/migrations/0184_merge_0182_merge_20240325_1338_0183_order_tax_error.py new file mode 100644 index 00000000000..0b8a3c21901 --- /dev/null +++ b/saleor/order/migrations/0184_merge_0182_merge_20240325_1338_0183_order_tax_error.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-03-25 13:42 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("order", "0182_merge_20240325_1338"), + ("order", "0183_order_tax_error"), + ] + + operations = [] diff --git a/saleor/order/migrations/tasks/saleor3_19.py b/saleor/order/migrations/tasks/saleor3_19.py new file mode 100644 index 00000000000..ddc468009b7 --- /dev/null +++ b/saleor/order/migrations/tasks/saleor3_19.py @@ -0,0 +1,29 @@ +from django.db.models import Exists, OuterRef +from django.forms.models import model_to_dict + +from ....account.models import Address +from ....celeryconf import app +from ....warehouse.models import Warehouse +from ...models import Order + +# The batch of size 250 takes ~0.5 second and consumes ~20MB memory at peak +ADDRESS_UPDATE_BATCH_SIZE = 250 + + +@app.task +def update_order_addresses_task(): + qs = Order.objects.filter( + Exists(Warehouse.objects.filter(address_id=OuterRef("shipping_address_id"))), + ) + order_ids = qs.values_list("pk", flat=True)[:ADDRESS_UPDATE_BATCH_SIZE] + addresses = [] + if order_ids: + orders = Order.objects.filter(id__in=order_ids) + for order in orders: + if cc_address := order.shipping_address: + order_address = Address(**model_to_dict(cc_address, exclude=["id"])) + order.shipping_address = order_address + addresses.append(order_address) + Address.objects.bulk_create(addresses, ignore_conflicts=True) + Order.objects.bulk_update(orders, ["shipping_address"]) + update_order_addresses_task.delay() From 16ba23e9b2486a97587fd831a4e9e8988ff073d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Wed, 27 Mar 2024 15:07:06 +0100 Subject: [PATCH 14/35] Refactor line weight calculation to reuse DB queries (#15714) * Refactor line weight calculation to reuse DB queries * Fix tests * Pass already fetched channel to excluded_shipping_methods_for_checkout manager method (#15715) --- saleor/checkout/complete_checkout.py | 3 ++- saleor/checkout/fetch.py | 2 +- saleor/checkout/models.py | 21 +--------------- saleor/checkout/tests/test_cart.py | 8 ++++-- saleor/checkout/utils.py | 22 ++++++++++++++++ .../benchmark/test_checkout_mutations.py | 2 +- .../graphql/order/tests/queries/test_order.py | 4 +-- saleor/order/models.py | 3 --- saleor/order/tests/test_order.py | 4 +-- saleor/order/utils.py | 6 +++-- saleor/plugins/manager.py | 3 ++- saleor/shipping/models.py | 25 ++++++++++++++----- 12 files changed, 62 insertions(+), 41 deletions(-) diff --git a/saleor/checkout/complete_checkout.py b/saleor/checkout/complete_checkout.py index a974813af6b..4aaad9f2067 100644 --- a/saleor/checkout/complete_checkout.py +++ b/saleor/checkout/complete_checkout.py @@ -82,6 +82,7 @@ ) from .models import Checkout from .utils import ( + calculate_checkout_weight, get_checkout_metadata, get_or_create_checkout_metadata, get_voucher_for_checkout_info, @@ -155,7 +156,7 @@ def _process_shipping_data_for_order( "shipping_address": shipping_address, "base_shipping_price": base_shipping_price, "shipping_price": shipping_price, - "weight": checkout_info.checkout.get_total_weight(lines), + "weight": calculate_checkout_weight(lines), **get_shipping_tax_class_kwargs_for_order(tax_class), } result.update(delivery_method_info.delivery_method_order_field) diff --git a/saleor/checkout/fetch.py b/saleor/checkout/fetch.py index 4965f0c7ba1..c4861ea4333 100644 --- a/saleor/checkout/fetch.py +++ b/saleor/checkout/fetch.py @@ -671,7 +671,7 @@ def _resolve_all_shipping_methods(): ) # Filter shipping methods using sync webhooks excluded_methods = manager.excluded_shipping_methods_for_checkout( - checkout_info.checkout, all_methods + checkout_info.checkout, checkout_info.channel, all_methods ) initialize_shipping_method_active_status(all_methods, excluded_methods) return all_methods diff --git a/saleor/checkout/models.py b/saleor/checkout/models.py index 04011e966e9..8f4ae38288c 100644 --- a/saleor/checkout/models.py +++ b/saleor/checkout/models.py @@ -1,10 +1,9 @@ """Checkout-related ORM models.""" -from collections.abc import Iterable from datetime import date from decimal import Decimal from operator import attrgetter -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional from uuid import uuid4 from django.conf import settings @@ -19,19 +18,14 @@ from ..channel.models import Channel from ..core.models import ModelWithMetadata from ..core.taxes import zero_money -from ..core.weight import zero_weight from ..giftcard.models import GiftCard from ..permission.enums import CheckoutPermissions from ..shipping.models import ShippingMethod from . import CheckoutAuthorizeStatus, CheckoutChargeStatus if TYPE_CHECKING: - from django_measurement import Weight - - from ..order.fetch import OrderLineInfo from ..payment.models import Payment from ..product.models import ProductVariant - from .fetch import CheckoutLineInfo def get_default_country(): @@ -234,19 +228,6 @@ def get_total_gift_cards_balance( return zero_money(currency=self.currency) return Money(balance, self.currency) - def get_total_weight( - self, lines: Union[Iterable["CheckoutLineInfo"], Iterable["OrderLineInfo"]] - ) -> "Weight": - # FIXME: it does not make sense for this method to live in the Checkout model - # since it's used in the Order model as well. We should move it to a separate - # helper. - weights = zero_weight() - for checkout_line_info in lines: - line = checkout_line_info.line - if line.variant: - weights += line.variant.get_weight() * line.quantity - return weights - def get_line(self, variant: "ProductVariant") -> Optional["CheckoutLine"]: """Return a line matching the given variant and data if any.""" matching_lines = (line for line in self if line.variant.pk == variant.pk) diff --git a/saleor/checkout/tests/test_cart.py b/saleor/checkout/tests/test_cart.py index 48747849fff..55011b6f7a6 100644 --- a/saleor/checkout/tests/test_cart.py +++ b/saleor/checkout/tests/test_cart.py @@ -10,7 +10,11 @@ from ...product.models import Category from .. import calculations, utils from ..models import Checkout -from ..utils import add_variant_to_checkout, calculate_checkout_quantity +from ..utils import ( + add_variant_to_checkout, + calculate_checkout_quantity, + calculate_checkout_weight, +) @pytest.fixture @@ -283,4 +287,4 @@ def test_get_total_weight(checkout_with_item): line.quantity = 6 line.save() lines, _ = fetch_checkout_lines(checkout_with_item) - assert checkout_with_item.get_total_weight(lines) == Weight(kg=60) + assert calculate_checkout_weight(lines) == Weight(kg=60) diff --git a/saleor/checkout/utils.py b/saleor/checkout/utils.py index 9d8f97082c3..4aa267258db 100644 --- a/saleor/checkout/utils.py +++ b/saleor/checkout/utils.py @@ -22,6 +22,7 @@ promo_code_is_voucher, ) from ..core.utils.translations import get_translation +from ..core.weight import zero_weight from ..discount import DiscountType, VoucherType from ..discount.interface import VoucherInfo, fetch_voucher_info from ..discount.models import ( @@ -59,6 +60,8 @@ from .models import Checkout, CheckoutLine, CheckoutMetadata if TYPE_CHECKING: + from measurement.measures import Weight + from ..account.models import Address from ..order.models import Order from .fetch import CheckoutInfo, CheckoutLineInfo @@ -865,6 +868,7 @@ def get_valid_internal_shipping_methods_for_checkout( checkout_info.checkout, channel_id=checkout_info.checkout.channel_id, price=subtotal, + shipping_address=checkout_info.shipping_address, country_code=country_code, lines=lines, ) @@ -1024,3 +1028,21 @@ def get_checkout_metadata(checkout: "Checkout"): return checkout.metadata_storage else: return CheckoutMetadata(checkout=checkout) + + +def calculate_checkout_weight(lines: Iterable["CheckoutLineInfo"]) -> "Weight": + weights = zero_weight() + for checkout_line_info in lines: + variant = checkout_line_info.variant + if variant: + line_weight = get_checkout_line_weight(checkout_line_info) + weights += line_weight * checkout_line_info.line.quantity + return weights + + +def get_checkout_line_weight(line_info: "CheckoutLineInfo"): + return ( + line_info.variant.weight + or line_info.product.weight + or line_info.product_type.weight + ) diff --git a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py index d112bbcf12e..802fbe36d7f 100644 --- a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py +++ b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py @@ -562,7 +562,7 @@ def test_create_checkout_with_order_promotion( } # when - with django_assert_num_queries(88): + with django_assert_num_queries(71): response = user_api_client.post_graphql(MUTATION_CHECKOUT_CREATE, variables) # then diff --git a/saleor/graphql/order/tests/queries/test_order.py b/saleor/graphql/order/tests/queries/test_order.py index 3b25d606076..20e2890993f 100644 --- a/saleor/graphql/order/tests/queries/test_order.py +++ b/saleor/graphql/order/tests/queries/test_order.py @@ -349,7 +349,7 @@ def test_order_query( expected_methods = ShippingMethod.objects.applicable_shipping_methods( price=order.subtotal.gross, - weight=order.get_total_weight(), + weight=order.weight, country_code=order.shipping_address.country.code, channel_id=order.channel_id, ) @@ -997,7 +997,7 @@ def test_order_query_in_pln_channel( expected_methods = ShippingMethod.objects.applicable_shipping_methods( price=order.subtotal.gross, - weight=order.get_total_weight(), + weight=order.weight, country_code=order.shipping_address.country.code, channel_id=order.channel_id, ) diff --git a/saleor/order/models.py b/saleor/order/models.py index 79ce62b9a01..469dc016593 100644 --- a/saleor/order/models.py +++ b/saleor/order/models.py @@ -498,9 +498,6 @@ def can_mark_as_paid(self, payments=None): def total_balance(self): return self.total_charged - self.total.gross - def get_total_weight(self, _lines=None): - return self.weight - class OrderLineQueryset(models.QuerySet["OrderLine"]): def digital(self): diff --git a/saleor/order/tests/test_order.py b/saleor/order/tests/test_order.py index 89c6ad9f3d6..7cc2c18b782 100644 --- a/saleor/order/tests/test_order.py +++ b/saleor/order/tests/test_order.py @@ -658,12 +658,12 @@ def test_get_order_weight_non_existing_product( app=None, manager=anonymous_plugins, ) - old_weight = order.get_total_weight() + old_weight = order.weight product.delete() order.refresh_from_db() - new_weight = order.get_total_weight() + new_weight = order.weight assert old_weight == new_weight diff --git a/saleor/order/utils.py b/saleor/order/utils.py index ed0317155c5..fcbe3e1503e 100644 --- a/saleor/order/utils.py +++ b/saleor/order/utils.py @@ -611,7 +611,8 @@ def get_all_shipping_methods_for_order( if not order.is_shipping_required(): return [] - if not order.shipping_address: + shipping_address = order.shipping_address + if not shipping_address: return [] all_methods = [] @@ -622,7 +623,8 @@ def get_all_shipping_methods_for_order( order, channel_id=order.channel_id, price=order.subtotal.gross, - country_code=order.shipping_address.country.code, + shipping_address=shipping_address, + country_code=shipping_address.country.code, ) .prefetch_related("channel_listings") ) diff --git a/saleor/plugins/manager.py b/saleor/plugins/manager.py index ae05b8c4640..9ae90c1e61d 100644 --- a/saleor/plugins/manager.py +++ b/saleor/plugins/manager.py @@ -2164,6 +2164,7 @@ def excluded_shipping_methods_for_order( def excluded_shipping_methods_for_checkout( self, checkout: "Checkout", + channel: "Channel", available_shipping_methods: list["ShippingMethodData"], ) -> list[ExcludedShippingMethod]: return self.__run_method_on_plugins( @@ -2171,7 +2172,7 @@ def excluded_shipping_methods_for_checkout( [], checkout, available_shipping_methods, - channel_slug=checkout.channel.slug, + channel_slug=channel.slug, ) def perform_mutation( diff --git a/saleor/shipping/models.py b/saleor/shipping/models.py index 996aed42753..bebcdc01cc1 100644 --- a/saleor/shipping/models.py +++ b/saleor/shipping/models.py @@ -24,6 +24,7 @@ from .postal_codes import filter_shipping_methods_by_postal_code_rules if TYPE_CHECKING: + from ..account.models import Address from ..checkout.fetch import CheckoutLineInfo from ..checkout.models import Checkout from ..order.fetch import OrderLineInfo @@ -165,32 +166,44 @@ def applicable_shipping_methods_for_instance( instance: Union["Checkout", "Order"], channel_id, price: Money, + shipping_address: Optional["Address"] = None, country_code: Optional[str] = None, lines: Union[ Iterable["CheckoutLineInfo"], Iterable["OrderLineInfo"], None ] = None, ): - if not instance.shipping_address: + if not shipping_address: return None + if not country_code: - # TODO: country_code should come from argument - country_code = instance.shipping_address.country.code + country_code = shipping_address.country.code + if lines is None: # TODO: lines should comes from args in get_valid_shipping_methods_for_order lines = list(instance.lines.prefetch_related("variant__product").all()) # type: ignore[misc] # this is hack # noqa: E501 instance_product_ids = { line.variant.product_id for line in lines if line.variant } + + from ..checkout.models import Checkout + + if isinstance(instance, Checkout): + from ..checkout.utils import calculate_checkout_weight + + weight = calculate_checkout_weight(lines) # type: ignore[arg-type] + else: + weight = instance.weight + applicable_methods = self.applicable_shipping_methods( price=price, channel_id=channel_id, - weight=instance.get_total_weight(lines), - country_code=country_code or instance.shipping_address.country.code, + weight=weight, + country_code=country_code, product_ids=instance_product_ids, ).prefetch_related("postal_code_rules") return filter_shipping_methods_by_postal_code_rules( - applicable_methods, instance.shipping_address + applicable_methods, shipping_address ) From 1d47e4485eb4a30ef620fb7341be8fc7b4a05cb3 Mon Sep 17 00:00:00 2001 From: Filip Owczarek Date: Thu, 28 Mar 2024 09:56:38 +0100 Subject: [PATCH 15/35] Turn off Jaeger during pytest (#15718) --- saleor/settings.py | 5 +++-- saleor/tests/settings.py | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/saleor/settings.py b/saleor/settings.py index 3c70df8960d..e0acb03a389 100644 --- a/saleor/settings.py +++ b/saleor/settings.py @@ -805,7 +805,8 @@ def SENTRY_INIT(dsn: str, sentry_opts: dict): # # If running locally, set: # JAEGER_AGENT_HOST=localhost -if "JAEGER_AGENT_HOST" in os.environ: +JAEGER_HOST = os.environ.get("JAEGER_AGENT_HOST") +if JAEGER_HOST: jaeger_client.Config( config={ "sampler": {"type": "const", "param": 1}, @@ -813,7 +814,7 @@ def SENTRY_INIT(dsn: str, sentry_opts: dict): "reporting_port": os.environ.get( "JAEGER_AGENT_PORT", jaeger_client.config.DEFAULT_REPORTING_PORT ), - "reporting_host": os.environ.get("JAEGER_AGENT_HOST"), + "reporting_host": JAEGER_HOST, }, "logging": get_bool_from_env("JAEGER_LOGGING", False), }, diff --git a/saleor/tests/settings.py b/saleor/tests/settings.py index e2e376ef3f8..deb53017348 100644 --- a/saleor/tests/settings.py +++ b/saleor/tests/settings.py @@ -1,9 +1,14 @@ +import os import re from re import Pattern from typing import Union from django.utils.functional import SimpleLazyObject +# Disable Jaeger tracing should be done before importing settings. +# without this line pytest will start sending traces to Jaeger agent. +os.environ["JAEGER_AGENT_HOST"] = "" + from ..settings import * # noqa From 20929abf68d6829cd1cf900fc46439107e4bf9d2 Mon Sep 17 00:00:00 2001 From: Maciej Korycinski Date: Thu, 28 Mar 2024 12:41:12 +0100 Subject: [PATCH 16/35] Fix failing exif data validation (#15684) --- saleor/thumbnail/tests/test_utils.py | 23 +++++++++++++++++++++++ saleor/thumbnail/utils.py | 11 ++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/saleor/thumbnail/tests/test_utils.py b/saleor/thumbnail/tests/test_utils.py index 9d7cd2f1e44..db5d308ed67 100644 --- a/saleor/thumbnail/tests/test_utils.py +++ b/saleor/thumbnail/tests/test_utils.py @@ -1,8 +1,10 @@ +from unittest import mock from unittest.mock import MagicMock import graphene import pytest from django.core.files import File +from PIL.JpegImagePlugin import JpegImageFile from .. import FILE_NAME_MAX_LENGTH, ThumbnailFormat from ..models import Thumbnail @@ -115,6 +117,27 @@ def test_processed_image_preprocess_method_called(category_with_image, thumb_for preprocess_mock.assert_called_once() +@pytest.mark.parametrize("thumb_format", [ThumbnailFormat.WEBP, ThumbnailFormat.AVIF]) +@mock.patch.object(JpegImageFile, "_getexif") +def test_processed_image_preprocess_with_exif_corrupted( + mocked_getexif, category_with_image, thumb_format +): + # given + image_path = category_with_image.background_image.name + processed_image = ProcessedImage(image_path, 128, thumb_format) + preprocess_method_name = f"preprocess_{thumb_format.upper()}" + preprocess_mock = MagicMock() + preprocess_mock.side_effect = getattr(processed_image, preprocess_method_name) + setattr(processed_image, preprocess_method_name, preprocess_mock) + mocked_getexif.side_effect = SyntaxError() + + # when + processed_image.create_thumbnail() + + # then + preprocess_mock.assert_called_once() + + def test_get_filename_from_url_unique(): # given file_format = "jpg" diff --git a/saleor/thumbnail/utils.py b/saleor/thumbnail/utils.py index adca9198bea..102be5f6cc6 100644 --- a/saleor/thumbnail/utils.py +++ b/saleor/thumbnail/utils.py @@ -173,7 +173,16 @@ def preprocess(self, image, image_format): # Ensuring image is properly rotated if hasattr(image, "_getexif"): - exif_datadict = image._getexif() # returns None if no EXIF data + try: + # validation of the exif data was added in separate PR: + # https://github.com/saleor/saleor/pull/11224, it means that there is a + # possibility that we could have the file with corrupted exif data. + # exif data is only used to apply some optional action on the image, + # but without it, we are still able to create a thumbnail. + exif_datadict = image._getexif() # returns None if no EXIF data + except SyntaxError: + exif_datadict = None + if exif_datadict is not None: exif = dict(exif_datadict.items()) orientation = exif.get(self.EXIF_ORIENTATION_KEY, None) From 5b95bbf929aad091032a2faeba54f21d2fe68b4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:37:43 +0200 Subject: [PATCH 17/35] Add allow_writer usages to delivery attempts and payments (#15729) --- saleor/payment/utils.py | 10 +++- saleor/webhook/payloads.py | 32 ++++++++++++ .../transport/asynchronous/transport.py | 22 ++++---- .../transport/synchronous/transport.py | 51 +++++++++++-------- saleor/webhook/transport/utils.py | 20 +++++--- 5 files changed, 95 insertions(+), 40 deletions(-) diff --git a/saleor/payment/utils.py b/saleor/payment/utils.py index b54f90dbca3..621342ea74a 100644 --- a/saleor/payment/utils.py +++ b/saleor/payment/utils.py @@ -24,6 +24,7 @@ from ..checkout.fetch import fetch_checkout_info, fetch_checkout_lines from ..checkout.models import Checkout from ..checkout.payment_utils import update_refundable_for_checkout +from ..core.db.connection import allow_writer from ..core.prices import quantize_price from ..core.tracing import traced_atomic_transaction from ..graphql.core.utils import str_to_enum @@ -388,7 +389,8 @@ def create_payment( "metadata": {} if metadata is None else metadata, } - payment, _ = Payment.objects.get_or_create(defaults=defaults, **data) + with allow_writer(): + payment, _ = Payment.objects.get_or_create(defaults=defaults, **data) return payment @@ -1008,6 +1010,7 @@ def get_failed_type_based_on_event(event: TransactionEvent): return event.type +@allow_writer() def create_failed_transaction_event( event: TransactionEvent, cause: str, @@ -1485,7 +1488,8 @@ def create_manual_adjustment_events( user=user, ) if events_to_create: - return TransactionEvent.objects.bulk_create(events_to_create) + with allow_writer(): + return TransactionEvent.objects.bulk_create(events_to_create) return [] @@ -1512,6 +1516,7 @@ def get_transaction_item_params( } +@allow_writer() def create_transaction_for_order( order: "Order", user: Optional["User"], @@ -1540,6 +1545,7 @@ def create_transaction_for_order( return transaction_item +@allow_writer() def handle_transaction_initialize_session( source_object: Union[Checkout, Order], payment_gateway_data: PaymentGatewayData, diff --git a/saleor/webhook/payloads.py b/saleor/webhook/payloads.py index d982a787892..578b0deac31 100644 --- a/saleor/webhook/payloads.py +++ b/saleor/webhook/payloads.py @@ -23,6 +23,7 @@ from ..checkout.fetch import CheckoutInfo, CheckoutLineInfo from ..checkout.models import Checkout from ..checkout.utils import get_checkout_metadata +from ..core.db.connection import allow_writer from ..core.prices import quantize_price, quantize_price_fields from ..core.utils import build_absolute_uri from ..core.utils.anonymization import ( @@ -139,6 +140,7 @@ def generate_meta(*, requestor_data: dict[str, Any], camel_case=False, **kwargs) return meta +@allow_writer() @traced_payload_generator def generate_metadata_updated_payload( instance: Any, requestor: Optional["RequestorOrLazyObject"] = None @@ -173,6 +175,7 @@ def prepare_order_lines_allocations_payload(line): return warehouse_id_quantity_allocated_map +@allow_writer() @traced_payload_generator def generate_order_lines_payload(lines: Iterable[OrderLine]): line_fields = ( @@ -271,6 +274,7 @@ def _generate_shipping_method_payload(shipping_method, channel): return json.loads(payload)[0] +@allow_writer() @traced_payload_generator def generate_order_payload( order: "Order", @@ -411,6 +415,7 @@ def _calculate_removed( return _calculate_added(current_catalogue, previous_catalogue, key) +@allow_writer() @traced_payload_generator def generate_sale_payload( promotion: "Promotion", @@ -459,6 +464,7 @@ def generate_sale_payload( ) +@allow_writer() @traced_payload_generator def generate_sale_toggle_payload( promotion: "Promotion", @@ -481,6 +487,7 @@ def generate_sale_toggle_payload( ) +@allow_writer() @traced_payload_generator def generate_invoice_payload( invoice: "Invoice", requestor: Optional["RequestorOrLazyObject"] = None @@ -521,6 +528,7 @@ def _generate_order_payload_for_invoice(order: "Order"): return payload +@allow_writer() @traced_payload_generator def generate_checkout_payload( checkout: "Checkout", requestor: Optional["RequestorOrLazyObject"] = None @@ -595,6 +603,7 @@ def generate_checkout_payload( return checkout_data +@allow_writer() @traced_payload_generator def generate_customer_payload( customer: "User", requestor: Optional["RequestorOrLazyObject"] = None @@ -633,6 +642,7 @@ def generate_customer_payload( return data +@allow_writer() @traced_payload_generator def generate_collection_payload( collection: "Collection", requestor: Optional["RequestorOrLazyObject"] = None @@ -704,6 +714,7 @@ def _get_charge_taxes_for_product(product: "Product") -> bool: return charge_taxes +@allow_writer() @traced_payload_generator def generate_product_payload( product: "Product", requestor: Optional["RequestorOrLazyObject"] = None @@ -747,6 +758,7 @@ def generate_product_payload( return product_payload +@allow_writer() @traced_payload_generator def generate_product_deleted_payload( product: "Product", variants_id, requestor: Optional["RequestorOrLazyObject"] = None @@ -777,6 +789,7 @@ def generate_product_deleted_payload( ) +@allow_writer() @traced_payload_generator def generate_product_variant_listings_payload(variant_channel_listings): serializer = PayloadSerializer() @@ -793,6 +806,7 @@ def generate_product_variant_listings_payload(variant_channel_listings): return channel_listing_payload +@allow_writer() @traced_payload_generator def generate_product_variant_media_payload(product_variant): return [ @@ -808,6 +822,7 @@ def generate_product_variant_media_payload(product_variant): ] +@allow_writer() @traced_payload_generator def generate_product_variant_with_stock_payload( stocks: Iterable["Stock"], requestor: Optional["RequestorOrLazyObject"] = None @@ -829,6 +844,7 @@ def generate_product_variant_with_stock_payload( return serializer.serialize(stocks, fields=[], extra_dict_data=extra_dict_data) +@allow_writer() @traced_payload_generator def generate_product_variant_payload( product_variants: Iterable["ProductVariant"], @@ -867,11 +883,13 @@ def generate_product_variant_payload( return payload +@allow_writer() @traced_payload_generator def generate_product_variant_stocks_payload(product_variant: "ProductVariant"): return product_variant.stocks.aggregate(Sum("quantity"))["quantity__sum"] or 0 +@allow_writer() @traced_payload_generator def generate_fulfillment_lines_payload(fulfillment: Fulfillment): serializer = PayloadSerializer() @@ -942,6 +960,7 @@ def generate_fulfillment_lines_payload(fulfillment: Fulfillment): ) +@allow_writer() @traced_payload_generator def generate_fulfillment_payload( fulfillment: Fulfillment, requestor: Optional["RequestorOrLazyObject"] = None @@ -989,6 +1008,7 @@ def generate_fulfillment_payload( return fulfillment_data +@allow_writer() @traced_payload_generator def generate_page_payload( page: Page, requestor: Optional["RequestorOrLazyObject"] = None @@ -1039,6 +1059,7 @@ def _generate_refund_data_payload(data): return data +@allow_writer() @traced_payload_generator def generate_payment_payload( payment_data: "PaymentData", requestor: Optional["RequestorOrLazyObject"] = None @@ -1057,6 +1078,7 @@ def generate_payment_payload( return json.dumps(data, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_list_gateways_payload( currency: Optional[str], checkout: Optional["Checkout"] @@ -1114,6 +1136,7 @@ def _generate_sample_order_payload(event_name): return generate_order_payload(anonymized_order) +@allow_writer() @traced_payload_generator def generate_sample_payload(event_name: str) -> Optional[dict]: checkout_events = [ @@ -1176,6 +1199,7 @@ def process_translation_context(context): return result +@allow_writer() @traced_payload_generator def generate_translation_payload( translation: "Translation", requestor: Optional["RequestorOrLazyObject"] = None @@ -1218,6 +1242,7 @@ def _generate_payload_for_shipping_method(method: ShippingMethodData): return payload +@allow_writer() @traced_payload_generator def generate_excluded_shipping_methods_for_order_payload( order: "Order", @@ -1234,6 +1259,7 @@ def generate_excluded_shipping_methods_for_order_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_excluded_shipping_methods_for_checkout_payload( checkout: "Checkout", @@ -1250,6 +1276,7 @@ def generate_excluded_shipping_methods_for_checkout_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_checkout_payload_for_tax_calculation( checkout_info: "CheckoutInfo", @@ -1381,6 +1408,7 @@ def _generate_order_lines_payload_for_tax_calculation(lines: QuerySet[OrderLine] ) +@allow_writer() @traced_payload_generator def generate_order_payload_for_tax_calculation(order: "Order"): serializer = PayloadSerializer() @@ -1437,6 +1465,7 @@ def generate_order_payload_for_tax_calculation(order: "Order"): return order_data +@allow_writer() @traced_payload_generator def generate_transaction_action_request_payload( transaction_data: "TransactionActionData", @@ -1496,6 +1525,7 @@ def generate_transaction_action_request_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() def generate_transaction_session_payload( transaction_process_action: "TransactionProcessActionData", transaction: "TransactionItem", @@ -1519,12 +1549,14 @@ def generate_transaction_session_payload( return json.dumps(payload, cls=CustomJsonEncoder) +@allow_writer() @traced_payload_generator def generate_thumbnail_payload(thumbnail: Thumbnail): thumbnail_id = graphene.Node.to_global_id("Thumbnail", thumbnail.id) return json.dumps({"id": thumbnail_id}) +@allow_writer() @traced_payload_generator def generate_product_media_payload(media: ProductMedia): product_media_id = graphene.Node.to_global_id("ProductMedia", media.id) diff --git a/saleor/webhook/transport/asynchronous/transport.py b/saleor/webhook/transport/asynchronous/transport.py index 83034ebe485..0041355f070 100644 --- a/saleor/webhook/transport/asynchronous/transport.py +++ b/saleor/webhook/transport/asynchronous/transport.py @@ -11,6 +11,7 @@ from ....celeryconf import app from ....core import EventDeliveryStatus +from ....core.db.connection import allow_writer from ....core.models import EventDelivery, EventPayload from ....core.tracing import webhooks_opentracing_trace from ....core.utils import get_domain @@ -129,8 +130,9 @@ def create_deliveries_for_subscriptions( ) ) - EventPayload.objects.bulk_create(event_payloads) - return EventDelivery.objects.bulk_create(event_deliveries) + with allow_writer(): + EventPayload.objects.bulk_create(event_payloads) + return EventDelivery.objects.bulk_create(event_deliveries) def group_webhooks_by_subscription(webhooks): @@ -140,6 +142,7 @@ def group_webhooks_by_subscription(webhooks): return regular, subscription +@allow_writer() def create_event_delivery_list_for_webhooks( webhooks: Sequence["Webhook"], event_payload: "EventPayload", @@ -190,14 +193,15 @@ def trigger_webhooks_async( elif data is None: raise NotImplementedError("No payload was provided for regular webhooks.") - payload = EventPayload.objects.create(payload=data) - deliveries.extend( - create_event_delivery_list_for_webhooks( - webhooks=regular_webhooks, - event_payload=payload, - event_type=event_type, + with allow_writer(): + payload = EventPayload.objects.create(payload=data) + deliveries.extend( + create_event_delivery_list_for_webhooks( + webhooks=regular_webhooks, + event_payload=payload, + event_type=event_type, + ) ) - ) if subscription_webhooks: deliveries.extend( create_deliveries_for_subscriptions( diff --git a/saleor/webhook/transport/synchronous/transport.py b/saleor/webhook/transport/synchronous/transport.py index 7f762838552..90eb1c66e6a 100644 --- a/saleor/webhook/transport/synchronous/transport.py +++ b/saleor/webhook/transport/synchronous/transport.py @@ -10,6 +10,7 @@ from ....celeryconf import app from ....core import EventDeliveryStatus +from ....core.db.connection import allow_writer from ....core.models import EventDelivery, EventPayload from ....core.tracing import webhooks_opentracing_trace from ....core.utils import get_domain @@ -252,13 +253,15 @@ def create_delivery_for_subscription_sync_event( # Return None so if subscription query returns no data Saleor will not crash but # log the issue and continue without creating a delivery. return None - event_payload = EventPayload.objects.create(payload=json.dumps({**data})) - event_delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + + with allow_writer(): + event_payload = EventPayload.objects.create(payload=json.dumps({**data})) + event_delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) return event_delivery @@ -283,13 +286,14 @@ def trigger_webhook_sync( if not delivery: return None else: - event_payload = EventPayload.objects.create(payload=payload) - delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + with allow_writer(): + event_payload = EventPayload.objects.create(payload=payload) + delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) kwargs = {} if timeout: @@ -337,14 +341,17 @@ def trigger_all_webhooks_sync( if not delivery: return None else: - if event_payload is None: - event_payload = EventPayload.objects.create(payload=generate_payload()) - delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + with allow_writer(): + if event_payload is None: + event_payload = EventPayload.objects.create( + payload=generate_payload() + ) + delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) response_data = send_webhook_request_sync(delivery) if parsed_response := parse_response(response_data): diff --git a/saleor/webhook/transport/utils.py b/saleor/webhook/transport/utils.py index e5e45f4e678..8357fba8d26 100644 --- a/saleor/webhook/transport/utils.py +++ b/saleor/webhook/transport/utils.py @@ -22,6 +22,7 @@ from ...app.headers import AppHeaders, DeprecatedAppHeaders from ...app.models import App +from ...core.db.connection import allow_writer from ...core.http_client import HTTPClient from ...core.models import ( EventDelivery, @@ -378,6 +379,7 @@ def catch_duration_time(): yield lambda: time() - start +@allow_writer() def create_attempt( delivery: "EventDelivery", task_id: Optional[str] = None, @@ -394,6 +396,7 @@ def create_attempt( return attempt +@allow_writer() def attempt_update( attempt: "EventDeliveryAttempt", webhook_response: "WebhookResponse", @@ -416,6 +419,7 @@ def attempt_update( ) +@allow_writer() def clear_successful_delivery(delivery: "EventDelivery"): if delivery.status == EventDeliveryStatus.SUCCESS: payload_id = delivery.payload_id @@ -424,6 +428,7 @@ def clear_successful_delivery(delivery: "EventDelivery"): EventPayload.objects.filter(pk=payload_id, deliveries__isnull=True).delete() +@allow_writer() def delivery_update(delivery: "EventDelivery", status: str): delivery.status = status delivery.save(update_fields=["status"]) @@ -486,13 +491,14 @@ def trigger_transaction_request( payload = generate_transaction_action_request_payload( transaction_data, requestor ) - event_payload = EventPayload.objects.create(payload=payload) - delivery = EventDelivery.objects.create( - status=EventDeliveryStatus.PENDING, - event_type=event_type, - payload=event_payload, - webhook=webhook, - ) + with allow_writer(): + event_payload = EventPayload.objects.create(payload=payload) + delivery = EventDelivery.objects.create( + status=EventDeliveryStatus.PENDING, + event_type=event_type, + payload=event_payload, + webhook=webhook, + ) call_event( handle_transaction_request_task.delay, delivery.id, From a25c73e9f74b78e8d709968fe8263a6c8fdafa8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:46:26 +0200 Subject: [PATCH 18/35] Add allow_writer_in_context in resolvers (#15730) --- saleor/graphql/checkout/dataloaders.py | 23 +++++----- saleor/graphql/checkout/types.py | 59 +++++++++++++++++--------- saleor/graphql/order/types.py | 30 ++++++++++--- saleor/permission/enums.py | 10 +++-- saleor/shipping/utils.py | 13 ++++-- saleor/tax/calculations/order.py | 8 ++-- saleor/warehouse/models.py | 3 +- 7 files changed, 100 insertions(+), 46 deletions(-) diff --git a/saleor/graphql/checkout/dataloaders.py b/saleor/graphql/checkout/dataloaders.py index 8220bcee1ca..80c44df819d 100644 --- a/saleor/graphql/checkout/dataloaders.py +++ b/saleor/graphql/checkout/dataloaders.py @@ -22,6 +22,7 @@ get_checkout_lines_problems, get_checkout_problems, ) +from ...core.db.connection import allow_writer_in_context from ...discount import VoucherType from ...discount.interface import VariantPromotionRuleInfo from ...payment.models import TransactionItem @@ -91,6 +92,7 @@ def with_checkout_lines(results): channel_pks = [checkout.channel_id for checkout in checkouts] + @allow_writer_in_context(self.context) def with_variants_products_collections(results): ( variants, @@ -554,16 +556,17 @@ def with_checkout_info(results): for listing in channel_listings if listing.channel_id == channel.id ] - update_delivery_method_lists_for_checkout_info( - checkout_info, - shipping_method, - collection_point, - shipping_address, - checkout_lines, - manager, - shipping_method_listings, - database_connection_name=self.database_connection_name, - ) + with allow_writer_in_context(self.context): + update_delivery_method_lists_for_checkout_info( + checkout_info, + shipping_method, + collection_point, + shipping_address, + checkout_lines, + manager, + shipping_method_listings, + database_connection_name=self.database_connection_name, + ) checkout_info_map[key] = checkout_info return [checkout_info_map[key] for key in keys] diff --git a/saleor/graphql/checkout/types.py b/saleor/graphql/checkout/types.py index 9980516c7d4..c2bb387e79f 100644 --- a/saleor/graphql/checkout/types.py +++ b/saleor/graphql/checkout/types.py @@ -11,6 +11,7 @@ ) from ...checkout.calculations import fetch_checkout_data from ...checkout.utils import get_valid_collection_points_for_checkout +from ...core.db.connection import allow_writer_in_context from ...core.taxes import zero_money, zero_taxed_money from ...core.utils.lazyobjects import unwrap_lazy from ...graphql.core.context import get_database_connection_name @@ -288,6 +289,7 @@ def with_checkout(data): checkout.token ) + @allow_writer_in_context(info.context) def calculate_line_unit_price(data): checkout_info, lines = data database_connection_name = get_database_connection_name(info.context) @@ -365,6 +367,7 @@ def with_checkout(data): checkout.token ) + @allow_writer_in_context(info.context) def calculate_line_total_price(data): (checkout_info, lines) = data database_connection_name = get_database_connection_name(info.context) @@ -854,20 +857,26 @@ def with_checkout_info(checkout_info): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_methods(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) + def with_checkout_info(checkout_info): + return unwrap_lazy(checkout_info.all_shipping_methods) + return ( CheckoutInfoByCheckoutTokenLoader(info.context) .load(root.token) - .then(lambda checkout_info: unwrap_lazy(checkout_info.all_shipping_methods)) + .then(with_checkout_info) ) @staticmethod def resolve_delivery_method(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) + def with_checkout_info(checkout_info): + return checkout_info.delivery_method_info.delivery_method + return ( CheckoutInfoByCheckoutTokenLoader(info.context) .load(root.token) - .then( - lambda checkout_info: checkout_info.delivery_method_info.delivery_method - ) + .then(with_checkout_info) ) @staticmethod @@ -885,6 +894,7 @@ def calculate_quantity(lines): @traced_resolver @prevent_sync_event_circular_query def resolve_total_price(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def calculate_total_price(data): address, lines, checkout_info, manager = data database_connection_name = get_database_connection_name(info.context) @@ -904,6 +914,7 @@ def calculate_total_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_subtotal_price(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def calculate_subtotal_price(data): address, lines, checkout_info, manager = data database_connection_name = get_database_connection_name(info.context) @@ -922,6 +933,7 @@ def calculate_subtotal_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_price(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def calculate_shipping_price(data): address, lines, checkout_info, manager = data database_connection_name = get_database_connection_name(info.context) @@ -944,22 +956,28 @@ def resolve_lines(root: models.Checkout, info: ResolveInfo): @traced_resolver @prevent_sync_event_circular_query def resolve_available_shipping_methods(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) + def with_checkout_info(checkout_info): + return checkout_info.valid_shipping_methods + return ( CheckoutInfoByCheckoutTokenLoader(info.context) .load(root.token) - .then(lambda checkout_info: checkout_info.valid_shipping_methods) + .then(with_checkout_info) ) @staticmethod @traced_resolver def resolve_available_collection_points(root: models.Checkout, info: ResolveInfo): + @allow_writer_in_context(info.context) def get_available_collection_points(lines): database_connection_name = get_database_connection_name(info.context) - return get_valid_collection_points_for_checkout( + result = get_valid_collection_points_for_checkout( lines, root.channel_id, database_connection_name=database_connection_name, ) + return list(result) return ( CheckoutLinesInfoByCheckoutTokenLoader(info.context) @@ -979,12 +997,12 @@ def resolve_available_payment_gateways( ) def get_available_payment_gateways(results): - (checkout, lines_info) = results + (checkout_info, lines_info) = results return manager.list_payment_gateways( currency=root.currency, - checkout_info=checkout, + checkout_info=checkout_info, checkout_lines=lines_info, - channel_slug=root.channel.slug, + channel_slug=checkout_info.channel.slug, ) return Promise.all([checkout_info, checkout_lines_info]).then( @@ -1025,20 +1043,21 @@ def resolve_language_code(root, _info): def resolve_stock_reservation_expires( root: models.Checkout, info: ResolveInfo, site ): - if not is_reservation_enabled(site.settings): - return None - - def get_oldest_stock_reservation_expiration_date(reservations): - if not reservations: + with allow_writer_in_context(info.context): + if not is_reservation_enabled(site.settings): return None - return min(reservation.reserved_until for reservation in reservations) + def get_oldest_stock_reservation_expiration_date(reservations): + if not reservations: + return None - return ( - StocksReservationsByCheckoutTokenLoader(info.context) - .load(root.token) - .then(get_oldest_stock_reservation_expiration_date) - ) + return min(reservation.reserved_until for reservation in reservations) + + return ( + StocksReservationsByCheckoutTokenLoader(info.context) + .load(root.token) + .then(get_oldest_stock_reservation_expiration_date) + ) @staticmethod @one_of_permissions_required( diff --git a/saleor/graphql/order/types.py b/saleor/graphql/order/types.py index 3112f2065f3..ae7db036ec3 100644 --- a/saleor/graphql/order/types.py +++ b/saleor/graphql/order/types.py @@ -13,6 +13,7 @@ from ...account.models import User as UserModel from ...checkout.utils import get_external_shipping_id from ...core.anonymize import obfuscate_address, obfuscate_email +from ...core.db.connection import allow_writer_in_context from ...core.prices import quantize_price from ...core.taxes import zero_money from ...discount import DiscountType @@ -44,7 +45,6 @@ from ...permission.utils import has_one_of_permissions from ...product.models import ALL_PRODUCTS_PERMISSIONS, ProductMediaTypes from ...shipping.interface import ShippingMethodData -from ...shipping.models import ShippingMethodChannelListing from ...shipping.utils import convert_to_shipping_method_data from ...tax.utils import get_display_gross_prices from ...thumbnail.utils import ( @@ -932,6 +932,7 @@ def _resolve_thumbnail(result): @traced_resolver @prevent_sync_event_circular_query def resolve_unit_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_unit_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -956,6 +957,7 @@ def resolve_quantity_to_fulfill(root: models.OrderLine, info): @traced_resolver @prevent_sync_event_circular_query def resolve_undiscounted_unit_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_undiscounted_unit_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -989,6 +991,7 @@ def resolve_unit_discount(root: models.OrderLine, _info): @staticmethod @traced_resolver def resolve_tax_rate(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_tax_rate(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1009,6 +1012,7 @@ def _resolve_tax_rate(data): @traced_resolver @prevent_sync_event_circular_query def resolve_total_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_total_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1029,6 +1033,7 @@ def _resolve_total_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_undiscounted_total_price(root: models.OrderLine, info): + @allow_writer_in_context(info.context) def _resolve_undiscounted_total_price(data): order, lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1637,6 +1642,7 @@ def _resolve_shipping_address(data): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_price(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_shipping_price(data): lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1652,6 +1658,7 @@ def _resolve_shipping_price(data): @traced_resolver @prevent_sync_event_circular_query def resolve_shipping_tax_rate(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_shipping_tax_rate(data): lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1685,6 +1692,7 @@ def _resolve_actions(payments): @staticmethod @traced_resolver def resolve_subtotal(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_subtotal(data): order_lines, manager = data database_connection_name = get_database_connection_name(info.context) @@ -1705,6 +1713,7 @@ def _resolve_subtotal(data): @prevent_sync_event_circular_query @plugin_manager_promise_callback def resolve_total(root: models.Order, info, manager): + @allow_writer_in_context(info.context) def _resolve_total(lines): database_connection_name = get_database_connection_name(info.context) return calculations.order_total( @@ -1719,6 +1728,7 @@ def _resolve_total(lines): @traced_resolver @prevent_sync_event_circular_query def resolve_undiscounted_total(root: models.Order, info): + @allow_writer_in_context(info.context) def _resolve_undiscounted_total(lines_and_manager): lines, manager = lines_and_manager database_connection_name = get_database_connection_name(info.context) @@ -2007,14 +2017,21 @@ def wrap_shipping_method_with_channel_context(data): ).load((shipping_method.id, channel.slug)) ) - def calculate_price( - listing: Optional[ShippingMethodChannelListing], - ) -> Optional[ShippingMethodData]: + tax_class = None + if shipping_method.tax_class_id: + tax_class = TaxClassByIdLoader(info.context).load( + shipping_method.tax_class_id + ) + + def calculate_price(data) -> Optional[ShippingMethodData]: + listing, tax_class = data if not listing: return None - return convert_to_shipping_method_data(shipping_method, listing) + return convert_to_shipping_method_data( + shipping_method, listing, tax_class + ) - return listing.then(calculate_price) + return Promise.all([listing, tax_class]).then(calculate_price) shipping_method = ShippingMethodByIdLoader(info.context).load( int(root.shipping_method_id) @@ -2045,6 +2062,7 @@ def with_channel(data): channel, manager = data database_connection_name = get_database_connection_name(info.context) + @allow_writer_in_context(info.context) def with_listings(channel_listings): return get_valid_shipping_methods_for_order( root, diff --git a/saleor/permission/enums.py b/saleor/permission/enums.py index 9f49e27b5de..fac96c66cae 100644 --- a/saleor/permission/enums.py +++ b/saleor/permission/enums.py @@ -160,12 +160,16 @@ def get_permissions( codenames = get_permissions_codename() else: codenames = split_permission_codename(permissions) - return get_permissions_from_codenames(codenames) + return get_permissions_from_codenames(codenames, database_connection_name) -def get_permissions_from_codenames(permission_codenames: list[str]) -> QuerySet: +def get_permissions_from_codenames( + permission_codenames: list[str], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +) -> QuerySet: return ( - Permission.objects.filter(codename__in=permission_codenames) + Permission.objects.using(database_connection_name) + .filter(codename__in=permission_codenames) .prefetch_related("content_type") .order_by("codename") ) diff --git a/saleor/shipping/utils.py b/saleor/shipping/utils.py index 7a07c3beb97..54dba131973 100644 --- a/saleor/shipping/utils.py +++ b/saleor/shipping/utils.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from django_countries import countries @@ -7,6 +7,7 @@ from .interface import ShippingMethodData if TYPE_CHECKING: + from ..tax.models import TaxClass from .models import ShippingMethod, ShippingMethodChannelListing @@ -30,12 +31,18 @@ def get_countries_without_shipping_zone(): def convert_to_shipping_method_data( - shipping_method: "ShippingMethod", listing: "ShippingMethodChannelListing" + shipping_method: "ShippingMethod", + listing: "ShippingMethodChannelListing", + tax_class: Optional["TaxClass"] = None, ) -> "ShippingMethodData": price = listing.price minimum_order_price = listing.minimum_order_price maximum_order_price = listing.maximum_order_price + if not tax_class: + # Tax class should be passed as argument, this is a fallback. + tax_class = shipping_method.tax_class + return ShippingMethodData( id=str(shipping_method.id), name=shipping_method.name, @@ -48,7 +55,7 @@ def convert_to_shipping_method_data( metadata=shipping_method.metadata, private_metadata=shipping_method.private_metadata, price=price, - tax_class=shipping_method.tax_class, + tax_class=tax_class, minimum_order_price=minimum_order_price, maximum_order_price=maximum_order_price, ) diff --git a/saleor/tax/calculations/order.py b/saleor/tax/calculations/order.py index dfa1a0b0229..0d88de8cb5f 100644 --- a/saleor/tax/calculations/order.py +++ b/saleor/tax/calculations/order.py @@ -28,9 +28,11 @@ def update_order_prices_with_flat_rates( database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): country_code = get_order_country(order) - default_country_rate_obj = TaxClassCountryRate.objects.filter( - country=country_code, tax_class=None - ).first() + default_country_rate_obj = ( + TaxClassCountryRate.objects.using(database_connection_name) + .filter(country=country_code, tax_class=None) + .first() + ) default_tax_rate = ( default_country_rate_obj.rate if default_country_rate_obj else Decimal(0) ) diff --git a/saleor/warehouse/models.py b/saleor/warehouse/models.py index e9c77243152..c18bc393712 100644 --- a/saleor/warehouse/models.py +++ b/saleor/warehouse/models.py @@ -127,7 +127,8 @@ def applicable_for_click_and_collect( ) stocks_qs = ( - Stock.objects.annotate_available_quantity() + Stock.objects.using(self.db) + .annotate_available_quantity() .annotate(line_quantity=F("available_quantity") - Subquery(lines_quantity)) .order_by("line_quantity") .filter( From 8699ee03cd33159c320cd96229897683ded6ae2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:46:52 +0200 Subject: [PATCH 19/35] Bump dependencies (#15737) Co-authored-by: maarcingebala --- .pre-commit-config.yaml | 2 +- poetry.lock | 82 ++++++++++++++++++++--------------------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 37be86ac92e..2df08998562 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: - id: deptry - repo: https://github.com/returntocorp/semgrep - rev: v1.66.0 + rev: v1.66.2 hooks: - id: semgrep language_version: python3.9 diff --git a/poetry.lock b/poetry.lock index 3e32fe2c163..7874a0658d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -321,17 +321,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.69" +version = "1.34.74" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.69-py3-none-any.whl", hash = "sha256:2e25ef6bd325217c2da329829478be063155897d8d3b29f31f7f23ab548519b1"}, - {file = "boto3-1.34.69.tar.gz", hash = "sha256:898a5fed26b1351352703421d1a8b886ef2a74be6c97d5ecc92432ae01fda203"}, + {file = "boto3-1.34.74-py3-none-any.whl", hash = "sha256:71f551491fb12fe07727d371d5561c5919fdf33dbc1d4251c57940d267a53a9e"}, + {file = "boto3-1.34.74.tar.gz", hash = "sha256:b703e22775561a748adc4576c30424b81abd2a00d3c6fb28eec2e5cde92c1eed"}, ] [package.dependencies] -botocore = ">=1.34.69,<1.35.0" +botocore = ">=1.34.74,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -340,13 +340,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.69" +version = "1.34.74" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.69-py3-none-any.whl", hash = "sha256:d3802d076d4d507bf506f9845a6970ce43adc3d819dd57c2791f5c19ed6e5950"}, - {file = "botocore-1.34.69.tar.gz", hash = "sha256:d1ab2bff3c2fd51719c2021d9fa2f30fbb9ed0a308f69e9a774ac92c8091380a"}, + {file = "botocore-1.34.74-py3-none-any.whl", hash = "sha256:5d2015b5d91d6c402c122783729ce995ed7283a746b0380957026dc2b3b75969"}, + {file = "botocore-1.34.74.tar.gz", hash = "sha256:32bb519bae62483893330c18a0ea4fd09d1ffa32bc573cd8559c2d9a08fb8c5c"}, ] [package.dependencies] @@ -1155,13 +1155,13 @@ tzdata = "*" [[package]] name = "django-countries" -version = "7.5.1" +version = "7.6" description = "Provides a country field for Django models." optional = false python-versions = "*" files = [ - {file = "django-countries-7.5.1.tar.gz", hash = "sha256:22915d9b9403932b731622619940a54894a3eb0da9a374e7249c8fc453c122d7"}, - {file = "django_countries-7.5.1-py3-none-any.whl", hash = "sha256:2df707aca7a5e677254bed116cf6021a136ebaccd5c2f46860abd6452bb45521"}, + {file = "django-countries-7.6.tar.gz", hash = "sha256:aba80a9ce7c293671bb36507c98c4f8886b768868160a1b7afd2fa2aee4c5a0b"}, + {file = "django_countries-7.6-py3-none-any.whl", hash = "sha256:1939b6b28fc341615f1b62ba4a681b54e67414df40cdea5e99ce3897546af92c"}, ] [package.dependencies] @@ -1169,8 +1169,8 @@ asgiref = "*" typing-extensions = "*" [package.extras] -dev = ["black", "django", "djangorestframework", "graphene-django", "pytest", "pytest-django", "tox"] -maintainer = ["django", "transifex-client", "zest.releaser[recommended]"] +dev = ["black", "django", "djangorestframework", "graphene-django", "pytest", "pytest-django", "tox (==4.*)"] +maintainer = ["django", "zest.releaser[recommended]"] pyuca = ["pyuca"] test = ["djangorestframework", "graphene-django", "pytest", "pytest-cov", "pytest-django"] @@ -1546,18 +1546,18 @@ probabilistic = ["pyprobables (>=0.6,<0.7)"] [[package]] name = "filelock" -version = "3.13.1" +version = "3.13.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"}, + {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -1731,13 +1731,13 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"] [[package]] name = "google-cloud-pubsub" -version = "2.20.3" +version = "2.21.0" description = "Google Cloud Pub/Sub API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-pubsub-2.20.3.tar.gz", hash = "sha256:76af0f179509e431d2bbe3f51f426256783dc959b8f458c1370cb19d6f7be0b1"}, - {file = "google_cloud_pubsub-2.20.3-py2.py3-none-any.whl", hash = "sha256:d9ae1e812800208492afa1a35ec194637600cd1c27eca382bd80aa1d4a168ffa"}, + {file = "google-cloud-pubsub-2.21.0.tar.gz", hash = "sha256:94017f0bc9a85fa3f4d913f312e930a0fe21775bd68dde5c666e2f1b1addf811"}, + {file = "google_cloud_pubsub-2.21.0-py2.py3-none-any.whl", hash = "sha256:fabd19e08faa1f70081b0e5ea003a3c031a1d2c1b798cf1be5ded2dbf1d1dbef"}, ] [package.dependencies] @@ -2369,13 +2369,13 @@ referencing = ">=0.31.0" [[package]] name = "kombu" -version = "5.3.5" +version = "5.3.6" description = "Messaging library for Python." optional = false python-versions = ">=3.8" files = [ - {file = "kombu-5.3.5-py3-none-any.whl", hash = "sha256:0eac1bbb464afe6fb0924b21bf79460416d25d8abc52546d4f16cad94f789488"}, - {file = "kombu-5.3.5.tar.gz", hash = "sha256:30e470f1a6b49c70dc6f6d13c3e4cc4e178aa6c469ceb6bcd55645385fc84b93"}, + {file = "kombu-5.3.6-py3-none-any.whl", hash = "sha256:49f1e62b12369045de2662f62cc584e7df83481a513db83b01f87b5b9785e378"}, + {file = "kombu-5.3.6.tar.gz", hash = "sha256:f3da5b570a147a5da8280180aa80b03807283d63ea5081fcdb510d18242431d9"}, ] [package.dependencies] @@ -2393,7 +2393,7 @@ mongodb = ["pymongo (>=4.1.1)"] msgpack = ["msgpack"] pyro = ["pyro4"] qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] -redis = ["redis (>=4.5.2,!=4.5.5,<6.0.0)"] +redis = ["redis (>=4.5.2,!=4.5.5,!=5.0.2)"] slmq = ["softlayer-messaging (>=1.0.3)"] sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] sqs = ["boto3 (>=1.26.143)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] @@ -3362,28 +3362,28 @@ files = [ [[package]] name = "pyasn1" -version = "0.5.1" +version = "0.6.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.8" files = [ - {file = "pyasn1-0.5.1-py2.py3-none-any.whl", hash = "sha256:4439847c58d40b1d0a573d07e3856e95333f1976294494c325775aeca506eb58"}, - {file = "pyasn1-0.5.1.tar.gz", hash = "sha256:6d391a96e59b23130a5cfa74d6fd7f388dbbe26cc8f1edf39fdddf08d9d6676c"}, + {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, + {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"}, ] [[package]] name = "pyasn1-modules" -version = "0.3.0" +version = "0.4.0" description = "A collection of ASN.1-based protocols modules" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.8" files = [ - {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, - {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, + {file = "pyasn1_modules-0.4.0-py3-none-any.whl", hash = "sha256:be04f15b66c206eed667e0bb5ab27e2b1855ea54a842e5037738099e8ca4ae0b"}, + {file = "pyasn1_modules-0.4.0.tar.gz", hash = "sha256:831dbcea1b177b28c9baddf4c6d1013c24c3accd14a1873fffaa6a2e905f17b6"}, ] [package.dependencies] -pyasn1 = ">=0.4.6,<0.6.0" +pyasn1 = ">=0.4.6,<0.7.0" [[package]] name = "pybars3" @@ -3400,13 +3400,13 @@ PyMeta3 = ">=0.5.1" [[package]] name = "pycparser" -version = "2.21" +version = "2.22" description = "C parser in Python" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.8" files = [ - {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, - {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] [[package]] @@ -4488,13 +4488,13 @@ files = [ [[package]] name = "textual" -version = "0.53.1" +version = "0.54.0" description = "Modern Text User Interface framework" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ - {file = "textual-0.53.1-py3-none-any.whl", hash = "sha256:32201aa9d334ed064d5e670f15fe3d7f19c736ca54cecb054a5b995691104434"}, - {file = "textual-0.53.1.tar.gz", hash = "sha256:23ba673be7974819ded35ea88d28df7117987e53d58f15b2cc890ac2ecf56401"}, + {file = "textual-0.54.0-py3-none-any.whl", hash = "sha256:94aacf28dece20a44f0b94b087e17ff4ac961acd92e12e648f060fe2555b3adc"}, + {file = "textual-0.54.0.tar.gz", hash = "sha256:0cfd134dde5ae49d64dd73bb32a2fb5a86d878d9caeacecaa1d640082f31124e"}, ] [package.dependencies] From 36699c6f5c99590d24f46e3d5c5b1a3c2fd072e7 Mon Sep 17 00:00:00 2001 From: zedzior Date: Tue, 2 Apr 2024 13:13:36 +0200 Subject: [PATCH 20/35] Advisory fix merx-280. --- .../mutations/authentication/refresh_token.py | 7 +++--- .../authentication/test_token_refresh.py | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/saleor/graphql/account/mutations/authentication/refresh_token.py b/saleor/graphql/account/mutations/authentication/refresh_token.py index cf4fb63484b..b461fffc36d 100644 --- a/saleor/graphql/account/mutations/authentication/refresh_token.py +++ b/saleor/graphql/account/mutations/authentication/refresh_token.py @@ -58,14 +58,13 @@ def get_refresh_token( cls, info: ResolveInfo, refresh_token: Optional[str] = None ) -> Optional[str]: request = info.context - refresh_token = refresh_token or request.COOKIES.get( - JWT_REFRESH_TOKEN_COOKIE_NAME, None - ) + if refresh_token is None: + refresh_token = request.COOKIES.get(JWT_REFRESH_TOKEN_COOKIE_NAME, None) return refresh_token @classmethod def clean_refresh_token(cls, refresh_token): - if not refresh_token: + if refresh_token is None: raise ValidationError( { "refresh_token": ValidationError( diff --git a/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py b/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py index e969dbcbf49..f0baffa79d4 100644 --- a/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py +++ b/saleor/graphql/account/tests/mutations/authentication/test_token_refresh.py @@ -1,5 +1,6 @@ from datetime import datetime +import pytest from django.urls import reverse from freezegun import freeze_time @@ -269,3 +270,25 @@ def test_refresh_token_when_user_deactivated_token(api_client, customer_user): assert not data["token"] assert len(errors) == 1 assert errors[0]["code"] == AccountErrorCode.JWT_INVALID_TOKEN.name + + +@pytest.mark.parametrize("token", ["incorrect-token", ""]) +def test_refresh_token_incorrect_token_provided(api_client, customer_user, token): + # given + csrf_token = _get_new_csrf_token() + refresh_token = create_refresh_token(customer_user, {"csrfToken": csrf_token}) + api_client.cookies[JWT_REFRESH_TOKEN_COOKIE_NAME] = refresh_token + api_client.cookies[JWT_REFRESH_TOKEN_COOKIE_NAME]["httponly"] = True + + variables = {"token": token, "csrf_token": csrf_token} + + # when + response = api_client.post_graphql(MUTATION_TOKEN_REFRESH, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["tokenRefresh"] + errors = data["errors"] + assert not data.get("token") + assert len(errors) == 1 + assert errors[0]["code"] == AccountErrorCode.JWT_DECODE_ERROR.name From 19197e2c957be95c261deb5ddcd668c16019ef11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Wcis=C5=82o?= <115464873+wcislo-saleor@users.noreply.github.com> Date: Fri, 5 Apr 2024 09:48:53 +0200 Subject: [PATCH 21/35] Bump requests-hardened (#15752) --- poetry.lock | 7 +++---- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7874a0658d1..2e667b4d91b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3846,7 +3846,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3951,12 +3950,12 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "requests-hardened" -version = "1.0.0b2" +version = "1.0.0b3" description = "A library that overrides the default behaviors of the requests library, and adds new security features." optional = false python-versions = ">=3.8" files = [ - {file = "requests-hardened-1.0.0b2.tar.gz", hash = "sha256:28c5591f1f346f4b1a014d97cb0e1addb852bcdcfb5c129b6ca5443587c32c73"}, + {file = "requests-hardened-1.0.0b3.tar.gz", hash = "sha256:125057fb864e4283c926021f594c9e4695432036f13fd76fee3ef738510231e2"}, ] [package.dependencies] @@ -5401,4 +5400,4 @@ test = ["pytest"] [metadata] lock-version = "2.0" python-versions = "~3.9" -content-hash = "16a2c854e9a7d0795c7e1222c51c247373015023e4f738575472d075bbc3b014" +content-hash = "59f3b92f1122bc8f6a11be9e6d6c94fc4cd60f8a161ed392a3f69554040b92c5" diff --git a/pyproject.toml b/pyproject.toml index 474d144c676..e7089fdb93a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ documentation = "https://docs.saleor.io/" razorpay = "^1.2" redis = "^5.0.1" requests = "^2.22" - requests-hardened = "1.0.0b2" + requests-hardened = "1.0.0b3" Rx = "^1.6.3" semantic-version = "^2.10.0" sendgrid = "^6.7.1" From a9d2e2dcae26c1233846d228be284b5b6c65d4cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Fri, 5 Apr 2024 10:51:54 +0200 Subject: [PATCH 22/35] Optimize validate_draft_order function to run less queries (#15748) --- saleor/discount/utils.py | 14 ++-- .../order/mutations/draft_order_complete.py | 2 +- .../order/tests/test_draft_order_validate.py | 77 +++++++++++++----- saleor/graphql/order/types.py | 30 ++++--- saleor/graphql/order/utils.py | 78 ++++++++++++------- saleor/order/tests/test_order.py | 11 ++- saleor/order/utils.py | 6 +- 7 files changed, 150 insertions(+), 68 deletions(-) diff --git a/saleor/discount/utils.py b/saleor/discount/utils.py index 4c21523cfd4..8e081480f0d 100644 --- a/saleor/discount/utils.py +++ b/saleor/discount/utils.py @@ -54,7 +54,7 @@ if TYPE_CHECKING: from ..account.models import User from ..checkout.fetch import CheckoutInfo - from ..order.models import Order + from ..order.models import Order, OrderLine from ..plugins.manager import PluginsManager from ..product.managers import ProductVariantQueryset from ..product.models import VariantChannelListingPromotionRule @@ -283,19 +283,23 @@ def validate_voucher_for_checkout( ) -def validate_voucher_in_order(order: "Order"): +def validate_voucher_in_order( + order: "Order", lines: Iterable["OrderLine"], channel: "Channel" +): if not order.voucher: return + from ..order.utils import get_total_quantity + subtotal = order.subtotal - quantity = order.get_total_quantity() + quantity = get_total_quantity(lines) customer_email = order.get_customer_email() - tax_configuration = order.channel.tax_configuration + tax_configuration = channel.tax_configuration prices_entered_with_tax = tax_configuration.prices_entered_with_tax value = subtotal.gross if prices_entered_with_tax else subtotal.net validate_voucher( - order.voucher, value, quantity, customer_email, order.channel, order.user + order.voucher, value, quantity, customer_email, channel, order.user ) diff --git a/saleor/graphql/order/mutations/draft_order_complete.py b/saleor/graphql/order/mutations/draft_order_complete.py index 08c610d6546..46e0ca5000a 100644 --- a/saleor/graphql/order/mutations/draft_order_complete.py +++ b/saleor/graphql/order/mutations/draft_order_complete.py @@ -111,7 +111,7 @@ def perform_mutation( # type: ignore[override] cls.validate_order(order) country = get_order_country(order) - validate_draft_order(order, country, manager) + validate_draft_order(order, order.lines.all(), country, manager) with traced_atomic_transaction(): cls.update_user_fields(order) order.status = OrderStatus.UNFULFILLED diff --git a/saleor/graphql/order/tests/test_draft_order_validate.py b/saleor/graphql/order/tests/test_draft_order_validate.py index 1a7b5eef418..c5c49437733 100644 --- a/saleor/graphql/order/tests/test_draft_order_validate.py +++ b/saleor/graphql/order/tests/test_draft_order_validate.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from unittest.mock import Mock +from unittest.mock import patch import pytest import pytz @@ -15,7 +15,10 @@ def test_validate_draft_order(draft_order): # should not raise any errors assert ( validate_draft_order( - draft_order, "US", get_plugins_manager(allow_replica=False) + draft_order, + draft_order.lines.all(), + "US", + get_plugins_manager(allow_replica=False), ) is None ) @@ -27,7 +30,10 @@ def test_validate_draft_order_without_sku(draft_order): # should not raise any errors assert ( validate_draft_order( - draft_order, "US", get_plugins_manager(allow_replica=False) + draft_order, + draft_order.lines.all(), + "US", + get_plugins_manager(allow_replica=False), ) is None ) @@ -40,7 +46,9 @@ def test_validate_draft_order_wrong_shipping(draft_order): shipping_zone.save() assert order.shipping_address.country.code not in shipping_zone.countries with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Shipping method is not valid for chosen shipping address" assert e.value.error_dict["shipping"][0].message == msg @@ -48,7 +56,9 @@ def test_validate_draft_order_wrong_shipping(draft_order): def test_validate_draft_order_no_order_lines(order, shipping_method): order.shipping_method = shipping_method with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Could not create order without any products." assert e.value.error_dict["lines"][0].message == msg @@ -62,7 +72,9 @@ def test_validate_draft_order_non_existing_variant(draft_order): assert line.variant is None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Could not create orders with non-existing products." assert e.value.error_dict["lines"][0].message == msg @@ -77,7 +89,9 @@ def test_validate_draft_order_with_unpublished_product(draft_order): line.refresh_from_db() with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Can't finalize draft with unpublished product." error = e.value.error_dict["lines"][0] @@ -93,7 +107,9 @@ def test_validate_draft_order_with_unavailable_for_purchase_product(draft_order) line.refresh_from_db() with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Can't finalize draft with product unavailable for purchase." error = e.value.error_dict["lines"][0] @@ -113,7 +129,9 @@ def test_validate_draft_order_with_product_available_for_purchase_in_future( line.refresh_from_db() with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Can't finalize draft with product unavailable for purchase." error = e.value.error_dict["lines"][0] @@ -131,7 +149,9 @@ def test_validate_draft_order_out_of_stock_variant(draft_order): stock.save(update_fields=["quantity"]) with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) msg = "Insufficient product stock." assert e.value.error_dict["lines"][0].message == msg @@ -141,7 +161,9 @@ def test_validate_draft_order_no_shipping_address(draft_order): order.shipping_address = None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["order"][0] assert error.message == "Can't finalize draft with no shipping address." assert error.code == OrderErrorCode.ORDER_NO_SHIPPING_ADDRESS.value @@ -152,7 +174,9 @@ def test_validate_draft_order_no_billing_address(draft_order): order.billing_address = None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["order"][0] assert error.message == "Can't finalize draft with no billing address." assert error.code == OrderErrorCode.BILLING_ADDRESS_NOT_SET.value @@ -163,35 +187,44 @@ def test_validate_draft_order_no_shipping_method(draft_order): order.shipping_method = None with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["shipping"][0] assert error.message == "Shipping method is required." assert error.code == OrderErrorCode.SHIPPING_METHOD_REQUIRED.value -def test_validate_draft_order_no_shipping_method_shipping_not_required(draft_order): +@patch("saleor.graphql.order.utils.is_shipping_required") +def test_validate_draft_order_no_shipping_method_shipping_not_required( + mocked_is_shipping_required, draft_order +): order = draft_order order.shipping_method = None - required_mock = Mock(return_value=False) - order.is_shipping_required = required_mock + mocked_is_shipping_required.return_value = False assert ( - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) is None ) +@patch("saleor.graphql.order.utils.is_shipping_required") def test_validate_draft_order_no_shipping_address_no_method_shipping_not_required( + mocked_is_shipping_required, draft_order, ): order = draft_order order.shipping_method = None order.shipping_address = None - required_mock = Mock(return_value=False) - order.is_shipping_required = required_mock + mocked_is_shipping_required.return_value = False assert ( - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) is None ) @@ -203,7 +236,9 @@ def test_validate_draft_order_voucher(draft_order_with_voucher): # when & then with pytest.raises(ValidationError) as e: - validate_draft_order(order, "US", get_plugins_manager(allow_replica=False)) + validate_draft_order( + order, order.lines.all(), "US", get_plugins_manager(allow_replica=False) + ) error = e.value.error_dict["voucher"][0] assert error.code == OrderErrorCode.INVALID_VOUCHER.value diff --git a/saleor/graphql/order/types.py b/saleor/graphql/order/types.py index ae7db036ec3..d581bc2853b 100644 --- a/saleor/graphql/order/types.py +++ b/saleor/graphql/order/types.py @@ -1930,21 +1930,26 @@ def resolve_status_display(root: models.Order, _info): def resolve_can_finalize(root: models.Order, info): if root.status == OrderStatus.DRAFT: - def _validate_draft_order(manager): + @allow_writer_in_context(info.context) + def _validate_draft_order(data): + lines, manager = data country = get_order_country(root) database_connection_name = get_database_connection_name(info.context) try: validate_draft_order( - root, - country, - manager, + order=root, + lines=lines, + country=country, + manager=manager, database_connection_name=database_connection_name, ) except ValidationError: return False return True - return get_plugin_manager_promise(info.context).then(_validate_draft_order) + lines = OrderLinesByOrderIdLoader(info.context).load(root.id) + manager = get_plugin_manager_promise(info.context) + return Promise.all([lines, manager]).then(_validate_draft_order) return True @staticmethod @@ -2158,21 +2163,26 @@ def resolve_original(root: models.Order, _info): def resolve_errors(root: models.Order, info): if root.status == OrderStatus.DRAFT: - def _validate_order(manager): + @allow_writer_in_context(info.context) + def _validate_order(data): + lines, manager = data country = get_order_country(root) database_connection_name = get_database_connection_name(info.context) try: validate_draft_order( - root, - country, - manager, + order=root, + lines=lines, + country=country, + manager=manager, database_connection_name=database_connection_name, ) except ValidationError as e: return validation_error_to_error_type(e, OrderError) return [] - return get_plugin_manager_promise(info.context).then(_validate_order) + lines = OrderLinesByOrderIdLoader(info.context).load(root.id) + manager = get_plugin_manager_promise(info.context) + return Promise.all([lines, manager]).then(_validate_order) return [] diff --git a/saleor/graphql/order/utils.py b/saleor/graphql/order/utils.py index 53197fc0b85..0113fc487db 100644 --- a/saleor/graphql/order/utils.py +++ b/saleor/graphql/order/utils.py @@ -12,7 +12,11 @@ from ...discount.models import NotApplicable from ...discount.utils import validate_voucher_in_order from ...order.error_codes import OrderErrorCode -from ...order.utils import get_valid_shipping_methods_for_order +from ...order.utils import ( + get_total_quantity, + get_valid_shipping_methods_for_order, + is_shipping_required, +) from ...plugins.manager import PluginsManager from ...product.models import Product, ProductChannelListing, ProductVariant from ...shipping.interface import ShippingMethodData @@ -22,7 +26,7 @@ if TYPE_CHECKING: from ...channel.models import Channel - from ...order.models import Order + from ...order.models import Order, OrderLine from dataclasses import dataclass @@ -39,8 +43,9 @@ class OrderLineData: rules_info: Optional[Iterable[VariantPromotionRuleInfo]] = None -def validate_total_quantity(order: "Order", errors: T_ERRORS): - if order.get_total_quantity() == 0: +def validate_total_quantity(lines: Iterable["OrderLine"], errors: T_ERRORS): + total_quantity = get_total_quantity(lines) + if total_quantity == 0: errors["lines"].append( ValidationError( "Could not create order without any products.", @@ -79,6 +84,7 @@ def get_shipping_method_availability_error( def validate_shipping_method( order: "Order", + channel: "Channel", errors: T_ERRORS, manager: "PluginsManager", database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, @@ -103,7 +109,7 @@ def validate_shipping_method( code=OrderErrorCode.SHIPPING_METHOD_NOT_APPLICABLE.value, ) else: - listing = order.channel.shipping_method_listings.filter( + listing = channel.shipping_method_listings.filter( shipping_method=order.shipping_method ).last() if not listing: @@ -145,11 +151,13 @@ def validate_shipping_address(order: "Order", errors: T_ERRORS): def validate_order_lines( order: "Order", + lines: Iterable["OrderLine"], + channel: "Channel", country: str, errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - for line in order.lines.all(): + for line in lines: if line.variant is None: errors["lines"].append( ValidationError( @@ -162,7 +170,7 @@ def validate_order_lines( check_stock_and_preorder_quantity( line.variant, country, - order.channel.slug, + channel.slug, line.quantity, order_line=line, database_connection_name=database_connection_name, @@ -174,15 +182,16 @@ def validate_order_lines( def validate_variants_is_available( - order: "Order", + channel: "Channel", + lines: Iterable["OrderLine"], errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - variants_ids = {line.variant_id for line in order.lines.all()} + variants_ids = {line.variant_id for line in lines} try: validate_variants_available_in_channel( variants_ids, - order.channel_id, + channel.id, OrderErrorCode.NOT_AVAILABLE_IN_CHANNEL.value, database_connection_name=database_connection_name, ) @@ -191,15 +200,16 @@ def validate_variants_is_available( def validate_product_is_published( - order: "Order", + channel: "Channel", + lines: Iterable["OrderLine"], errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - variant_ids = [line.variant_id for line in order.lines.all()] + variant_ids = [line.variant_id for line in lines] unpublished_product = ( Product.objects.using(database_connection_name) .filter(variants__id__in=variant_ids) - .not_published(order.channel.slug) + .not_published(channel.slug) ) if unpublished_product.exists(): errors["lines"].append( @@ -272,18 +282,19 @@ def validate_variant_channel_listings( def validate_product_is_available_for_purchase( - order: "Order", + channel: "Channel", + lines: Iterable["OrderLine"], errors: T_ERRORS, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): invalid_lines = [] - for line in order.lines.all(): + for line in lines: variant = line.variant if not variant: continue product_channel_listing = ( ProductChannelListing.objects.using(database_connection_name) - .filter(channel_id=order.channel_id, product_id=variant.product_id) + .filter(channel_id=channel.id, product_id=variant.product_id) .first() ) if not ( @@ -311,10 +322,12 @@ def validate_channel_is_active(channel: "Channel", errors: T_ERRORS): ) -def _validate_voucher(order: "Order", errors: T_ERRORS): - if order.channel.include_draft_order_in_voucher_usage: +def _validate_voucher( + order: "Order", lines: Iterable["OrderLine"], channel: "Channel", errors: T_ERRORS +): + if channel.include_draft_order_in_voucher_usage: try: - validate_voucher_in_order(order) + validate_voucher_in_order(order, lines, channel) except NotApplicable as e: errors["voucher"].append( ValidationError( @@ -326,6 +339,7 @@ def _validate_voucher(order: "Order", errors: T_ERRORS): def validate_draft_order( order: "Order", + lines: Iterable["OrderLine"], country: str, manager: "PluginsManager", database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, @@ -341,18 +355,26 @@ def validate_draft_order( Returns a list of errors if any were found. """ + channel = order.channel + errors: T_ERRORS = defaultdict(list) validate_billing_address(order, errors) - if order.is_shipping_required(): + if is_shipping_required(lines): validate_shipping_address(order, errors) - validate_shipping_method(order, errors, manager, database_connection_name) - validate_total_quantity(order, errors) - validate_order_lines(order, country, errors, database_connection_name) - validate_channel_is_active(order.channel, errors) - validate_product_is_published(order, errors, database_connection_name) - validate_product_is_available_for_purchase(order, errors, database_connection_name) - validate_variants_is_available(order, errors, database_connection_name) - _validate_voucher(order, errors) + validate_shipping_method( + order, channel, errors, manager, database_connection_name + ) + validate_total_quantity(lines, errors) + validate_order_lines( + order, lines, channel, country, errors, database_connection_name + ) + validate_channel_is_active(channel, errors) + validate_product_is_published(channel, lines, errors, database_connection_name) + validate_product_is_available_for_purchase( + channel, lines, errors, database_connection_name + ) + validate_variants_is_available(channel, lines, errors, database_connection_name) + _validate_voucher(order, lines, channel, errors) if errors: raise ValidationError(errors) diff --git a/saleor/order/tests/test_order.py b/saleor/order/tests/test_order.py index 7cc2c18b782..820b01a7c60 100644 --- a/saleor/order/tests/test_order.py +++ b/saleor/order/tests/test_order.py @@ -678,7 +678,9 @@ def test_get_voucher_discount_for_order_voucher_validation( quantity = order_with_lines.get_total_quantity() customer_email = order_with_lines.get_customer_email() - validate_voucher_in_order(order_with_lines) + validate_voucher_in_order( + order_with_lines, order_with_lines.lines.all(), order_with_lines.channel + ) mock_validate_voucher.assert_called_once_with( voucher, @@ -699,7 +701,9 @@ def test_validate_voucher_in_order_without_voucher( assert not order_with_lines.voucher - validate_voucher_in_order(order_with_lines) + validate_voucher_in_order( + order_with_lines, order_with_lines.lines.all(), order_with_lines.channel + ) mock_validate_voucher.assert_not_called() @@ -745,6 +749,7 @@ def test_value_voucher_order_discount( billing_address=address_usa, channel=channel_USD, ) + order.lines = Mock(all=Mock(return_value=[])) discount = get_voucher_discount_for_order(order) assert discount == Money(expected_value, "USD") @@ -784,6 +789,7 @@ def test_shipping_voucher_order_discount( voucher=voucher, channel=channel_USD, ) + order.lines = Mock(all=Mock(return_value=[])) discount = get_voucher_discount_for_order(order) assert discount == Money(expected_value, "USD") @@ -840,6 +846,7 @@ def test_shipping_voucher_checkout_discount_not_applicable_returns_zero( voucher=voucher, channel=channel_USD, ) + order.lines = Mock(all=Mock(return_value=[])) with pytest.raises(NotApplicable): get_voucher_discount_for_order(order) diff --git a/saleor/order/utils.py b/saleor/order/utils.py index fcbe3e1503e..9e88fa5f8e0 100644 --- a/saleor/order/utils.py +++ b/saleor/order/utils.py @@ -664,6 +664,10 @@ def is_shipping_required(lines: Iterable["OrderLine"]): return any(line.is_shipping_required for line in lines) +def get_total_quantity(lines: Iterable["OrderLine"]): + return sum([line.quantity for line in lines]) + + def get_valid_collection_points_for_order( lines: Iterable["OrderLine"], channel_id: int, @@ -741,7 +745,7 @@ def get_voucher_discount_for_order(order: Order) -> Money: """ if not order.voucher: return zero_money(order.currency) - validate_voucher_in_order(order) + validate_voucher_in_order(order, order.lines.all(), order.channel) subtotal = order.subtotal if order.voucher.type == VoucherType.ENTIRE_ORDER: return order.voucher.get_discount_amount_for(subtotal.gross, order.channel) From 52b3886bef42bcaf3e8423cff825edafc5b9ec9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Fri, 5 Apr 2024 10:52:17 +0200 Subject: [PATCH 23/35] Add more usages of allow_writer context manager (#15749) * Optimize validate_draft_order function to run less queries * Add misc allow_writer usages * Update comment --- saleor/graphql/executor.py | 23 +++++++++++++++++++ saleor/plugins/manager.py | 7 +++++- .../webhook/tests/test_shipping_webhook.py | 3 ++- saleor/product/utils/digital_products.py | 2 ++ saleor/product/views.py | 5 +++- saleor/settings.py | 3 +++ 6 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 saleor/graphql/executor.py diff --git a/saleor/graphql/executor.py b/saleor/graphql/executor.py new file mode 100644 index 00000000000..ddafe3a2e99 --- /dev/null +++ b/saleor/graphql/executor.py @@ -0,0 +1,23 @@ +from graphql.execution import executor + +original_complete_value_catching_error = executor.complete_value_catching_error + + +def _patched_complete_value_catching_error(*args, **kwargs): + info = args[3] + from saleor.core.db.connection import allow_writer_in_context + + with allow_writer_in_context(info.context): + return original_complete_value_catching_error(*args, **kwargs) + + +def patch_executor(): + """Patch `complete_value_catching_error` function to allow writer DB in mutations. + + The `complete_value_catching_error` function is called when resolving a field in + GraphQL. This patch wraps each call with `allow_writer_in_context` context manager. + This allows to use writer DB in resolvers, when they are called via mutation, while + they will still raise or log error when a resolver is run in a query. + """ + + executor.complete_value_catching_error = _patched_complete_value_catching_error diff --git a/saleor/plugins/manager.py b/saleor/plugins/manager.py index 9ae90c1e61d..169eb881690 100644 --- a/saleor/plugins/manager.py +++ b/saleor/plugins/manager.py @@ -14,6 +14,7 @@ from ..channel.models import Channel from ..checkout import base_calculations +from ..core.db.connection import allow_writer from ..core.models import EventDelivery from ..core.payments import PaymentInterface from ..core.prices import quantize_price @@ -2219,4 +2220,8 @@ def get_plugins_manager( requestor_getter: Optional[Callable[[], "Requestor"]] = None, ) -> PluginsManager: with opentracing.global_tracer().start_active_span("get_plugins_manager"): - return PluginsManager(settings.PLUGINS, requestor_getter, allow_replica) + if allow_replica: + return PluginsManager(settings.PLUGINS, requestor_getter, allow_replica) + else: + with allow_writer(): + return PluginsManager(settings.PLUGINS, requestor_getter, allow_replica) diff --git a/saleor/plugins/webhook/tests/test_shipping_webhook.py b/saleor/plugins/webhook/tests/test_shipping_webhook.py index b3c30d81e15..7e1f5bb3dc2 100644 --- a/saleor/plugins/webhook/tests/test_shipping_webhook.py +++ b/saleor/plugins/webhook/tests/test_shipping_webhook.py @@ -412,9 +412,10 @@ def test_order_available_shipping_methods( ): # given settings.PLUGINS = ["saleor.plugins.webhook.plugin.WebhookPlugin"] + shipping_method = order_with_lines.shipping_method def respond(*args, **kwargs): - return webhook_response(order_with_lines.shipping_method) + return webhook_response(shipping_method) mocked_webhook.side_effect = respond permission_group_manage_orders.user_set.add(staff_api_client.user) diff --git a/saleor/product/utils/digital_products.py b/saleor/product/utils/digital_products.py index d6f604a37a7..913ed1ff0c1 100644 --- a/saleor/product/utils/digital_products.py +++ b/saleor/product/utils/digital_products.py @@ -4,6 +4,7 @@ from django.utils.timezone import now from ...account import events as account_events +from ...core.db.connection import allow_writer from ..models import DigitalContentUrl @@ -42,6 +43,7 @@ def digital_content_url_is_valid(content_url: DigitalContentUrl) -> bool: return True +@allow_writer() def increment_download_count(content_url: DigitalContentUrl): content_url.download_num += 1 content_url.save(update_fields=["download_num"]) diff --git a/saleor/product/views.py b/saleor/product/views.py index cc75fcdac6e..75c82ec118c 100644 --- a/saleor/product/views.py +++ b/saleor/product/views.py @@ -2,6 +2,7 @@ import os from typing import Union +from django.conf import settings from django.http import FileResponse, HttpResponseNotFound from django.shortcuts import get_object_or_404 @@ -15,7 +16,9 @@ def digital_product(request, token: str) -> Union[FileResponse, HttpResponseNotFound]: """Return the direct download link to content if given token is still valid.""" - qs = DigitalContentUrl.objects.prefetch_related("line__order__user") + qs = DigitalContentUrl.objects.using( + settings.DATABASE_CONNECTION_REPLICA_NAME + ).prefetch_related("line__order__user") content_url = get_object_or_404(qs, token=token) # type: DigitalContentUrl if not digital_content_url_is_valid(content_url): return HttpResponseNotFound("Url is not valid anymore") diff --git a/saleor/settings.py b/saleor/settings.py index e0acb03a389..b69ebaecc38 100644 --- a/saleor/settings.py +++ b/saleor/settings.py @@ -29,6 +29,7 @@ from . import PatchedSubscriberExecutionContext, __version__ from .core.languages import LANGUAGES as CORE_LANGUAGES from .core.schedules import initiated_promotion_webhook_schedule +from .graphql.executor import patch_executor django_stubs_ext.monkeypatch() @@ -878,6 +879,8 @@ def SENTRY_INIT(dsn: str, sentry_opts: dict): executor.SubscriberExecutionContext = PatchedSubscriberExecutionContext # type: ignore +patch_executor() + # Optional queue names for Celery tasks. # Set None to route to the default queue, or a string value to use a separate one # From 1ab4f3bbcebec211c169277d9cc94d250e943d3a Mon Sep 17 00:00:00 2001 From: Piotr Zabieglik <55899043+zedzior@users.noreply.github.com> Date: Mon, 8 Apr 2024 16:43:40 +0200 Subject: [PATCH 24/35] Add `unique_type` to line discount models. (#15775) * Add unique_type to line discount models. * Update changelog. * Remove redundant stuff from migration. --- CHANGELOG.md | 1 + .../migrations/0078_add_unique_type.py | 44 +++++++++++++++ .../0079_add_index_for_unique_type.py | 33 ++++++++++++ .../0080_add_unique_type_constraint.py | 54 +++++++++++++++++++ saleor/discount/models.py | 26 +++++++++ 5 files changed, 158 insertions(+) create mode 100644 saleor/discount/migrations/0078_add_unique_type.py create mode 100644 saleor/discount/migrations/0079_add_index_for_unique_type.py create mode 100644 saleor/discount/migrations/0080_add_unique_type_constraint.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ce1ddd81e2..e79571decbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -58,6 +58,7 @@ All notable, unreleased changes to this project will be documented in this file. - Added new input `AppInput.identifier`. - Added new parameter `identifier` for `create_app` command. - When `taxAppId` is provided for `TaxConfiguration` do not allow to finalize `checkoutComplete` or `draftOrderComplete` mutations if Tax App or Avatax plugin didn't respond. +- Add `unique_type` to `OrderLineDiscount` and `CheckoutLineDiscount` models - #15774 by @zedzior # 3.18.0 diff --git a/saleor/discount/migrations/0078_add_unique_type.py b/saleor/discount/migrations/0078_add_unique_type.py new file mode 100644 index 00000000000..85dd280310e --- /dev/null +++ b/saleor/discount/migrations/0078_add_unique_type.py @@ -0,0 +1,44 @@ +# Generated by Django 3.2.24 on 2024-04-08 11:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("discount", "0077_merge_20240307_1217"), + ] + + operations = [ + migrations.AddField( + model_name="checkoutlinediscount", + name="unique_type", + field=models.CharField( + blank=True, + choices=[ + ("sale", "Sale"), + ("voucher", "Voucher"), + ("manual", "Manual"), + ("promotion", "Promotion"), + ("order_promotion", "Order promotion"), + ], + max_length=64, + null=True, + ), + ), + migrations.AddField( + model_name="orderlinediscount", + name="unique_type", + field=models.CharField( + blank=True, + choices=[ + ("sale", "Sale"), + ("voucher", "Voucher"), + ("manual", "Manual"), + ("promotion", "Promotion"), + ("order_promotion", "Order promotion"), + ], + max_length=64, + null=True, + ), + ), + ] diff --git a/saleor/discount/migrations/0079_add_index_for_unique_type.py b/saleor/discount/migrations/0079_add_index_for_unique_type.py new file mode 100644 index 00000000000..b97205ae32c --- /dev/null +++ b/saleor/discount/migrations/0079_add_index_for_unique_type.py @@ -0,0 +1,33 @@ +# Generated by Django 3.2.24 on 2024-04-08 11:18 + +from django.db import migrations + + +class Migration(migrations.Migration): + atomic = False + dependencies = [ + ("discount", "0078_add_unique_type"), + ] + + operations = [ + migrations.RunSQL( + sql=""" + CREATE UNIQUE INDEX CONCURRENTLY checkoutlinediscount_line_id_unique_type_idx + ON discount_checkoutlinediscount USING btree (line_id, unique_type); + """, + reverse_sql=""" + DROP INDEX CONCURRENTLY IF EXISTS + checkoutlinediscount_line_id_unique_type_idx; + """, + ), + migrations.RunSQL( + sql=""" + CREATE UNIQUE INDEX CONCURRENTLY orderlinediscount_line_id_unique_type_idx + ON discount_orderlinediscount USING btree (line_id, unique_type); + """, + reverse_sql=""" + DROP INDEX CONCURRENTLY IF EXISTS + orderlinediscount_line_id_unique_type_idx; + """, + ), + ] diff --git a/saleor/discount/migrations/0080_add_unique_type_constraint.py b/saleor/discount/migrations/0080_add_unique_type_constraint.py new file mode 100644 index 00000000000..ce3c8e8e000 --- /dev/null +++ b/saleor/discount/migrations/0080_add_unique_type_constraint.py @@ -0,0 +1,54 @@ +# Generated by Django 3.2.24 on 2024-04-08 11:20 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("discount", "0079_add_index_for_unique_type"), + ] + + operations = [ + migrations.SeparateDatabaseAndState( + database_operations=[ + migrations.RunSQL( + sql=""" + ALTER TABLE discount_checkoutlinediscount + ADD CONSTRAINT unique_checkoutline_discount_type + UNIQUE USING INDEX checkoutlinediscount_line_id_unique_type_idx; + """, + reverse_sql=""" + ALTER TABLE discount_checkoutlinediscount DROP CONSTRAINT IF EXISTS + unique_checkoutline_discount_type; + """, + ), + migrations.RunSQL( + sql=""" + ALTER TABLE discount_orderlinediscount + ADD CONSTRAINT unique_orderline_discount_type + UNIQUE USING INDEX orderlinediscount_line_id_unique_type_idx; + """, + reverse_sql=""" + ALTER TABLE discount_orderlinediscount DROP CONSTRAINT IF EXISTS + unique_orderline_discount_type; + """, + ), + ], + state_operations=[ + migrations.AddConstraint( + model_name="checkoutlinediscount", + constraint=models.UniqueConstraint( + fields=("line_id", "unique_type"), + name="unique_checkoutline_discount_type", + ), + ), + migrations.AddConstraint( + model_name="orderlinediscount", + constraint=models.UniqueConstraint( + fields=("line_id", "unique_type"), + name="unique_orderline_discount_type", + ), + ), + ], + ) + ] diff --git a/saleor/discount/models.py b/saleor/discount/models.py index 61461e6d9e6..2027e4b7058 100644 --- a/saleor/discount/models.py +++ b/saleor/discount/models.py @@ -539,6 +539,13 @@ class OrderLineDiscount(BaseDiscount): null=True, on_delete=models.CASCADE, ) + # Saleor in version 3.19 and below, doesn't have any unique constraint applied on + # discounts for checkout/order. To not have an impact on existing DB objects, + # the new field `unique_type` will be used for new discount records. + # This will ensure that we always apply a single specific discount type. + unique_type = models.CharField( + max_length=64, null=True, blank=True, choices=DiscountType.CHOICES + ) class Meta: indexes = [ @@ -548,6 +555,12 @@ class Meta: GinIndex(fields=["voucher_code"], name="orderlinedisc_voucher_code_idx"), ] ordering = ("created_at", "id") + constraints = [ + models.UniqueConstraint( + fields=["line_id", "unique_type"], + name="unique_orderline_discount_type", + ), + ] class CheckoutDiscount(BaseDiscount): @@ -578,6 +591,13 @@ class CheckoutLineDiscount(BaseDiscount): null=True, on_delete=models.CASCADE, ) + # Saleor in version 3.19 and below, doesn't have any unique constraint applied on + # discounts for checkout/order. To not have an impact on existing DB objects, + # the new field `unique_type` will be used for new discount records. + # This will ensure that we always apply a single specific discount type. + unique_type = models.CharField( + max_length=64, null=True, blank=True, choices=DiscountType.CHOICES + ) class Meta: indexes = [ @@ -587,6 +607,12 @@ class Meta: GinIndex(fields=["voucher_code"], name="checklinedisc_voucher_code_idx"), ] ordering = ("created_at", "id") + constraints = [ + models.UniqueConstraint( + fields=["line_id", "unique_type"], + name="unique_checkoutline_discount_type", + ), + ] class PromotionEvent(models.Model): From de8e3a5085e13a8a7311692ca95c216f193ccd2b Mon Sep 17 00:00:00 2001 From: Artur <31221055+Air-t@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:28:51 +0200 Subject: [PATCH 25/35] Fix checkout shipping addresses which had reference to warehouse addresses (#15747) --- .../migrations/0063_auto_20240402_1114.py | 17 ++++++ ...o_20240402_1114_0065_checkout_tax_error.py | 11 ++++ .../migrations/0067_auto_20240405_0756.py | 55 +++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 saleor/checkout/migrations/0063_auto_20240402_1114.py create mode 100644 saleor/checkout/migrations/0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error.py create mode 100644 saleor/checkout/migrations/0067_auto_20240405_0756.py diff --git a/saleor/checkout/migrations/0063_auto_20240402_1114.py b/saleor/checkout/migrations/0063_auto_20240402_1114.py new file mode 100644 index 00000000000..2267d5dd53c --- /dev/null +++ b/saleor/checkout/migrations/0063_auto_20240402_1114.py @@ -0,0 +1,17 @@ +# Generated by Django 3.2.22 on 2024-04-02 11:14 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ( + "checkout", + "0062_update_checkout_last_transaction_modified_at_and_refundable", + ) + ] + + # Execution of a task moved to 0067_auto_20240405_0756.py. + # The migration task was not triggered due to missing proper celeryconf setup. + # Migration file left for history consistency. + operations = [] diff --git a/saleor/checkout/migrations/0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error.py b/saleor/checkout/migrations/0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error.py new file mode 100644 index 00000000000..3737c768b46 --- /dev/null +++ b/saleor/checkout/migrations/0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error.py @@ -0,0 +1,11 @@ +# Generated by Django 3.2.25 on 2024-04-03 08:09 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("checkout", "0063_auto_20240402_1114"), + ("checkout", "0065_checkout_tax_error"), + ] + operations = [] diff --git a/saleor/checkout/migrations/0067_auto_20240405_0756.py b/saleor/checkout/migrations/0067_auto_20240405_0756.py new file mode 100644 index 00000000000..d8fd5ccb813 --- /dev/null +++ b/saleor/checkout/migrations/0067_auto_20240405_0756.py @@ -0,0 +1,55 @@ +# Generated by Django 3.2.25 on 2024-04-05 07:56 + +from django.db import migrations +from django.db.models.expressions import Exists, OuterRef +from django.forms import model_to_dict + +# The batch of size 250 takes ~0.5 second and consumes ~20MB memory at peak +ADDRESS_UPDATE_BATCH_SIZE = 250 + + +def queryset_in_batches(queryset): + """Slice a queryset into batches. + + Input queryset should be sorted be pk. + """ + start_pk = 0 + + while True: + qs = queryset.filter(pk__gt=start_pk)[:ADDRESS_UPDATE_BATCH_SIZE] + pks = list(qs.values_list("pk", flat=True)) + if not pks: + break + yield pks + start_pk = pks[-1] + + +def update_checkout_addresses(apps, schema_editor): + Checkout = apps.get_model("checkout", "Checkout") + Warehouse = apps.get_model("warehouse", "Warehouse") + Address = apps.get_model("account", "Address") + + queryset = Checkout.objects.filter( + Exists(Warehouse.objects.filter(address_id=OuterRef("shipping_address_id"))), + ).order_by("pk") + + for batch_pks in queryset_in_batches(queryset): + checkouts = Checkout.objects.filter(pk__in=batch_pks) + addresses = [] + for checkout in checkouts: + if cc_address := checkout.shipping_address: + checkout_address = Address(**model_to_dict(cc_address, exclude=["id"])) + checkout.shipping_address = checkout_address + addresses.append(checkout_address) + Address.objects.bulk_create(addresses, ignore_conflicts=True) + Checkout.objects.bulk_update(checkouts, ["shipping_address"]) + + +class Migration(migrations.Migration): + dependencies = [ + ("checkout", "0066_merge_0063_auto_20240402_1114_0065_checkout_tax_error"), + ] + + operations = [ + migrations.RunPython(update_checkout_addresses, migrations.RunPython.noop), + ] From 609f295ea00bb12a49bc80b6493198b1407754b2 Mon Sep 17 00:00:00 2001 From: Artur <31221055+Air-t@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:30:29 +0200 Subject: [PATCH 26/35] remove celery task from order migration (#15776) --- .../0172_update_order_cc_addresses.py | 29 ++++++++++++++----- saleor/order/migrations/tasks/__init__.py | 0 saleor/order/migrations/tasks/saleor3_19.py | 29 ------------------- 3 files changed, 21 insertions(+), 37 deletions(-) delete mode 100644 saleor/order/migrations/tasks/__init__.py delete mode 100644 saleor/order/migrations/tasks/saleor3_19.py diff --git a/saleor/order/migrations/0172_update_order_cc_addresses.py b/saleor/order/migrations/0172_update_order_cc_addresses.py index 773ebd9bbb6..c686fdfd5fa 100644 --- a/saleor/order/migrations/0172_update_order_cc_addresses.py +++ b/saleor/order/migrations/0172_update_order_cc_addresses.py @@ -2,23 +2,37 @@ from django.db.models import Exists, OuterRef from django.forms.models import model_to_dict -from .tasks.saleor3_19 import update_order_addresses_task - # The batch of size 250 takes ~0.5 second and consumes ~20MB memory at peak ADDRESS_UPDATE_BATCH_SIZE = 250 +def queryset_in_batches(queryset): + """Slice a queryset into batches. + + Input queryset should be sorted be pk. + """ + start_pk = 0 + + while True: + qs = queryset.filter(pk__gt=start_pk)[:ADDRESS_UPDATE_BATCH_SIZE] + pks = list(qs.values_list("pk", flat=True)) + if not pks: + break + yield pks + start_pk = pks[-1] + + def update_order_addresses(apps, schema_editor): Order = apps.get_model("order", "Order") Warehouse = apps.get_model("warehouse", "Warehouse") Address = apps.get_model("account", "Address") - qs = Order.objects.filter( + queryset = Order.objects.filter( Exists(Warehouse.objects.filter(address_id=OuterRef("shipping_address_id"))), - ) - order_ids = qs.values_list("pk", flat=True)[:ADDRESS_UPDATE_BATCH_SIZE] - addresses = [] - if order_ids: + ).order_by("pk") + + for order_ids in queryset_in_batches(queryset): orders = Order.objects.filter(id__in=order_ids) + addresses = [] for order in orders: if cc_address := order.shipping_address: order_address = Address(**model_to_dict(cc_address, exclude=["id"])) @@ -26,7 +40,6 @@ def update_order_addresses(apps, schema_editor): addresses.append(order_address) Address.objects.bulk_create(addresses, ignore_conflicts=True) Order.objects.bulk_update(orders, ["shipping_address"]) - update_order_addresses_task.delay() class Migration(migrations.Migration): diff --git a/saleor/order/migrations/tasks/__init__.py b/saleor/order/migrations/tasks/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/saleor/order/migrations/tasks/saleor3_19.py b/saleor/order/migrations/tasks/saleor3_19.py deleted file mode 100644 index ddc468009b7..00000000000 --- a/saleor/order/migrations/tasks/saleor3_19.py +++ /dev/null @@ -1,29 +0,0 @@ -from django.db.models import Exists, OuterRef -from django.forms.models import model_to_dict - -from ....account.models import Address -from ....celeryconf import app -from ....warehouse.models import Warehouse -from ...models import Order - -# The batch of size 250 takes ~0.5 second and consumes ~20MB memory at peak -ADDRESS_UPDATE_BATCH_SIZE = 250 - - -@app.task -def update_order_addresses_task(): - qs = Order.objects.filter( - Exists(Warehouse.objects.filter(address_id=OuterRef("shipping_address_id"))), - ) - order_ids = qs.values_list("pk", flat=True)[:ADDRESS_UPDATE_BATCH_SIZE] - addresses = [] - if order_ids: - orders = Order.objects.filter(id__in=order_ids) - for order in orders: - if cc_address := order.shipping_address: - order_address = Address(**model_to_dict(cc_address, exclude=["id"])) - order.shipping_address = order_address - addresses.append(order_address) - Address.objects.bulk_create(addresses, ignore_conflicts=True) - Order.objects.bulk_update(orders, ["shipping_address"]) - update_order_addresses_task.delay() From fcf1ac55b150f29be0f2473ebdc2fcf73d3bc4a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Szyma=C5=84ski?= Date: Wed, 10 Apr 2024 12:20:56 +0200 Subject: [PATCH 27/35] Fix checkout base prices when tax app is broken (#15783) --- saleor/checkout/calculations.py | 6 ++++-- saleor/checkout/tests/test_calculations.py | 11 ++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/saleor/checkout/calculations.py b/saleor/checkout/calculations.py index 92b1070af1f..6e0e9302a90 100644 --- a/saleor/checkout/calculations.py +++ b/saleor/checkout/calculations.py @@ -276,6 +276,7 @@ def _fetch_checkout_prices_if_expired( database_connection_name=database_connection_name, ) except TaxEmptyData as e: + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.tax_error = str(e) if not should_charge_tax: @@ -301,10 +302,11 @@ def _fetch_checkout_prices_if_expired( database_connection_name=database_connection_name, ) except TaxEmptyData as e: + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.tax_error = str(e) else: # Calculate net prices without taxes. - _get_checkout_base_prices(checkout, checkout_info, lines) + _set_checkout_base_prices(checkout, checkout_info, lines) checkout_update_fields = [ "voucher_code", @@ -532,7 +534,7 @@ def _apply_tax_data_from_plugins( ) -def _get_checkout_base_prices( +def _set_checkout_base_prices( checkout: "Checkout", checkout_info: "CheckoutInfo", lines: Iterable["CheckoutLineInfo"], diff --git a/saleor/checkout/tests/test_calculations.py b/saleor/checkout/tests/test_calculations.py index a4c1d7dc24a..82ab2e8327b 100644 --- a/saleor/checkout/tests/test_calculations.py +++ b/saleor/checkout/tests/test_calculations.py @@ -24,7 +24,7 @@ from ..calculations import ( _apply_tax_data, _calculate_and_add_tax, - _get_checkout_base_prices, + _set_checkout_base_prices, fetch_checkout_data, ) from ..fetch import CheckoutLineInfo, fetch_checkout_info, fetch_checkout_lines @@ -269,7 +269,7 @@ def test_fetch_checkout_data_flat_rates_and_no_tax_calc_strategy( assert checkout.shipping_tax_rate == Decimal("0.2300") -def test_get_checkout_base_prices_no_charge_taxes_with_voucher( +def test_set_checkout_base_prices_no_charge_taxes_with_voucher( checkout_with_item, voucher_percentage ): # given @@ -312,7 +312,7 @@ def test_get_checkout_base_prices_no_charge_taxes_with_voucher( checkout_info = fetch_checkout_info(checkout, lines, manager) # when - _get_checkout_base_prices(checkout, checkout_info, lines) + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.save() checkout.lines.bulk_update( [line_info.line for line_info in lines], @@ -339,7 +339,7 @@ def test_get_checkout_base_prices_no_charge_taxes_with_voucher( assert line.total_price == checkout.total -def test_get_checkout_base_prices_no_charge_taxes_with_order_promotion( +def test_set_checkout_base_prices_no_charge_taxes_with_order_promotion( checkout_with_item_and_order_discount, ): # given @@ -359,7 +359,7 @@ def test_get_checkout_base_prices_no_charge_taxes_with_order_promotion( checkout_info = fetch_checkout_info(checkout, lines, manager) # when - _get_checkout_base_prices(checkout, checkout_info, lines) + _set_checkout_base_prices(checkout, checkout_info, lines) checkout.save() checkout.lines.bulk_update( [line_info.line for line_info in lines], @@ -634,6 +634,7 @@ def test_fetch_checkout_data_calls_inactive_plugin( fetch_checkout_data(**fetch_kwargs) # then + assert checkout.total.gross.amount > 0 assert checkout_with_items.tax_error == "Empty tax data." From cd38de41627a49a9e18a6ba50adda5ac024cce85 Mon Sep 17 00:00:00 2001 From: Krzysztof Waliczek Date: Thu, 11 Apr 2024 10:21:24 +0200 Subject: [PATCH 28/35] fix migration nodes (#15764) --- .../product/migrations/0187_merge_20231221_1030.py | 12 ++++++++++++ .../product/migrations/0188_merge_20231221_1119.py | 12 ++++++++++++ .../product/migrations/0189_merge_20240405_1121.py | 12 ++++++++++++ .../product/migrations/0191_merge_20240405_1125.py | 12 ++++++++++++ .../product/migrations/0192_merge_20240405_1154.py | 12 ++++++++++++ 5 files changed, 60 insertions(+) create mode 100644 saleor/product/migrations/0187_merge_20231221_1030.py create mode 100644 saleor/product/migrations/0188_merge_20231221_1119.py create mode 100644 saleor/product/migrations/0189_merge_20240405_1121.py create mode 100644 saleor/product/migrations/0191_merge_20240405_1125.py create mode 100644 saleor/product/migrations/0192_merge_20240405_1154.py diff --git a/saleor/product/migrations/0187_merge_20231221_1030.py b/saleor/product/migrations/0187_merge_20231221_1030.py new file mode 100644 index 00000000000..cd674db847d --- /dev/null +++ b/saleor/product/migrations/0187_merge_20231221_1030.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2023-12-21 10:30 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0186_alter_productmedia_alt"), + ("product", "0186_remove_product_charge_taxes"), + ] + + operations = [] diff --git a/saleor/product/migrations/0188_merge_20231221_1119.py b/saleor/product/migrations/0188_merge_20231221_1119.py new file mode 100644 index 00000000000..43ba261509e --- /dev/null +++ b/saleor/product/migrations/0188_merge_20231221_1119.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2023-12-21 11:19 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0186_alter_productmedia_alt"), + ("product", "0187_auto_20230614_0838"), + ] + + operations = [] diff --git a/saleor/product/migrations/0189_merge_20240405_1121.py b/saleor/product/migrations/0189_merge_20240405_1121.py new file mode 100644 index 00000000000..c29237d1a4f --- /dev/null +++ b/saleor/product/migrations/0189_merge_20240405_1121.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-04-05 11:21 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0187_merge_20231221_1030"), + ("product", "0188_merge_20231221_1119"), + ] + + operations = [] diff --git a/saleor/product/migrations/0191_merge_20240405_1125.py b/saleor/product/migrations/0191_merge_20240405_1125.py new file mode 100644 index 00000000000..f4b0edafa5a --- /dev/null +++ b/saleor/product/migrations/0191_merge_20240405_1125.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-04-05 11:25 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0189_merge_20240405_1121"), + ("product", "0190_merge_20231221_1356"), + ] + + operations = [] diff --git a/saleor/product/migrations/0192_merge_20240405_1154.py b/saleor/product/migrations/0192_merge_20240405_1154.py new file mode 100644 index 00000000000..c6523f14c3a --- /dev/null +++ b/saleor/product/migrations/0192_merge_20240405_1154.py @@ -0,0 +1,12 @@ +# Generated by Django 3.2.22 on 2024-04-05 11:54 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("product", "0191_merge_20240405_1125"), + ("product", "0191_productchannellisting_discounted_price_dirty"), + ] + + operations = [] From eb0c3f25e9f760c18a1af40a5faa13f53a53a2e9 Mon Sep 17 00:00:00 2001 From: Piotr Zabieglik <55899043+zedzior@users.noreply.github.com> Date: Tue, 16 Apr 2024 11:14:52 +0200 Subject: [PATCH 29/35] Calculate order promotions in draft order (#15459) * Refactor fetching variants for predicates (#14670) * Refactor method for fetching variants for predicate * Limit querysets used for filtering * Fix failing tests * Extend PromotionRule with checkout_and_order_predicate field (#14664) * Fix migrations * Allow defining `checkoutAndOrderPredicate` in promotion mutations (#15054) * Add checkout_and_order_predicate field to promotion_rule graphql type. * Add reward_type field. * Add predicate types. * Temporarily turn checkoutAndOrderPredicate field into JSON. * Add input types. * Validate new input fields. * Cover validations with tests. * Remove stuff from previous approach. * Do not allow mixing predicates. * Fix promotion rule validation * Adjust clean_promotion_rule validator * Extend PromotionCreate validation * Unify promotion rule validations * Rename PredicateType ORDER field to CHECKOUT_AND_ORDER * Add more tests to PromotionRuleUpdate * Update promotion/validators.py file * Add docstrings to functions in promotion validators.py file * Add tests for promotion validators * Add validation for price based predicate and rule with mixed currencies * Add mising test for PromotionRuleCreate * Clear graphql/discount/utils.py * Fix migrations and removed not needed changes * Apply code review suggestions * Refactor get_predicate_type function. --------- Co-authored-by: IKarbowiak * Update changelog * Include promotion discounts in checkout calculations (#15052) * Add CheckoutDiscountedObjectWhere * Add fetch_promotion_rules_for_checkout method * Introduce PredicateObjectType * Add schema of including checkout and order discount * Extend discount models * Include checkout and order promotion discount in checkoucalculations * Adjust create or update checkout discount - remove CheckoutDiscount when the voucher is assigned and update CheckoutOrder if any exist * Adjust checkout discount creation for checkout and order promotions * Delete CheckoutDiscount when no rule applies anymore * Add tests for CheckoutLinesAdd * Fix failing tests * Add test for generate_checkout_payload_for_tax_calculation * Include checkout discounts in CheckoutInfo dataloader * Add tests for checkout base calculations * Add tests for webhooks * Add tests for recalculate_checkout_discount * Add tests for price_override * Add tests for avatax * Add test for calculate_checkout_total method for Avatax plugin * Add test for checkout calculation * Fix migrations * Apply code review suggestions * test_unable_to_have_promotion_rule_with_mixed_predicates_CORE_2125 (#15210) * Adjust voucher assignement (#15189) * Adjsut CheckoutAddPromoCode * Adjust checkoutRemovePromoCode * Update test for recalculate_checkout_discount * Rename checkoutAndOrder discount to order discount (#15207) * Rename checkoutAndOrder discount to order discount * Rename occurrences in comments. --------- Co-authored-by: zedzior * Fix failing tests * Prevent race condition on `CheckoutDiscount` object creation. (#15212) * Add unique constrain on CheckoutDiscount model. * Use get_or_create when creating CheckoutDiscount. * Make it working. * Add test. * Add additional test. * Adjust checkout discounted object where filter (#15187) * Adjust total_price filtering in CheckoutDiscountedObjectWhere * Add filtering by subtotal_price in CheckoutDiscountedObjectWhere * Apply code review suggestiins and change DiscountedObjectPredicateInput to DiscountedObjectWhereInput * Fix failing CI * Fix failing test * Fix DiscountedObjectWhereInput - include AND, OR operators * Add rules limit for checkout and order promotions. (#15200) * Add CHECKOUT_AND_ORDER_RULES_LIMIT env variable * Add limit check to promotion validators * Add test for promotion_rule_create mutation. * Add validation to promotion_create mutation; cover with test. * Add TODO comment; fix linter. * Check the number of all checkout and order rules in database, instead of rules asocciated with given promotion. * Extend errors with rulesLimit and exceedBy fields. * Rename checkoutAndOrder; fix tests. * Add order promotions when populating db with dummy data. (#15232) * test_apply_promotion_with_order_predicate_eq_amount_on_checkout (#15222) * Get rid of total discount (#15229) * Get rid of total discount * Fix failing test * Adjust promotion rule type (#15213) * Add predicate_type to PromotionRule model * Add predicateType to PromotionRuleCreateInput * Update changelog * Make predicateType in PromotionRuleCreateInput optional, set the default value * Fix failing e2e * Adjust promotion validations * Fix setting default value for predicateType * Set predicateType on PromotionRule in sale mutations * Fix failing tests * Add type to Promotion * Get rif of MIXED_PROMOTION_PREDICATES error code * Update promotion rule validators * Replace promotion fixture with catalogue_promotion * Update update_products_discounted_prices_for_promotion_task * Drop commented line in PromotionCreate * Collapse migrations introduced for order predicate (#15252) * Do not re-calculate the promotions when catalogue predicate is empty (#15251) * Fix `OrderDiscount.type` when converting checkout to order. (#15248) * Create proper 'OrderDiscount' object, when converting checkout to order. * Add voucher and voucher code to 'OrderDiscount' when discount type is VOUCHER. * Add test. * Fix linter. * Simplify 'OrderDiscount' create. * Apply review changes. * Fix linter. * Apply database changes for gift promotion rule (#15254) * Extend PromotionRule with gifts field * Extend CheckoutLine and OrderLine with is_gift field * Add GIFT RewardType * Add gift_promotion_rule fixture * Fix migrations * Fix fixture * test_promotion_applied_on_checkout_with_specified_lte_gte_subtotal (#15261) * Checkout and order promotions adjustments (#15279) * Add missing labels * Add atomic block for setting voucher on checkout and deleting the discount * Gift promotions API. (#15274) * Add 'gifts' and 'giftsLimit' fields to 'PromotionRule' type. * Add 'is_gift' field to 'OrderLine' and 'CheckoutLine' types. * Add 'gifts' to promotion rule input; validate gift promotions. * Make it working. * Add tests to promotion rule create. * Adjust promotion create mutation. * Add test for promotion create mutation. * Add tests to promotion rule update. * Adjust validators. * Add more tests to promotion rule update. * Add dataloader. * Fix tests. * Test 'isGift' field. * Rename 'gifts' to 'giftIds'. * Adjust mutations after renaming. * Restore previous name in case of inputs. * Add PREVIEW_FEATURE flags. * Get rid of list() on iterator() * Fix failing test * Validate gift lines. (#15326) * Do not allow updating 'OrderLine.quantity' when line is a gift. * Do not allow updating 'CheckoutLine.quantity' when line is a gift. * Do not allow deleting gift from checkout. * Do not allow deleting gifts from checkout. * Extend error with line ids. * Do not allow deleting gifts from order. * Fix test. * Add missing test. * Fix flaky test. * Do not allow to update gift lines at all, not only its quantity; change error codes. * Simplify validation. * Refactor flaky tests. * Fix filtering by order predicate (#15332) * Convert camel case to snake case. * Improve tests. * Correct predicates in tests. * Apply review comments. * Include gift promotion in calculation (#15278) * Create discount object for gift reward * Create checkout line for gift reward * Update gift line creation * Handle race condition for gift line creation * Create CheckoutLineDiscount instead of CheckoutDiscount for gift reward * Gift reward calculation adjustment * Update complete_checkout for gift promotion * Add tests for checkout total calculations and CheckoutLinesAdd mutation * Add tests for webhooks * Add tests for checkout claculations * Add tests for avatax * Update create_discount_objects_for_catalogue_promotions * Update discount/utils.py * Add missing tests for _get_best_gift_reward * Clean duplicated isGift on CheckoutLine * Add more tests * Update gift line removal * Ensure that the gift line is created and returned when the gift reward is applicable during CheckoutCreate mutation * Rename invalidate_checkout_prices to invalidate_checkout * Ensure gift line is created in CheckoutCreateFromOrder mutation * Apply code review suggestions * Update giftreward calculations * Fix checkout calculations with catalogue and gift reward * Handle multiple gifts in _delete_gift_line * Move checkout creation logic to checkout mutation utils * Do not delete checkout line discounts obejcts in loop * Propagate order discount when using tax plugins * Adjust handling gift reward in calculations * Order promotions performance. (#15347) * Check if clear discounts is needed; call the function once * Update base_prices if needed only. * Fix tests. * Refactor. * Set gifts limit per promotion rule (#15353) * Set gifts limit per promotion rule * Apply code review suggestions * Optimize `fetch_promotion_rules_for_checkout` function. (#15352) * Optimize 'fetch_promotion_rules_for_checkout' function. * Fix tests. * Remove 'list(iterator)' pattern. * Clear gift line when not applicable (#15335) * Clear gift line in _clear_checkout_discount * Add tests for CheckoutAddPromoCode and CheckoutRemovePromoCode * Fix test * Adjust gift line removing * Apply code review suggestion * Add filter by promotion type (#15330) * Add filter by promotion type * Apply changes after review * Fix test to use correct promotion types * Apply changes after review * Sketch. * Add DradtOrderLineInfo class with fetch function. * Handle catalogue line discounts. * Refactor. * Further refcator. * Make it working. * Add CheckoutOrOrderModels class. * Apply helper class to order handler. * Merge checkout and order logic. * Cleanup. * Fix gift flow. * Preprae fixtures. * Add tests for create_or_update_discount_objects_from_promotion_for_order function. * More tests. * Cover fetch function. * More tests. * Test discounts update. * Make catalogue discounts working. * Make order discounts working. * Add more tests. * Fix discount reason. * Fix order.undiscounted_total_gross_amunt calculation. * Fix tests. * Add more tests; cleanup test files. * Move get_checkout_or_order_models function. * Merge from main. * Adjust tests. * Adjusts rest of tests. * Fet rid of redundant fields in DraftOrderLineInfo. * Get rid of Order.get_country function in favour of utils.get_order_country. * Move subtotal calculation. * Fill OrderLine's unit_discount fields. * Create separate test file: test_fetch_order_prices.py. * Cleanup after merge. * Self review. * Test draftOrderComplete. * Do not stack manual discounts and order promotion. * Test draftOrderCreate. * Test draftOrderUpdate. * Test orderLinesCreate. * Test orderLineUpdate. * Test avatax plugin. * Restore lazy recalculation. * Fix lint. * Fix tests. * Get rid of CheckoutOrOrderHelper. * Do not update prefetched objects; keep discount cache in line info only. * Further refactor. * Fix tests. * Test against manual discounts. * Test against vouchers. * Explain hardcoded values in test_order_create_from_checkout. * Test OrderDiscountedObjectWhere filter. * Explain hardcoded values in test_draft_order_complete. * Explain hardcoded values in test_draft_order_update. * Explain hardcoded values in test_order_line_update. * Explain hardcoded values in test_order_line_create. * Explain hardcoded values in test_fetch_order_prices.py. * Pull out clear cache logic. * Remove redundant save call. * Merge multiple discounts data for line unit discount fields. * Test against manual line discount. * Merge '_update_order_line_base_unit_prices' and _set_order_base_prices functions. * Save base_unit_price_amount in db. * Split set_base_price function again; fix some tests; clear line.discounts cache. * Update base_unit_price of affected lines only. * Fix _update_line_info_cached_discounts. * Create discount objects for manual line discounts. * Adjust orderLineDiscountRemove. * Update TODOs. * Explain numbers in test_draft_order_create. * Apply review suggestions. * Add more tests. * Refactor order.discounts resolver. * Refactor prefetch logic. * Delete gifts in transaction. * Save prices only if something has changed. * Fix resolver. * Correct clear_prefetched_discounts function. * Merge test queries * Use prefecthed data to get channel listings. * Test price filter with AND and OR operators. * Test manual line discount update. * Fix test. * Optimize fetch_draft_order_line_info function; delete redundant prefetches; test number of queries. * Remove potential duplicated order line discounts. * Extend the solution to handle checkout. * Fix tests. * Narrow duplicates handler to check for catalogue discounts only. * Remove wrong tests with stackable catalogue discounts. * Self-review. * Refactor order part. * Reafctor checkout part. * Adjust tests. * Correct lock. * Correct catalogue promotion discount update; fetch only best catalogue promotion to rules_info. * Add uniqe_type field and unique_together constraints to BaseDiscount models. * Adapt discount create logic. * Add tests. * Add dedicated enum for unique_type; correct migration. * Add more tests. * Fix performance issue. * Adjust query counts in checkout benchmark tests. * Alter unique_field to accept all strings. * Correct lock. * Adjust tests. * Remove UniqueDiscountType. * Remove redundant tests. * Add more comments. * Update comments. * Add choices for unique_type; add indexes concurrently. * Lock whole checkout instead of its lines. * Add back prefetch to fetching order lines. * Add benchmark test for fetching order prices. * Add more benchmar tests to checkout mutations. * Reduce db call number during fetch_draft_order_lines_info. * Add more benchmark tests. * Self review. * Fix failing test. * Fix test. * Update changelog. * Check if discount model has unique type; correct tests.. * Add falback for mismatch between discount_amount and rules_info. * Pass database_connection_name to create_or_update_discount_objects_from_promotion_for_order. * Add space between discount reasons. --------- Co-authored-by: Iga Karbowiak <40886528+IKarbowiak@users.noreply.github.com> Co-authored-by: IKarbowiak Co-authored-by: Renata Co-authored-by: Maciej Korycinski Co-authored-by: Maciej Korycinski --- CHANGELOG.md | 1 + saleor/checkout/utils.py | 10 +- saleor/discount/interface.py | 12 +- saleor/discount/tests/test_utils/fixtures.py | 86 + ...t_copy_unit_discount_data_to_order_line.py | 96 ++ ...unt_objects_from_promotion_for_checkout.py | 335 ++-- ...scount_objects_from_promotion_for_order.py | 523 +++++++ ...test_fetch_promotion_rules_for_checkout.py | 8 +- saleor/discount/utils.py | 797 +++++++--- saleor/graphql/checkout/mutations/utils.py | 28 +- .../benchmark/test_checkout_mutations.py | 158 +- .../test_checkout_complete_with_payment.py | 183 +-- .../test_order_create_from_checkout.py | 209 ++- saleor/graphql/discount/utils.py | 32 + saleor/graphql/order/filters.py | 30 +- .../order/mutations/draft_order_create.py | 4 +- .../order/mutations/draft_order_update.py | 3 +- .../graphql/order/mutations/order_update.py | 4 +- .../mutations/test_draft_order_complete.py | 213 +++ .../mutations/test_draft_order_create.py | 261 ++- .../mutations/test_draft_order_update.py | 385 ++++- .../tests/mutations/test_order_discount.py | 355 ++++- .../tests/mutations/test_order_line_update.py | 107 ++ .../mutations/test_order_lines_create.py | 158 +- .../queries/test_draft_order_with_filter.py | 243 +++ saleor/graphql/order/types.py | 7 +- saleor/order/base_calculations.py | 9 +- saleor/order/calculations.py | 46 +- saleor/order/fetch.py | 96 +- saleor/order/tests/benchmark/__init__.py | 0 .../benchmark/test_fetch_order_prices.py | 144 ++ ...ations.py => test_apply_order_discount.py} | 46 +- saleor/order/tests/test_fetch.py | 73 + saleor/order/tests/test_fetch_order_prices.py | 1394 +++++++++++++++++ saleor/order/utils.py | 50 + ..._calculate_order_total_gift_promotion.yaml | 81 + ...calculate_order_total_order_promotion.yaml | 77 + saleor/plugins/avatax/tests/test_avatax.py | 50 + saleor/tax/calculations/order.py | 36 +- .../tax/tests/test_checkout_calculations.py | 6 +- saleor/tax/tests/test_order_calculations.py | 2 +- ...order_product_with_percentage_promotion.py | 4 +- .../test_order_products_on_percentage_sale.py | 4 +- ..._on_promotion_and_manual_order_discount.py | 4 +- saleor/tests/fixtures.py | 119 +- saleor/tests/utils.py | 10 + 46 files changed, 5686 insertions(+), 813 deletions(-) create mode 100644 saleor/discount/tests/test_utils/test_copy_unit_discount_data_to_order_line.py create mode 100644 saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_order.py create mode 100644 saleor/order/tests/benchmark/__init__.py create mode 100644 saleor/order/tests/benchmark/test_fetch_order_prices.py rename saleor/order/tests/{test_discount_calculations.py => test_apply_order_discount.py} (94%) create mode 100644 saleor/order/tests/test_fetch.py create mode 100644 saleor/order/tests/test_fetch_order_prices.py create mode 100644 saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_gift_promotion.yaml create mode 100644 saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_order_promotion.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index e79571decbb..f6c83ed37fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ All notable, unreleased changes to this project will be documented in this file. - Remove `prefetched_for_webhook` to legacy payload generators - #15369 by @AjmalPonneth - Don't raise InsufficientStock for track_inventory=False variants #15475 by @carlosa54 - DB performance improvements in attribute dataloaders - #15474 by @AjmalPonneth +- Calculate order promotions in draft orders - #15459 by @zedzior # 3.19.0 diff --git a/saleor/checkout/utils.py b/saleor/checkout/utils.py index 4aa267258db..94084f66b21 100644 --- a/saleor/checkout/utils.py +++ b/saleor/checkout/utils.py @@ -32,8 +32,8 @@ VoucherCode, ) from ..discount.utils import ( - create_discount_objects_for_catalogue_promotions, - create_discount_objects_for_order_promotions, + create_checkout_discount_objects_for_order_promotions, + create_checkout_line_discount_objects_for_catalogue_promotions, delete_gift_line, get_products_voucher_discount, get_voucher_code_instance, @@ -96,7 +96,7 @@ def recalculate_checkout_discounts( Update line and checkout discounts from vouchers and promotions. Create or remove gift line if needed. """ - create_discount_objects_for_catalogue_promotions(lines) + create_checkout_line_discount_objects_for_catalogue_promotions(lines) recalculate_checkout_discount(manager, checkout_info, lines) @@ -692,7 +692,9 @@ def recalculate_checkout_discount( else: remove_voucher_from_checkout(checkout) - create_discount_objects_for_order_promotions(checkout_info, lines, save=True) + create_checkout_discount_objects_for_order_promotions( + checkout_info, lines, save=True + ) def add_promo_code_to_checkout( diff --git a/saleor/discount/interface.py b/saleor/discount/interface.py index 1919f332698..d835c58b2fa 100644 --- a/saleor/discount/interface.py +++ b/saleor/discount/interface.py @@ -53,14 +53,22 @@ class VariantPromotionRuleInfo(NamedTuple): def fetch_variant_rules_info( variant_channel_listing: "ProductVariantChannelListing", translation_language_code: str, -): +) -> list[VariantPromotionRuleInfo]: listings_rules = ( variant_channel_listing.variantlistingpromotionrule.all() if variant_channel_listing else [] ) + rules_info = [] - for listing_promotion_rule in listings_rules: + if listings_rules: + # Before introducing unique_type on discount models, there was possibility + # to have multiple catalogue discount associated with single line. In such a + # case, we should pick the best discount (with the highest discount amount) + listing_promotion_rule = max( + list(listings_rules), + key=lambda x: x.discount_amount, + ) promotion = listing_promotion_rule.promotion_rule.promotion promotion_translation, rule_translation = get_rule_translations( diff --git a/saleor/discount/tests/test_utils/fixtures.py b/saleor/discount/tests/test_utils/fixtures.py index 7cc5a0d4d7d..b3244fb7b1a 100644 --- a/saleor/discount/tests/test_utils/fixtures.py +++ b/saleor/discount/tests/test_utils/fixtures.py @@ -1,7 +1,16 @@ +from decimal import Decimal + +import graphene import pytest +from prices import TaxedMoney from ....checkout.fetch import fetch_checkout_info, fetch_checkout_lines +from ....core.taxes import zero_money +from ....discount import RewardType, RewardValueType +from ....order import OrderStatus from ....plugins.manager import get_plugins_manager +from ....product.models import VariantChannelListingPromotionRule +from ....warehouse.models import Stock @pytest.fixture @@ -51,3 +60,80 @@ def checkout_lines_with_multiple_quantity_info( lines_info, _ = fetch_checkout_lines(checkout_with_items) return lines_info + + +@pytest.fixture +def draft_order_and_promotions( + order_with_lines, + order_promotion_without_rules, + catalogue_promotion_without_rules, + channel_USD, +): + # given + order = order_with_lines + line_1 = order.lines.get(quantity=3) + line_2 = order.lines.get(quantity=2) + + # prepare catalogue promotions + catalogue_promotion = catalogue_promotion_without_rules + variant_1 = line_1.variant + variant_2 = line_2.variant + rule_catalogue = catalogue_promotion.rules.create( + name="Catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant_2.id)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(3), + ) + rule_catalogue.channels.add(channel_USD) + + listing = variant_2.channel_listings.first() + listing.discounted_price_amount = Decimal(17) + listing.save(update_fields=["discounted_price_amount"]) + + currency = order.currency + VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule_catalogue, + discount_amount=Decimal(3), + currency=currency, + ) + + # prepare order promotion - subtotal + order_promotion = order_promotion_without_rules + rule_total = order_promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(25), + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule_total.channels.add(channel_USD) + + # prepare order promotion - gift + rule_gift = order_promotion.rules.create( + name="Gift subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_type=RewardType.GIFT, + ) + rule_gift.channels.add(channel_USD) + rule_gift.gifts.set([variant_1, variant_2]) + Stock.objects.update(quantity=100) + + # reset prices + order.total = TaxedMoney(net=zero_money(currency), gross=zero_money(currency)) + order.subtotal = TaxedMoney(net=zero_money(currency), gross=zero_money(currency)) + order.undiscounted_total = TaxedMoney( + net=zero_money(currency), gross=zero_money(currency) + ) + order.status = OrderStatus.DRAFT + order.save() + + return order, rule_catalogue, rule_total, rule_gift diff --git a/saleor/discount/tests/test_utils/test_copy_unit_discount_data_to_order_line.py b/saleor/discount/tests/test_utils/test_copy_unit_discount_data_to_order_line.py new file mode 100644 index 00000000000..cd832bdb518 --- /dev/null +++ b/saleor/discount/tests/test_utils/test_copy_unit_discount_data_to_order_line.py @@ -0,0 +1,96 @@ +from decimal import Decimal + +import graphene + +from ....order.fetch import fetch_draft_order_lines_info +from ... import DiscountType, DiscountValueType +from ...models import PromotionRule +from ...utils import _copy_unit_discount_data_to_order_line + + +def test_copy_unit_discount_data_to_order_line_multiple_discounts( + order_with_lines_and_catalogue_promotion, +): + # given + order = order_with_lines_and_catalogue_promotion + rule = PromotionRule.objects.get() + rule_reward_value = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + rule_discount_reason = f"Promotion: {promotion_id}" + + line = order.lines.first() + rule_discount = line.discounts.get() + rule_discount.reason = rule_discount_reason + rule_discount.save(update_fields=["reason"]) + + manual_reward_value = Decimal("2") + manual_discount_reason = "Manual discount" + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=manual_reward_value, + amount_value=manual_reward_value * line.quantity, + currency=order.currency, + reason=manual_discount_reason, + ) + + assert line.discounts.count() == 2 + lines_info = fetch_draft_order_lines_info(order) + + # when + _copy_unit_discount_data_to_order_line(lines_info) + + # then + line = lines_info[0].line + assert line.unit_discount_amount == rule_reward_value + manual_reward_value + assert rule_discount_reason in line.unit_discount_reason + assert manual_discount_reason in line.unit_discount_reason + assert line.unit_discount_type == DiscountValueType.FIXED + assert line.unit_discount_value == line.unit_discount_amount + + +def test_copy_unit_discount_data_to_order_line_single_discount( + order_with_lines_and_catalogue_promotion, +): + # given + order = order_with_lines_and_catalogue_promotion + rule = PromotionRule.objects.get() + rule_reward_value = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + rule_discount_reason = f"Promotion: {promotion_id}" + + line = order.lines.first() + rule_discount = line.discounts.get() + rule_discount.reason = rule_discount_reason + rule_discount.save(update_fields=["reason"]) + + assert line.discounts.count() == 1 + lines_info = fetch_draft_order_lines_info(order) + + # when + _copy_unit_discount_data_to_order_line(lines_info) + + # then + line = lines_info[0].line + assert line.unit_discount_amount == rule_reward_value + assert line.unit_discount_reason == rule_discount_reason + assert line.unit_discount_type == rule.reward_value_type + assert line.unit_discount_value == rule_reward_value + + +def test_copy_unit_discount_data_to_order_line_no_discount(order_with_lines): + # given + order = order_with_lines + line = order.lines.first() + assert not line.discounts.exists() + lines_info = fetch_draft_order_lines_info(order) + + # when + _copy_unit_discount_data_to_order_line(lines_info) + + # then + line = lines_info[0].line + assert line.unit_discount_amount == Decimal(0) + assert not line.unit_discount_reason + assert line.unit_discount_type == DiscountValueType.FIXED + assert line.unit_discount_value == Decimal(0) diff --git a/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py index 92d59f356db..2e5b940ab30 100644 --- a/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py +++ b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_checkout.py @@ -21,9 +21,9 @@ from ... import DiscountType, RewardType, RewardValueType from ...models import CheckoutDiscount, CheckoutLineDiscount, PromotionRule from ...utils import ( - _create_or_update_checkout_discount, _get_best_gift_reward, - create_discount_objects_for_order_promotions, + create_checkout_discount_objects_for_order_promotions, + create_checkout_line_discount_objects_for_catalogue_promotions, create_or_update_discount_objects_from_promotion_for_checkout, ) @@ -115,7 +115,14 @@ def test_create_fixed_discount( == discount_from_db.name == f"{catalogue_promotion_without_rules.name}: {rule.name}" ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None assert ( @@ -123,11 +130,90 @@ def test_create_fixed_discount( == discount_from_db.translated_name == promotion_translation_fr.name ) + assert ( + discount_from_info.unique_type + == discount_from_db.unique_type + == DiscountType.PROMOTION + ) for checkout_line_info in checkout_lines_info[1:]: assert not checkout_line_info.discounts +@freeze_time("2020-12-12 12:00:00") +def test_update_catalogue_discount( + checkout_info, + checkout_lines_info, + catalogue_promotion_without_rules, + promotion_translation_fr, +): + # given + line_info1 = checkout_lines_info[0] + product_line1 = line_info1.product + + actual_reward_value = Decimal("5") + discount_to_update = line_info1.line.discounts.create( + type=DiscountType.PROMOTION, + value_type=RewardValueType.FIXED, + value=actual_reward_value, + name="Fixed 5 catalogue discount", + currency=line_info1.channel.currency_code, + amount_value=actual_reward_value * line_info1.line.quantity, + ) + checkout_lines_info[0].discounts.append(discount_to_update) + + reward_value = Decimal("7") + assert reward_value > actual_reward_value + rule = catalogue_promotion_without_rules.rules.create( + name="Percentage promotion rule", + catalogue_predicate={ + "productPredicate": { + "ids": [graphene.Node.to_global_id("Product", product_line1.id)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(line_info1.channel) + + listing = line_info1.channel_listing + discounted_price = listing.price.amount - reward_value + listing.discounted_price_amount = discounted_price + listing.save(update_fields=["discounted_price_amount"]) + + listing_promotion_rule = VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule, + discount_amount=reward_value, + currency=line_info1.channel.currency_code, + ) + line_info1.rules_info = [ + VariantPromotionRuleInfo( + rule=rule, + variant_listing_promotion_rule=listing_promotion_rule, + promotion=catalogue_promotion_without_rules, + promotion_translation=promotion_translation_fr, + rule_translation=None, + ) + ] + + # when + create_or_update_discount_objects_from_promotion_for_checkout( + checkout_info, checkout_lines_info + ) + + # then + assert len(line_info1.discounts) == 1 + assert CheckoutLineDiscount.objects.count() == 1 + + discount = line_info1.discounts[0] + assert discount.id == discount_to_update.id + assert discount.value == reward_value + assert discount.promotion_rule_id == rule.id + assert discount.amount_value == reward_value * line_info1.line.quantity + assert discount.unique_type == DiscountType.PROMOTION + + @freeze_time("2020-12-12 12:00:00") def test_create_fixed_discount_multiple_quantity_in_lines( checkout_info, @@ -203,7 +289,14 @@ def test_create_fixed_discount_multiple_quantity_in_lines( == discount_from_db.name == catalogue_promotion_without_rules.name ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -358,7 +451,14 @@ def test_create_percentage_discount( == discount_from_db.name == f"{catalogue_promotion_without_rules.name}: {rule.name}" ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -441,7 +541,14 @@ def test_create_percentage_discount_multiple_quantity_in_lines( assert discount_from_info.currency == discount_from_db.currency == "USD" discount_name = f"{catalogue_promotion_without_rules.name}: {rule.name}" assert discount_from_info.name == discount_from_db.name == discount_name - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -449,127 +556,6 @@ def test_create_percentage_discount_multiple_quantity_in_lines( assert not checkout_line_info.discounts -def test_create_discount_multiple_rules_applied( - checkout_info, checkout_lines_info, catalogue_promotion_without_rules -): - # given - line_info1 = checkout_lines_info[0] - product_line1 = line_info1.product - - reward_value_1 = Decimal("2") - reward_value_2 = Decimal("10") - rule_1, rule_2 = PromotionRule.objects.bulk_create( - [ - PromotionRule( - name="Percentage promotion rule 1", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.FIXED, - reward_value=reward_value_1, - catalogue_predicate={ - "productPredicate": { - "ids": [graphene.Node.to_global_id("Product", product_line1.id)] - } - }, - ), - PromotionRule( - name="Percentage promotion rule 2", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.PERCENTAGE, - reward_value=reward_value_2, - catalogue_predicate={ - "variantPredicate": { - "ids": [ - graphene.Node.to_global_id( - "ProductVariant", line_info1.variant.id - ) - ] - } - }, - ), - ] - ) - - rule_1.channels.add(line_info1.channel) - rule_2.channels.add(line_info1.channel) - - listing = line_info1.channel_listing - discount_amount_2 = reward_value_2 / 100 * listing.price.amount - discounted_price = listing.price.amount - reward_value_1 - discount_amount_2 - listing.discounted_price_amount = discounted_price - listing.save(update_fields=["discounted_price_amount"]) - - ( - listing_promotion_rule_1, - listing_promotion_rule_2, - ) = VariantChannelListingPromotionRule.objects.bulk_create( - [ - VariantChannelListingPromotionRule( - variant_channel_listing=listing, - promotion_rule=rule_1, - discount_amount=reward_value_1, - currency=line_info1.channel.currency_code, - ), - VariantChannelListingPromotionRule( - variant_channel_listing=listing, - promotion_rule=rule_2, - discount_amount=discount_amount_2, - currency=line_info1.channel.currency_code, - ), - ] - ) - - line_info1.rules_info = [ - VariantPromotionRuleInfo( - rule=rule_1, - variant_listing_promotion_rule=listing_promotion_rule_1, - promotion=catalogue_promotion_without_rules, - promotion_translation=None, - rule_translation=None, - ), - VariantPromotionRuleInfo( - rule=rule_2, - variant_listing_promotion_rule=listing_promotion_rule_2, - promotion=catalogue_promotion_without_rules, - promotion_translation=None, - rule_translation=None, - ), - ] - - # when - create_or_update_discount_objects_from_promotion_for_checkout( - checkout_info, checkout_lines_info - ) - - # then - assert len(line_info1.discounts) == 2 - discount_for_rule_1 = line_info1.line.discounts.get(promotion_rule=rule_1) - discount_for_rule_2 = line_info1.line.discounts.get(promotion_rule=rule_2) - - assert discount_for_rule_1.line == line_info1.line - assert discount_for_rule_2.line == line_info1.line - - assert discount_for_rule_1.type == DiscountType.PROMOTION - assert discount_for_rule_2.type == DiscountType.PROMOTION - - assert discount_for_rule_1.value_type == RewardValueType.FIXED - assert discount_for_rule_2.value_type == RewardValueType.PERCENTAGE - - assert discount_for_rule_1.value == reward_value_1 - assert discount_for_rule_2.value == reward_value_2 - - assert discount_for_rule_1.amount_value == reward_value_1 - assert discount_for_rule_2.amount_value == discount_amount_2 - - assert discount_for_rule_1.currency == "USD" - assert discount_for_rule_2.currency == "USD" - - assert discount_for_rule_1.promotion_rule == rule_1 - assert discount_for_rule_2.promotion_rule == rule_2 - - for checkout_line_info in checkout_lines_info[1:]: - assert not checkout_line_info.discounts - - def test_two_promotions_applied_to_two_different_lines( checkout_info, checkout_lines_info, catalogue_promotion_without_rules ): @@ -691,7 +677,14 @@ def test_two_promotions_applied_to_two_different_lines( == discount_from_db_1.name == f"{catalogue_promotion_without_rules.name}: {rule_1.name}" ) - assert discount_from_info_1.reason == discount_from_db_1.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info_1.reason + == discount_from_db_1.reason + == f"Promotion: {promotion_id}" + ) assert ( discount_from_info_1.promotion_rule == discount_from_db_1.promotion_rule @@ -722,7 +715,14 @@ def test_two_promotions_applied_to_two_different_lines( == discount_from_db_2.name == f"{catalogue_promotion_without_rules.name}: {rule_2.name}" ) - assert discount_from_info_2.reason == discount_from_db_2.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info_2.reason + == discount_from_db_2.reason + == f"Promotion: {promotion_id}" + ) assert ( discount_from_info_2.promotion_rule == discount_from_db_2.promotion_rule @@ -815,7 +815,14 @@ def test_create_percentage_discount_1_cent_variant_on_10_percentage_discount( == discount_from_db.name == f"{catalogue_promotion_without_rules.name}: {rule.name}" ) - assert discount_from_info.reason == discount_from_db.reason is None + promotion_id = graphene.Node.to_global_id( + "Promotion", catalogue_promotion_without_rules.pk + ) + assert ( + discount_from_info.reason + == discount_from_db.reason + == f"Promotion: {promotion_id}" + ) assert discount_from_info.promotion_rule == discount_from_db.promotion_rule == rule assert discount_from_info.voucher == discount_from_db.voucher is None @@ -2103,10 +2110,12 @@ def call_before_creating_discount_object(*args, **kwargs): ) with before_after.before( - "saleor.discount.utils._create_or_update_checkout_discount", + "saleor.discount.utils.create_checkout_discount_objects_for_order_promotions", call_before_creating_discount_object, ): - create_discount_objects_for_order_promotions(checkout_info, checkout_lines_info) + create_checkout_discount_objects_for_order_promotions( + checkout_info, checkout_lines_info + ) # then discounts = list(checkout_info.checkout.discounts.all()) @@ -2114,16 +2123,14 @@ def call_before_creating_discount_object(*args, **kwargs): assert discounts[0].amount_value == reward_value -def test_create_or_update_checkout_discount_race_condition( +def test_create_or_update_order_discount_race_condition( checkout_info, checkout_lines_info, catalogue_promotion_without_rules, ): # given promotion = catalogue_promotion_without_rules - checkout = checkout_info.checkout channel = checkout_info.channel - currency = channel.currency_code reward_value = Decimal("2") rule = promotion.rules.create( @@ -2141,20 +2148,14 @@ def test_create_or_update_checkout_discount_race_condition( rule.channels.add(channel) def call_update(*args, **kwargs): - _create_or_update_checkout_discount( - checkout, + create_checkout_discount_objects_for_order_promotions( checkout_info, checkout_lines_info, - rule, - reward_value, - None, - currency, - promotion, - True, + save=True, ) with before_after.before( - "saleor.discount.utils.get_rule_translations", call_update + "saleor.discount.utils._set_checkout_base_prices", call_update ): call_update() @@ -2163,37 +2164,23 @@ def call_update(*args, **kwargs): assert len(discounts) == 1 -def test_create_or_update_checkout_discount_gift_reward_race_condition( +def test_create_or_update_order_discount_gift_reward_race_condition( checkout_info, checkout_lines_info, gift_promotion_rule, ): # given - rule = gift_promotion_rule - promotion = gift_promotion_rule.promotion checkout = checkout_info.checkout - channel = checkout_info.channel - currency = channel.currency_code - - variants = gift_promotion_rule.gifts.all() - variant_listings = ProductVariantChannelListing.objects.filter(variant__in=variants) - listing = max(list(variant_listings), key=lambda x: x.discounted_price_amount) def call_update(*args, **kwargs): - _create_or_update_checkout_discount( - checkout, + create_checkout_discount_objects_for_order_promotions( checkout_info, checkout_lines_info, - rule, - listing.discounted_price_amount, - listing, - currency, - promotion, - True, + save=True, ) with before_after.before( - "saleor.discount.utils.get_rule_translations", call_update + "saleor.discount.utils._set_checkout_base_prices", call_update ): call_update() @@ -2298,3 +2285,27 @@ def test_get_best_gift_reward_no_variants_in_channel(gift_promotion_rule, channe # then assert rule is None assert listing is None + + +def test_create_checkout_line_discount_objects_for_catalogue_promotions_race_condition( + checkout_with_item_on_promotion, + plugins_manager, +): + # given + checkout = checkout_with_item_on_promotion + CheckoutLineDiscount.objects.all().delete() + + # when + def call_before_creating_catalogue_line_discount(*args, **kwargs): + lines_info, _ = fetch_checkout_lines(checkout) + create_checkout_line_discount_objects_for_catalogue_promotions(lines_info) + + with before_after.before( + "saleor.discount.utils.prepare_line_discount_objects_for_catalogue_promotions", + call_before_creating_catalogue_line_discount, + ): + lines_info, _ = fetch_checkout_lines(checkout) + create_checkout_line_discount_objects_for_catalogue_promotions(lines_info) + + # then + assert CheckoutLineDiscount.objects.count() == 1 diff --git a/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_order.py b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_order.py new file mode 100644 index 00000000000..86a3a48bdb0 --- /dev/null +++ b/saleor/discount/tests/test_utils/test_create_or_update_discount_objects_from_promotion_for_order.py @@ -0,0 +1,523 @@ +from decimal import Decimal + +import graphene + +from ....order.fetch import fetch_draft_order_lines_info +from ....product.models import ( + ProductVariantChannelListing, + VariantChannelListingPromotionRule, +) +from ....warehouse.models import Stock +from ... import DiscountType, RewardType, RewardValueType +from ...models import OrderDiscount, OrderLineDiscount +from ...utils import create_or_update_discount_objects_from_promotion_for_order + + +def test_create_catalogue_discount_fixed( + order_with_lines, + catalogue_promotion_without_rules, +): + # given + order = order_with_lines + promotion = catalogue_promotion_without_rules + channel = order.channel + line_1 = order.lines.get(quantity=3) + + # prepare catalogue promotions + variant_1 = line_1.variant + reward_value = Decimal(3) + rule = promotion.rules.create( + name="Catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant_1.id)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(channel) + + listing = variant_1.channel_listings.get(channel=channel) + undiscounted_price = listing.price_amount + listing.discounted_price_amount = undiscounted_price - reward_value + listing.save(update_fields=["discounted_price_amount"]) + + currency = order.currency + VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule, + discount_amount=reward_value, + currency=currency, + ) + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + discount = OrderLineDiscount.objects.get() + assert discount.line == line_1 + assert discount.promotion_rule == rule + assert discount.type == DiscountType.PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(3) + assert discount.amount_value == reward_value * line_1.quantity == Decimal(9) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + + line = [line_info.line for line_info in lines_info if line_info.line == line_1][0] + assert line.base_unit_price_amount == Decimal(7) + + +def test_create_catalogue_discount_percentage( + order_with_lines, + catalogue_promotion_without_rules, +): + # given + order = order_with_lines + promotion = catalogue_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + channel = order.channel + line_1 = order.lines.get(quantity=3) + + variant_1 = line_1.variant + reward_value = Decimal(50) + rule = promotion.rules.create( + name="Catalogue rule percentage", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant_1.id)] + } + }, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=reward_value, + ) + rule.channels.add(channel) + + listing = variant_1.channel_listings.get(channel=channel) + undiscounted_price = listing.price_amount + discount_amount = undiscounted_price * reward_value / 100 + listing.discounted_price_amount = discount_amount + listing.save(update_fields=["discounted_price_amount"]) + + currency = order.currency + VariantChannelListingPromotionRule.objects.create( + variant_channel_listing=listing, + promotion_rule=rule, + discount_amount=discount_amount, + currency=currency, + ) + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + discount = OrderLineDiscount.objects.get() + assert discount.line == line_1 + assert discount.promotion_rule == rule + assert discount.type == DiscountType.PROMOTION + assert discount.value_type == RewardValueType.PERCENTAGE + assert discount.value == reward_value == Decimal(50) + assert discount.amount_value == discount_amount * line_1.quantity == Decimal(15) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + line = [line_info.line for line_info in lines_info if line_info.line == line_1][0] + assert line.base_unit_price_amount == Decimal(5) + + +def test_create_order_discount_subtotal_fixed( + order_with_lines, order_promotion_without_rules +): + # given + order = order_with_lines + channel = order.channel + promotion = order_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + reward_value = Decimal(25) + rule = promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseTotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(order.channel) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderDiscount.objects.count() == 1 + assert not OrderLineDiscount.objects.exists() + discount = OrderDiscount.objects.get() + assert discount.order == order + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(25) + assert discount.amount_value == reward_value == Decimal(25) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + +def test_create_order_discount_subtotal_percentage( + order_with_lines, order_promotion_without_rules +): + # given + order = order_with_lines + channel = order.channel + promotion = order_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + reward_value = Decimal(50) + rule = promotion.rules.create( + name="Percentage subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"eq": 70}} + }, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=reward_value, + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(order.channel) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderDiscount.objects.count() == 1 + assert not OrderLineDiscount.objects.exists() + discount = OrderDiscount.objects.get() + assert discount.order == order + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value_type == RewardValueType.PERCENTAGE + assert discount.value == reward_value == Decimal(50) + assert discount.amount_value == Decimal(35) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + +def test_create_order_discount_gift( + order_with_lines, order_promotion_without_rules, variant_with_many_stocks +): + # given + order = order_with_lines + variant = variant_with_many_stocks + channel = order.channel + promotion = order_promotion_without_rules + promotion_id = graphene.Node.to_global_id("Promotion", promotion.id) + rule = promotion.rules.create( + name="Gift subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_type=RewardType.GIFT, + ) + rule.channels.add(channel) + rule.gifts.set([variant]) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + lines = order.lines.all() + assert len(lines) == 3 + + gift_line = [line for line in lines if line.is_gift][0] + discount = OrderLineDiscount.objects.get() + assert discount.line == gift_line + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + listing = ProductVariantChannelListing.objects.filter( + channel=channel, variant=variant + ).first() + assert discount.value == listing.price_amount == Decimal(10) + assert discount.amount_value == Decimal(10) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + assert discount.reason == f"Promotion: {promotion_id}" + + assert gift_line.quantity == 1 + assert gift_line.variant == variant + assert gift_line.total_price_gross_amount == Decimal(0) + assert gift_line.total_price_net_amount == Decimal(0) + assert gift_line.undiscounted_total_price_gross_amount == Decimal(0) + assert gift_line.undiscounted_total_price_net_amount == Decimal(0) + assert gift_line.unit_price_gross_amount == Decimal(0) + assert gift_line.unit_price_net_amount == Decimal(0) + assert gift_line.base_unit_price_amount == Decimal(0) + assert gift_line.unit_discount_amount == Decimal(0) + assert gift_line.unit_discount_type == RewardValueType.FIXED + assert gift_line.unit_discount_value == Decimal(0) + + +def test_multiple_rules_subtotal_and_catalogue_discount_applied( + draft_order_and_promotions, +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + lines_info = fetch_draft_order_lines_info(order) + discounted_variant_global_id = rule_catalogue.catalogue_predicate[ + "variantPredicate" + ]["ids"][0] + _, discounted_variant_id = graphene.Node.from_global_id( + discounted_variant_global_id + ) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + order.refresh_from_db() + assert OrderLineDiscount.objects.count() == 1 + line = order.lines.get(variant_id=discounted_variant_id) + catalogue_discount = line.discounts.first() + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.value == Decimal(3) + assert catalogue_discount.value == rule_catalogue.reward_value + assert catalogue_discount.amount_value == Decimal(6) + assert ( + catalogue_discount.amount_value == line.quantity * rule_catalogue.reward_value + ) + assert catalogue_discount.value_type == RewardValueType.FIXED + + assert OrderDiscount.objects.count() == 1 + order_discount = order.discounts.first() + assert order_discount.type == DiscountType.ORDER_PROMOTION + assert order_discount.amount_value == Decimal(25) + assert order_discount.amount_value == rule_total.reward_value + assert order_discount.value_type == RewardValueType.FIXED + + +def test_multiple_rules_gift_and_catalogue_discount_applied(draft_order_and_promotions): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + lines_info = fetch_draft_order_lines_info(order) + rule_total.reward_value = Decimal(0) + rule_total.save(update_fields=["reward_value"]) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + order.refresh_from_db() + # If gift reward applies and gift is discounted by catalogue promotion, + # do not create discount object for catalogue promotion. Instead, create discount + # object for gift promotion and set reward amount to undiscounted price + assert OrderLineDiscount.objects.count() == 2 + lines = order.lines.all() + assert len(lines) == 3 + gift_line = [line for line in lines if line.is_gift][0] + gift_discount = gift_line.discounts.get() + assert gift_discount.type == DiscountType.ORDER_PROMOTION + listing = ProductVariantChannelListing.objects.filter( + channel=order.channel, variant=gift_line.variant + ).first() + assert gift_discount.value == listing.price_amount + assert not gift_discount.value == listing.discounted_price_amount + assert gift_discount.value == Decimal(20) + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + assert not line_1.discounts.exists() + catalogue_discount = line_2.discounts.first() + assert catalogue_discount.type == DiscountType.PROMOTION + + assert not OrderDiscount.objects.exists() + + +def test_multiple_rules_no_discount_applied( + draft_order_and_promotions, product_variant_list +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + rule_total.order_predicate = { + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 100000}}} + } + rule_total.save(update_fields=["order_predicate"]) + rule_gift.order_predicate = { + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 100000}}} + } + rule_gift.save(update_fields=["order_predicate"]) + + line_2 = [line for line in order.lines.all() if line.quantity == 2][0] + discounted_variant = line_2.variant + listing = discounted_variant.channel_listings.get(channel=order.channel) + listing.discounted_price_amount = listing.price_amount + listing.variantlistingpromotionrule.all().delete() + listing.save(update_fields=["discounted_price_amount"]) + rule_catalogue.catalogue_predicate = { + "variantPredicate": { + "ids": [ + graphene.Node.to_global_id("ProductVariant", product_variant_list[0].id) + ] + } + } + rule_catalogue.save(update_fields=["catalogue_predicate"]) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert not OrderLineDiscount.objects.exists() + assert not OrderDiscount.objects.exists() + + +def test_update_catalogue_discount( + order_with_lines_and_catalogue_promotion, catalogue_promotion_without_rules +): + # given + order = order_with_lines_and_catalogue_promotion + promotion = catalogue_promotion_without_rules + + channel = order.channel + line = order.lines.get(quantity=3) + variant = line.variant + assert OrderLineDiscount.objects.count() == 1 + discount_to_update = line.discounts.get() + + reward_value = Decimal(6) + assert reward_value > promotion.rules.first().reward_value + rule = promotion.rules.create( + name="New catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(channel) + + variant_channel_listing = variant.channel_listings.get(channel=channel) + undiscounted_price = variant_channel_listing.price_amount + variant_channel_listing.discounted_price_amount = undiscounted_price - reward_value + variant_channel_listing.save(update_fields=["discounted_price_amount"]) + + variant_rule_listing = variant_channel_listing.variantlistingpromotionrule.get() + variant_rule_listing.discount_amount = reward_value + variant_rule_listing.promotion_rule = rule + variant_rule_listing.save(update_fields=["discount_amount", "promotion_rule"]) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + discount = OrderLineDiscount.objects.get() + assert discount_to_update.id == discount.id + assert discount.line == line + assert discount.promotion_rule == rule + assert discount.type == DiscountType.PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(6) + assert discount.amount_value == reward_value * line.quantity == Decimal(18) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + + line = [line_info.line for line_info in lines_info if line_info.line == line][0] + assert line.base_unit_price_amount == Decimal(4) + + +def test_update_order_discount_subtotal( + order_with_lines_and_order_promotion, order_promotion_without_rules +): + # given + order = order_with_lines_and_order_promotion + channel = order.channel + promotion = order_promotion_without_rules + + reward_value = Decimal(30) + assert reward_value > promotion.rules.first().reward_value + rule = promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseTotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(order.channel) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderDiscount.objects.count() == 1 + discount = order.discounts.get() + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value_type == RewardValueType.FIXED + assert discount.value == reward_value == Decimal(30) + assert discount.amount_value == reward_value == Decimal(30) + assert discount.currency == channel.currency_code + assert discount.name == f"{promotion.name}: {rule.name}" + + +def test_update_gift_discount_new_gift_available( + order_with_lines_and_gift_promotion, product_variant_list, warehouse +): + # given + order = order_with_lines_and_gift_promotion + variant = product_variant_list[0] + Stock.objects.create(product_variant=variant, warehouse=warehouse, quantity=100) + channel = order.channel + current_discount = OrderLineDiscount.objects.get() + rule = current_discount.promotion_rule + + gift_price = Decimal(50) + listing = variant.channel_listings.get(channel=channel) + listing.discounted_price_amount = gift_price + listing.price_amount = gift_price + listing.save(update_fields=["discounted_price_amount", "price_amount"]) + rule.gifts.add(variant) + + lines_info = fetch_draft_order_lines_info(order) + + # when + create_or_update_discount_objects_from_promotion_for_order(order, lines_info) + + # then + assert OrderLineDiscount.objects.count() == 1 + lines = order.lines.all() + assert len(lines) == 3 + + gift_line = [line for line in lines if line.is_gift][0] + discount = OrderLineDiscount.objects.get() + assert discount.line == gift_line + assert discount.promotion_rule == rule + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.value == gift_price == Decimal(50) + assert discount.amount_value == gift_price + assert discount.currency == channel.currency_code + + assert gift_line.quantity == 1 + assert gift_line.variant == variant diff --git a/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py b/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py index a090088b4c2..8e41130a536 100644 --- a/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py +++ b/saleor/discount/tests/test_utils/test_fetch_promotion_rules_for_checkout.py @@ -2,7 +2,7 @@ from ... import RewardType, RewardValueType from ...models import PromotionRule -from ...utils import fetch_promotion_rules_for_checkout +from ...utils import fetch_promotion_rules_for_checkout_or_order def test_fetch_promotion_rules_for_checkout( @@ -23,7 +23,7 @@ def test_fetch_promotion_rules_for_checkout( ) # when - rules_per_promotion_id = fetch_promotion_rules_for_checkout(checkout) + rules_per_promotion_id = fetch_promotion_rules_for_checkout_or_order(checkout) # then assert len(rules_per_promotion_id) == 1 @@ -48,7 +48,7 @@ def test_fetch_promotion_rules_for_checkout_no_matching_rule( ) # when - rules_per_promotion_id = fetch_promotion_rules_for_checkout(checkout) + rules_per_promotion_id = fetch_promotion_rules_for_checkout_or_order(checkout) # then assert not rules_per_promotion_id @@ -77,7 +77,7 @@ def test_fetch_promotion_rules_for_checkout_relevant_channel_only( rule_2.channels.add(checkout_JPY.channel) # when - rules_per_promotion_id = fetch_promotion_rules_for_checkout(checkout_JPY) + rules_per_promotion_id = fetch_promotion_rules_for_checkout_or_order(checkout_JPY) # then assert len(rules_per_promotion_id) == 1 diff --git a/saleor/discount/utils.py b/saleor/discount/utils.py index 8e081480f0d..c3db0fb11a1 100644 --- a/saleor/discount/utils.py +++ b/saleor/discount/utils.py @@ -19,12 +19,14 @@ base_checkout_delivery_price, base_checkout_subtotal, ) -from ..checkout.fetch import CheckoutLineInfo, find_checkout_line_info -from ..checkout.models import Checkout, CheckoutLine +from ..checkout.fetch import CheckoutInfo, CheckoutLineInfo +from ..checkout.models import Checkout from ..core.db.connection import allow_writer from ..core.exceptions import InsufficientStock from ..core.taxes import zero_money from ..core.utils.promo_code import InvalidPromoCode +from ..order.fetch import DraftOrderLineInfo +from ..order.models import Order, OrderLine from ..product.models import ( Product, ProductChannelListing, @@ -44,6 +46,8 @@ CheckoutLineDiscount, DiscountValueType, NotApplicable, + OrderDiscount, + OrderLineDiscount, Promotion, PromotionRule, Voucher, @@ -53,8 +57,6 @@ if TYPE_CHECKING: from ..account.models import User - from ..checkout.fetch import CheckoutInfo - from ..order.models import Order, OrderLine from ..plugins.manager import PluginsManager from ..product.managers import ProductVariantQueryset from ..product.models import VariantChannelListingPromotionRule @@ -305,7 +307,7 @@ def validate_voucher_in_order( def validate_voucher( voucher: "Voucher", - total_price: TaxedMoney, + total_price: Money, quantity: int, customer_email: str, channel: Channel, @@ -355,76 +357,125 @@ def create_or_update_discount_objects_from_promotion_for_checkout( lines_info: Iterable["CheckoutLineInfo"], database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - create_discount_objects_for_catalogue_promotions(lines_info) - create_discount_objects_for_order_promotions( + create_checkout_line_discount_objects_for_catalogue_promotions(lines_info) + create_checkout_discount_objects_for_order_promotions( checkout_info, lines_info, database_connection_name=database_connection_name ) -def create_discount_objects_for_catalogue_promotions( - lines_info: Iterable["CheckoutLineInfo"], +def create_checkout_line_discount_objects_for_catalogue_promotions( + lines_info: Iterable[CheckoutLineInfo], ): - line_discounts_to_create = [] - line_discounts_to_update = [] - line_discount_ids_to_remove = [] + discount_data = prepare_line_discount_objects_for_catalogue_promotions(lines_info) + if not discount_data or not lines_info: + return + + ( + discounts_to_create_inputs, + discounts_to_update, + discount_to_remove, + updated_fields, + ) = discount_data + + new_line_discounts = [] + with allow_writer(): + with transaction.atomic(): + # Protect against potential thread race. CheckoutLine object can have only + # single catalogue discount applied. + checkout_id = lines_info[0].line.checkout_id # type: ignore[index] + _checkout_lock = list( + Checkout.objects.filter(pk=checkout_id).select_for_update(of=(["self"])) + ) + + if discount_ids_to_remove := [ + discount.id for discount in discount_to_remove + ]: + CheckoutLineDiscount.objects.filter( + id__in=discount_ids_to_remove + ).delete() + + if discounts_to_create_inputs: + new_line_discounts = [ + CheckoutLineDiscount(**input) + for input in discounts_to_create_inputs + ] + CheckoutLineDiscount.objects.bulk_create( + new_line_discounts, ignore_conflicts=True + ) + + if discounts_to_update and updated_fields: + CheckoutLineDiscount.objects.bulk_update( + discounts_to_update, updated_fields + ) + + _update_line_info_cached_discounts( + lines_info, new_line_discounts, discounts_to_update, discount_ids_to_remove + ) + + +def prepare_line_discount_objects_for_catalogue_promotions( + lines_info: Union[Iterable["CheckoutLineInfo"], Iterable["DraftOrderLineInfo"]], +): + line_discounts_to_create_inputs: list[dict] = [] + line_discounts_to_update: list[Union[CheckoutLineDiscount, OrderLineDiscount]] = [] + line_discounts_to_remove: list[Union[CheckoutLineDiscount, OrderLineDiscount]] = [] updated_fields: list[str] = [] + if not lines_info: + return + for line_info in lines_info: line = line_info.line + # get the existing catalogue discount for the line + discount_to_update = None + if discounts_to_update := line_info.get_catalogue_discounts(): + discount_to_update = discounts_to_update[0] + # Line should never have multiple catalogue discounts associated. Before + # introducing unique_type on discount models, there was such a possibility. + line_discounts_to_remove.extend(discounts_to_update[1:]) + + # manual line discount do not stack with other discounts + if [ + discount + for discount in line_info.discounts + if discount.type == DiscountType.MANUAL + ]: + line_discounts_to_remove.extend(discounts_to_update) + continue + # discount_amount based on the difference between discounted_price and price discount_amount = _get_discount_amount(line_info.channel_listing, line.quantity) - # get the existing discounts for the line - discounts_to_update = line_info.get_catalogue_discounts() - rule_id_to_discount = { - discount.promotion_rule_id: discount for discount in discounts_to_update - } - # delete all existing discounts if the line is not discounted or it is a gift if not discount_amount or line.is_gift: - ids_to_remove = [discount.id for discount in discounts_to_update] - if ids_to_remove: - line_discount_ids_to_remove.extend(ids_to_remove) - line_info.discounts = [ - discount - for discount in line_info.discounts - if discount.id not in ids_to_remove - ] + line_discounts_to_remove.extend(discounts_to_update) continue - # delete the discount objects that are not valid anymore - line_discount_ids_to_remove.extend( - _get_discounts_that_are_not_valid_anymore( - line_info.rules_info, - rule_id_to_discount, # type: ignore[arg-type] - line_info, - ) - ) - - for rule_info in line_info.rules_info: + if line_info.rules_info: + rule_info = line_info.rules_info[0] rule = rule_info.rule - discount_to_update = rule_id_to_discount.get(rule.id) rule_discount_amount = _get_rule_discount_amount( rule_info.variant_listing_promotion_rule, line.quantity ) discount_name = get_discount_name(rule, rule_info.promotion) translated_name = get_discount_translated_name(rule_info) + reason = _get_discount_reason(rule) if not discount_to_update: - line_discount = CheckoutLineDiscount( - line=line, - type=DiscountType.PROMOTION, - value_type=rule.reward_value_type, - value=rule.reward_value, - amount_value=rule_discount_amount, - currency=line.currency, - name=discount_name, - translated_name=translated_name, - reason=None, - promotion_rule=rule, - ) - line_discounts_to_create.append(line_discount) - line_info.discounts.append(line_discount) + line_discount_input = { + "line": line, + "type": DiscountType.PROMOTION, + "value_type": rule.reward_value_type, + "value": rule.reward_value, + "amount_value": rule_discount_amount, + "currency": line.currency, + "name": discount_name, + "translated_name": translated_name, + "reason": reason, + "promotion_rule": rule, + "unique_type": DiscountType.PROMOTION, + } + line_discounts_to_create_inputs.append(line_discount_input) else: _update_discount( rule, @@ -433,20 +484,17 @@ def create_discount_objects_for_catalogue_promotions( discount_to_update, updated_fields, ) - line_discounts_to_update.append(discount_to_update) + else: + # Fallback for unlike mismatch between discount_amount and rules_info + line_discounts_to_remove.extend(discounts_to_update) - with allow_writer(): - if line_discounts_to_create: - CheckoutLineDiscount.objects.bulk_create(line_discounts_to_create) - if line_discounts_to_update and updated_fields: - CheckoutLineDiscount.objects.bulk_update( - line_discounts_to_update, updated_fields - ) - if line_discount_ids_to_remove: - CheckoutLineDiscount.objects.filter( - id__in=line_discount_ids_to_remove - ).delete() + return ( + line_discounts_to_create_inputs, + line_discounts_to_update, + line_discounts_to_remove, + updated_fields, + ) def _get_discount_amount( @@ -466,20 +514,6 @@ def _get_discount_amount( return unit_discount * line_quantity -def _get_discounts_that_are_not_valid_anymore( - rules_info: list["VariantPromotionRuleInfo"], - rule_id_to_discount: dict[int, "CheckoutLineDiscount"], - line_info: "CheckoutLineInfo", -): - discount_ids = [] - rule_ids = {rule_info.rule.id for rule_info in rules_info} - for rule_id, discount in rule_id_to_discount.items(): - if rule_id not in rule_ids: - discount_ids.append(discount.id) - line_info.discounts.remove(discount) - return discount_ids - - def _get_rule_discount_amount( variant_listing_promotion_rule: Optional["VariantChannelListingPromotionRule"], line_quantity: int, @@ -496,6 +530,13 @@ def get_discount_name(rule: "PromotionRule", promotion: "Promotion"): return rule.name or promotion.name +def _get_discount_reason(rule: PromotionRule): + promotion = rule.promotion + if promotion.old_sale_id: + return f"Sale: {graphene.Node.to_global_id('Sale', promotion.old_sale_id)}" + return f"Promotion: {graphene.Node.to_global_id('Promotion', promotion.id)}" + + def get_discount_translated_name(rule_info: "VariantPromotionRuleInfo"): promotion_translation = rule_info.promotion_translation rule_translation = rule_info.rule_translation @@ -512,7 +553,9 @@ def _update_discount( rule: "PromotionRule", rule_info: "VariantPromotionRuleInfo", rule_discount_amount: Decimal, - discount_to_update: Union["CheckoutLineDiscount", "CheckoutDiscount"], + discount_to_update: Union[ + "CheckoutLineDiscount", "CheckoutDiscount", "OrderLineDiscount", "OrderDiscount" + ], updated_fields: list[str], ): if discount_to_update.promotion_rule_id != rule.id: @@ -545,9 +588,33 @@ def _update_discount( if discount_to_update.reason != reason: discount_to_update.reason = reason updated_fields.append("reason") + if hasattr(discount_to_update, "unique_type"): + if discount_to_update.unique_type is None: + discount_to_update.unique_type = DiscountType.PROMOTION + updated_fields.append("unique_type") -def create_discount_objects_for_order_promotions( +def _update_line_info_cached_discounts( + lines_info, new_line_discounts, updated_discounts, line_discount_ids_to_remove +): + if not any([new_line_discounts, updated_discounts, line_discount_ids_to_remove]): + return + + line_id_line_discounts_map = defaultdict(list) + for line_discount in new_line_discounts: + line_id_line_discounts_map[line_discount.line_id].append(line_discount) + + for line_info in lines_info: + line_info.discounts = [ + discount + for discount in line_info.discounts + if discount.id not in line_discount_ids_to_remove + ] + if discount := line_id_line_discounts_map.get(line_info.line.id): + line_info.discounts.extend(discount) + + +def create_checkout_discount_objects_for_order_promotions( checkout_info: "CheckoutInfo", lines_info: Iterable["CheckoutLineInfo"], *, @@ -565,42 +632,77 @@ def create_discount_objects_for_order_promotions( return channel = checkout_info.channel - rule_data = get_best_rule_for_checkout( - checkout, channel, checkout_info.get_country(), database_connection_name + rules = fetch_promotion_rules_for_checkout_or_order( + checkout, database_connection_name + ) + rule_data = get_best_rule( + rules=rules, + channel=channel, + country=checkout_info.get_country(), + subtotal=checkout.base_subtotal, + database_connection_name=database_connection_name, ) if not rule_data: _clear_checkout_discount(checkout_info, lines_info, save) return best_rule, best_discount_amount, gift_listing = rule_data - - _create_or_update_checkout_discount( - checkout, - checkout_info, - lines_info, - best_rule, - best_discount_amount, - gift_listing, - channel.currency_code, - best_rule.promotion, - save, + promotion = best_rule.promotion + currency = channel.currency_code + translation_language_code = checkout.language_code + promotion_translation, rule_translation = get_rule_translations( + promotion, best_rule, translation_language_code + ) + rule_info = VariantPromotionRuleInfo( + rule=best_rule, + variant_listing_promotion_rule=None, + promotion=promotion, + promotion_translation=promotion_translation, + rule_translation=rule_translation, ) + # gift rule has empty reward_value and reward_value_type + value_type = best_rule.reward_value_type or RewardValueType.FIXED + amount_value = gift_listing.price_amount if gift_listing else best_discount_amount + value = best_rule.reward_value or amount_value + discount_object_defaults = { + "promotion_rule": best_rule, + "value_type": value_type, + "value": value, + "amount_value": amount_value, + "currency": currency, + "name": get_discount_name(best_rule, promotion), + "translated_name": get_discount_translated_name(rule_info), + "reason": prepare_promotion_discount_reason(promotion, get_sale_id(promotion)), + } + if gift_listing: + _handle_gift_reward_for_checkout( + checkout_info, + lines_info, + gift_listing, + discount_object_defaults, + rule_info, + save, + ) + else: + _handle_order_promotion_for_checkout( + checkout_info, + lines_info, + discount_object_defaults, + rule_info, + save, + ) -def get_best_rule_for_checkout( - checkout: "Checkout", +def get_best_rule( + rules: Iterable["PromotionRule"], channel: "Channel", country: str, + subtotal: Money, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): RuleDiscount = namedtuple( "RuleDiscount", ["rule", "discount_amount", "gift_listing"] ) - subtotal = checkout.base_subtotal - rules = fetch_promotion_rules_for_checkout(checkout, database_connection_name) - if not rules: - return - currency_code = channel.currency_code rule_discounts: list[RuleDiscount] = [] gift_rules = [rule for rule in rules if rule.reward_type == RewardType.GIFT] @@ -615,12 +717,14 @@ def get_best_rule_for_checkout( rule_discounts.append(RuleDiscount(rule, discount_amount, None)) if gift_rules: - rule, gift_listing = _get_best_gift_reward( + best_gift_rule, gift_listing = _get_best_gift_reward( gift_rules, channel, country, database_connection_name ) - if rule and gift_listing: + if best_gift_rule and gift_listing: rule_discounts.append( - RuleDiscount(rule, gift_listing.discounted_price_amount, gift_listing) + RuleDiscount( + best_gift_rule, gift_listing.discounted_price_amount, gift_listing + ) ) if not rule_discounts: @@ -785,72 +889,14 @@ def _get_available_for_purchase_variant_ids( return set(available_variant_ids) -def _create_or_update_checkout_discount( - checkout: "Checkout", - checkout_info: "CheckoutInfo", - lines_info: Iterable["CheckoutLineInfo"], - best_rule: "PromotionRule", - best_discount_amount: Decimal, - gift_listing: Optional[ProductVariantChannelListing], - currency_code: str, - promotion: "Promotion", - save: bool, -): - translation_language_code = checkout.language_code - promotion_translation, rule_translation = get_rule_translations( - promotion, best_rule, translation_language_code - ) - rule_info = VariantPromotionRuleInfo( - rule=best_rule, - variant_listing_promotion_rule=None, - promotion=promotion, - promotion_translation=promotion_translation, - rule_translation=rule_translation, - ) - # gift rule has empty reward_value and reward_value_type - value_type = best_rule.reward_value_type or RewardValueType.FIXED - amount_value = gift_listing.price_amount if gift_listing else best_discount_amount - value = best_rule.reward_value or amount_value - discount_object_defaults = { - "promotion_rule": best_rule, - "value_type": value_type, - "value": value, - "amount_value": amount_value, - "currency": currency_code, - "name": get_discount_name(best_rule, promotion), - "translated_name": get_discount_translated_name(rule_info), - "reason": prepare_promotion_discount_reason(promotion, get_sale_id(promotion)), - } - if gift_listing: - _handle_gift_reward( - checkout, - checkout_info, - lines_info, - gift_listing, - discount_object_defaults, - rule_info, - save, - ) - else: - _handle_order_promotion( - checkout, - checkout_info, - lines_info, - discount_object_defaults, - rule_info, - save, - ) - - -@allow_writer() -def _handle_order_promotion( - checkout: "Checkout", - checkout_info: "CheckoutInfo", - lines_info: Iterable["CheckoutLineInfo"], +def _handle_order_promotion_for_checkout( + checkout_info: CheckoutInfo, + lines_info: Iterable[CheckoutLineInfo], discount_object_defaults: dict, rule_info: VariantPromotionRuleInfo, - save: bool, + save: bool = False, ): + checkout = checkout_info.checkout discount_object, created = checkout.discounts.get_or_create( type=DiscountType.ORDER_PROMOTION, defaults=discount_object_defaults, @@ -870,6 +916,7 @@ def _handle_order_promotion( discount_object.save(update_fields=fields_to_update) checkout_info.discounts = [discount_object] + checkout = checkout_info.checkout checkout.discount_amount = discount_amount checkout.discount_name = discount_object.name checkout.translated_discount_name = discount_object.translated_name @@ -885,27 +932,29 @@ def _handle_order_promotion( delete_gift_line(checkout, lines_info) -def delete_gift_line(checkout: "Checkout", lines_info: Iterable["CheckoutLineInfo"]): +def delete_gift_line( + order_or_checkout: Union[Checkout, Order], + lines_info: Iterable[Union["CheckoutLineInfo", "DraftOrderLineInfo"]], +): if gift_line_infos := [line for line in lines_info if line.line.is_gift]: - CheckoutLine.objects.filter(checkout_id=checkout.pk, is_gift=True).delete() + order_or_checkout.lines.filter(is_gift=True).delete() # type: ignore[misc] for gift_line_info in gift_line_infos: lines_info.remove(gift_line_info) # type: ignore[attr-defined] @allow_writer() -def _handle_gift_reward( - checkout: "Checkout", - checkout_info: "CheckoutInfo", - lines_info: Iterable["CheckoutLineInfo"], +def _handle_gift_reward_for_checkout( + checkout_info: CheckoutInfo, + lines_info: Iterable[CheckoutLineInfo], gift_listing: ProductVariantChannelListing, discount_object_defaults: dict, rule_info: VariantPromotionRuleInfo, - save: bool, + save: bool = False, ): with transaction.atomic(): - line, line_created = create_gift_line(checkout, gift_listing.variant_id) - line_discount = None - discount_created = False + line, line_created = create_gift_line( + checkout_info.checkout, gift_listing.variant_id + ) ( line_discount, discount_created, @@ -931,38 +980,38 @@ def _handle_gift_reward( line_discount.save(update_fields=fields_to_update) checkout_info.discounts = [] - checkout.discount_amount = Decimal("0") + checkout_info.checkout.discount_amount = Decimal("0") if save: - checkout.save(update_fields=["discount_amount"]) + checkout_info.checkout.save(update_fields=["discount_amount"]) if line_created: variant = gift_listing.variant - gift_line_info = CheckoutLineInfo( - line=line, - variant=variant, - channel_listing=gift_listing, - product=variant.product, - product_type=variant.product.product_type, - collections=[], - discounts=[line_discount], - rules_info=[rule_info], - channel=checkout_info.channel, - ) + init_values = { + "line": line, + "variant": variant, + "channel_listing": gift_listing, + "discounts": [line_discount], + "rules_info": [rule_info], + "channel": checkout_info.channel, + "product": variant.product, + "product_type": variant.product.product_type, + "collections": [], + } + + gift_line_info = CheckoutLineInfo(**init_values) lines_info.append(gift_line_info) # type: ignore[attr-defined] else: - line_info = find_checkout_line_info(lines_info, line.id) + line_info = next( + line_info for line_info in lines_info if line_info.line.pk == line.id + ) line_info.line = line line_info.discounts = [line_discount] -def create_gift_line(checkout: "Checkout", variant_id: int): - defaults = { - "variant_id": variant_id, - "quantity": 1, - "currency": checkout.currency, - } - line, created = CheckoutLine.objects.get_or_create( - checkout=checkout, is_gift=True, defaults=defaults +def create_gift_line(order_or_checkout: Union[Checkout, Order], variant_id: int): + defaults = _get_defaults_for_gift_line(order_or_checkout, variant_id) + line, created = order_or_checkout.lines.get_or_create( + is_gift=True, defaults=defaults ) if not created: fields_to_update = [] @@ -976,6 +1025,29 @@ def create_gift_line(checkout: "Checkout", variant_id: int): return line, created +def _get_defaults_for_gift_line( + order_or_checkout: Union[Checkout, Order], variant_id: int +): + if isinstance(order_or_checkout, Checkout): + return { + "variant_id": variant_id, + "quantity": 1, + "currency": order_or_checkout.currency, + } + else: + return { + "variant_id": variant_id, + "quantity": 1, + "currency": order_or_checkout.currency, + "unit_price_net_amount": Decimal(0), + "unit_price_gross_amount": Decimal(0), + "total_price_net_amount": Decimal(0), + "total_price_gross_amount": Decimal(0), + "is_shipping_required": True, + "is_gift_card": False, + } + + def get_variants_to_promotion_rules_map( variant_qs: "ProductVariantQueryset", ) -> dict[int, list[PromotionRuleInfo]]: @@ -1024,39 +1096,43 @@ def get_variants_to_promotion_rules_map( return rules_info_per_variant -def fetch_promotion_rules_for_checkout( - checkout: Checkout, +def fetch_promotion_rules_for_checkout_or_order( + instance: Union["Checkout", "Order"], database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): from ..graphql.discount.utils import PredicateObjectType, filter_qs_by_predicate applicable_rules = [] promotions = Promotion.objects.active() - checkout_channel_id = checkout.channel_id - PromotionRuleChannels = PromotionRule.channels.through.objects.filter( - channel_id=checkout_channel_id - ) rules = ( PromotionRule.objects.using(database_connection_name) - .filter( - Exists(promotions.filter(id=OuterRef("promotion_id"))), - Exists(PromotionRuleChannels.filter(promotionrule_id=OuterRef("id"))), - ) + .filter(Exists(promotions.filter(id=OuterRef("promotion_id")))) .exclude(order_predicate={}) + .prefetch_related("channels") ) + rule_to_channel_ids_map = _get_rule_to_channel_ids_map(rules) - currency = checkout.currency - checkout_qs = Checkout.objects.using(database_connection_name).filter( - pk=checkout.pk + channel_id = instance.channel_id + currency = instance.channel.currency_code + qs = instance._meta.model.objects.using(database_connection_name).filter( # type: ignore[attr-defined] # noqa: E501 + pk=instance.pk ) for rule in rules.iterator(): - checkouts = filter_qs_by_predicate( + rule_channel_ids = rule_to_channel_ids_map.get(rule.id, []) + if channel_id not in rule_channel_ids: + continue + predicate_type = ( + PredicateObjectType.CHECKOUT + if isinstance(instance, Checkout) + else PredicateObjectType.ORDER + ) + objects = filter_qs_by_predicate( rule.order_predicate, - checkout_qs, - PredicateObjectType.CHECKOUT, + qs, + predicate_type, currency, ) - if checkouts.exists(): + if objects.exists(): applicable_rules.append(rule) return applicable_rules @@ -1132,6 +1208,299 @@ def update_rule_variant_relation( ) +def create_or_update_discount_objects_from_promotion_for_order( + order: "Order", + lines_info: Iterable["DraftOrderLineInfo"], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +): + create_order_line_discount_objects_for_catalogue_promotions(lines_info) + create_order_discount_objects_for_order_promotions( + order, lines_info, database_connection_name=database_connection_name + ) + _copy_unit_discount_data_to_order_line(lines_info) + + +def create_order_line_discount_objects_for_catalogue_promotions( + lines_info: Iterable[DraftOrderLineInfo], +): + discount_data = prepare_line_discount_objects_for_catalogue_promotions(lines_info) + if not discount_data or not lines_info: + return + + ( + discounts_to_create_inputs, + discounts_to_update, + discount_to_remove, + updated_fields, + ) = discount_data + + new_line_discounts = [] + with allow_writer(): + with transaction.atomic(): + # Protect against potential thread race. OrderLine object can have only + # single catalogue discount applied. + order_id = lines_info[0].line.order_id # type: ignore[index] + _order_lock = list( + Order.objects.filter(id=order_id).select_for_update(of=(["self"])) + ) + + if discount_ids_to_remove := [ + discount.id for discount in discount_to_remove + ]: + OrderLineDiscount.objects.filter(id__in=discount_ids_to_remove).delete() + + if discounts_to_create_inputs: + new_line_discounts = [ + OrderLineDiscount(**input) for input in discounts_to_create_inputs + ] + OrderLineDiscount.objects.bulk_create( + new_line_discounts, ignore_conflicts=True + ) + + if discounts_to_update and updated_fields: + OrderLineDiscount.objects.bulk_update( + discounts_to_update, updated_fields + ) + + _update_line_info_cached_discounts( + lines_info, new_line_discounts, discounts_to_update, discount_ids_to_remove + ) + + affected_line_ids = [ + discount_line.line.id + for discount_line in new_line_discounts + + discounts_to_update + + discount_to_remove + ] + modified_lines_info = [ + line_info for line_info in lines_info if line_info.line.id in affected_line_ids + ] + # base unit price must reflect all actual catalogue discounts + _update_base_unit_price_amount(modified_lines_info) + + +def _copy_unit_discount_data_to_order_line(lines_info: Iterable[DraftOrderLineInfo]): + for line_info in lines_info: + if discounts := line_info.discounts: + line = line_info.line + discount_amount = sum([discount.amount_value for discount in discounts]) + unit_discount_amount = discount_amount / line.quantity + discount_reason = "; ".join( + [discount.reason for discount in discounts if discount.reason] + ) + discount_type = ( + discounts[0].value_type + if len(discounts) == 1 + else DiscountValueType.FIXED + ) + discount_value = ( + discounts[0].value if len(discounts) == 1 else unit_discount_amount + ) + + line.unit_discount_amount = unit_discount_amount + line.unit_discount_reason = discount_reason + line.unit_discount_type = discount_type + line.unit_discount_value = discount_value + + +def _update_base_unit_price_amount(lines_info: Iterable[DraftOrderLineInfo]): + for line_info in lines_info: + line = line_info.line + base_unit_price = line.undiscounted_base_unit_price_amount + for discount in line_info.discounts: + unit_discount = discount.amount_value / line.quantity + base_unit_price -= unit_discount + line.base_unit_price_amount = max(base_unit_price, Decimal(0)) + + +def create_order_discount_objects_for_order_promotions( + order: "Order", + lines_info: Iterable["DraftOrderLineInfo"], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +): + from ..order.base_calculations import base_order_subtotal + from ..order.utils import get_order_country + + # If voucher is set or manual discount applied, then skip order promotions + if order.voucher_code or order.discounts.filter(type=DiscountType.MANUAL): + _clear_order_discount(order, lines_info) + return + + # The base prices are required for order promotion discount qualification. + _set_order_base_prices(order, lines_info) + + lines = [line_info.line for line_info in lines_info] + subtotal = base_order_subtotal(order, lines) + channel = order.channel + rules = fetch_promotion_rules_for_checkout_or_order(order, database_connection_name) + rule_data = get_best_rule( + rules=rules, + channel=channel, + country=get_order_country(order), + subtotal=subtotal, + database_connection_name=database_connection_name, + ) + if not rule_data: + _clear_order_discount(order, lines_info) + return + + best_rule, best_discount_amount, gift_listing = rule_data + promotion = best_rule.promotion + currency = channel.currency_code + translation_language_code = order.language_code + promotion_translation, rule_translation = get_rule_translations( + promotion, best_rule, translation_language_code + ) + rule_info = VariantPromotionRuleInfo( + rule=best_rule, + variant_listing_promotion_rule=None, + promotion=best_rule.promotion, + promotion_translation=promotion_translation, + rule_translation=rule_translation, + ) + # gift rule has empty reward_value and reward_value_type + value_type = best_rule.reward_value_type or RewardValueType.FIXED + amount_value = gift_listing.price_amount if gift_listing else best_discount_amount + value = best_rule.reward_value or amount_value + discount_object_defaults = { + "promotion_rule": best_rule, + "value_type": value_type, + "value": value, + "amount_value": amount_value, + "currency": currency, + "name": get_discount_name(best_rule, promotion), + "translated_name": get_discount_translated_name(rule_info), + "reason": prepare_promotion_discount_reason(promotion, get_sale_id(promotion)), + } + if gift_listing: + _handle_gift_reward_for_order( + order, + lines_info, + gift_listing, + discount_object_defaults, + rule_info, + ) + else: + _handle_order_promotion_for_order( + order, + lines_info, + discount_object_defaults, + rule_info, + ) + + +def _clear_order_discount( + order_or_checkout: Union[Checkout, Order], + lines_info: Iterable[DraftOrderLineInfo], +): + with transaction.atomic(): + delete_gift_line(order_or_checkout, lines_info) + order_or_checkout.discounts.filter(type=DiscountType.ORDER_PROMOTION).delete() + + +def _set_order_base_prices(order: Order, lines_info: Iterable[DraftOrderLineInfo]): + """Set base order prices that includes only catalogue discounts.""" + from ..order.base_calculations import base_order_subtotal + + lines = [line_info.line for line_info in lines_info] + subtotal = base_order_subtotal(order, lines) + shipping_price = order.base_shipping_price + total = subtotal + shipping_price + + update_fields = [] + if order.subtotal != TaxedMoney(net=subtotal, gross=subtotal): + order.subtotal = TaxedMoney(net=subtotal, gross=subtotal) + update_fields.extend(["subtotal_net_amount", "subtotal_gross_amount"]) + if order.total != TaxedMoney(net=total, gross=total): + order.total = TaxedMoney(net=total, gross=total) + update_fields.extend(["total_net_amount", "total_gross_amount"]) + + if update_fields: + with allow_writer(): + order.save(update_fields=update_fields) + + +def _handle_order_promotion_for_order( + order: Order, + lines_info: Iterable[DraftOrderLineInfo], + discount_object_defaults: dict, + rule_info: VariantPromotionRuleInfo, +): + discount_object, created = order.discounts.get_or_create( + type=DiscountType.ORDER_PROMOTION, + defaults=discount_object_defaults, + ) + discount_amount = discount_object_defaults["amount_value"] + + if not created: + fields_to_update: list[str] = [] + _update_discount( + discount_object_defaults["promotion_rule"], + rule_info, + discount_amount, + discount_object, + fields_to_update, + ) + if fields_to_update: + discount_object.save(update_fields=fields_to_update) + + delete_gift_line(order, lines_info) + + +@allow_writer() +def _handle_gift_reward_for_order( + order: Order, + lines_info: Iterable[DraftOrderLineInfo], + gift_listing: ProductVariantChannelListing, + discount_object_defaults: dict, + rule_info: VariantPromotionRuleInfo, +): + with transaction.atomic(): + line, line_created = create_gift_line(order, gift_listing.variant_id) + ( + line_discount, + discount_created, + ) = OrderLineDiscount.objects.get_or_create( + type=DiscountType.ORDER_PROMOTION, + line=line, + defaults=discount_object_defaults, + ) + + if not discount_created: + fields_to_update = [] + if line_discount.line_id != line.id: + line_discount.line = line + fields_to_update.append("line_id") + _update_discount( + discount_object_defaults["promotion_rule"], + rule_info, + discount_object_defaults["amount_value"], + line_discount, + fields_to_update, + ) + if fields_to_update: + line_discount.save(update_fields=fields_to_update) + + if line_created: + variant = gift_listing.variant + init_values = { + "line": line, + "variant": variant, + "channel_listing": gift_listing, + "discounts": [line_discount], + "rules_info": [rule_info], + "channel": order.channel_id, + } + gift_line_info = DraftOrderLineInfo(**init_values) + lines_info.append(gift_line_info) # type: ignore[attr-defined] + else: + line_info = next( + line_info for line_info in lines_info if line_info.line.pk == line.id + ) + line_info.line = line + line_info.discounts = [line_discount] + + def get_active_catalogue_promotion_rules( allow_replica: bool = False, ) -> "QuerySet[PromotionRule]": diff --git a/saleor/graphql/checkout/mutations/utils.py b/saleor/graphql/checkout/mutations/utils.py index c80c36cfb69..c39a3a2d83c 100644 --- a/saleor/graphql/checkout/mutations/utils.py +++ b/saleor/graphql/checkout/mutations/utils.py @@ -33,7 +33,11 @@ from ....core.exceptions import InsufficientStock, PermissionDenied from ....discount import DiscountType, DiscountValueType from ....discount.models import CheckoutLineDiscount, PromotionRule -from ....discount.utils import create_gift_line, get_best_rule_for_checkout +from ....discount.utils import ( + create_gift_line, + fetch_promotion_rules_for_checkout_or_order, + get_best_rule, +) from ....permission.enums import CheckoutPermissions from ....product import models as product_models from ....product.models import ProductChannelListing, ProductVariant @@ -442,6 +446,7 @@ def group_lines_input_data_on_update( variant_id = cast(str, line.get("variant_id")) line_id = cast(str, line.get("line_id")) + line_db_id, variant_db_id = None, None if line_id: _, line_db_id = graphene.Node.from_global_id(line_id) @@ -452,7 +457,7 @@ def group_lines_input_data_on_update( ) if not line_db_id: - line_data = checkout_lines_data_map[variant_db_id] + line_data = checkout_lines_data_map[variant_db_id] # type: ignore[index] line_data.variant_id = variant_db_id else: line_data = checkout_lines_data_map[line_db_id] @@ -532,7 +537,7 @@ def find_variant_id_when_line_parameter_used( def apply_gift_reward_if_applicable_on_checkout_creation( - checkout: "Checkout", + checkout: "models.Checkout", database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ) -> None: """Apply gift reward if applicable on newly created checkout. @@ -549,9 +554,13 @@ def apply_gift_reward_if_applicable_on_checkout_creation( return _set_checkout_base_subtotal_and_total_on_checkout_creation(checkout) - - best_rule_data = get_best_rule_for_checkout( - checkout, checkout.channel, checkout.get_country(), database_connection_name + rules = fetch_promotion_rules_for_checkout_or_order(checkout) + best_rule_data = get_best_rule( + rules, + checkout.channel, + checkout.get_country(), + checkout.base_subtotal, + database_connection_name, ) if not best_rule_data: return @@ -560,22 +569,21 @@ def apply_gift_reward_if_applicable_on_checkout_creation( if not gift_listing: return - amount_value = gift_listing.price_amount with transaction.atomic(): line, _line_created = create_gift_line(checkout, gift_listing.variant_id) CheckoutLineDiscount.objects.create( type=DiscountType.ORDER_PROMOTION, line=line, - amount_value=amount_value, + amount_value=best_discount_amount, value_type=DiscountValueType.FIXED, - value=amount_value, + value=best_discount_amount, promotion_rule=best_rule, currency=checkout.currency, ) def _set_checkout_base_subtotal_and_total_on_checkout_creation( - checkout: "Checkout", + checkout: "models.Checkout", ): """Calculate and set base subtotal and total for newly created checkout.""" variants_id = [line.variant_id for line in checkout.lines.all()] diff --git a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py index 802fbe36d7f..e0cba93fdc0 100644 --- a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py +++ b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py @@ -9,8 +9,12 @@ from .....checkout.fetch import fetch_checkout_info, fetch_checkout_lines from .....checkout.models import Checkout from .....checkout.utils import add_variants_to_checkout, set_external_shipping_id +from .....discount import RewardValueType +from .....discount.models import CheckoutLineDiscount, PromotionRule from .....plugins.manager import get_plugins_manager -from .....product.models import ProductVariant, ProductVariantChannelListing +from .....product.models import Product, ProductVariant, ProductVariantChannelListing +from .....product.utils.variant_prices import update_discounted_prices_for_promotion +from .....product.utils.variants import fetch_variants_for_promotion_rules from .....warehouse.models import Stock from ....core.utils import to_global_id_or_none from ....tests.utils import get_graphql_content @@ -413,7 +417,7 @@ def test_create_checkout_with_reservations( } } - with django_assert_num_queries(66): + with django_assert_num_queries(71): response = api_client.post_graphql(query, variables) assert get_graphql_content(response)["data"]["checkoutCreate"] assert Checkout.objects.first().lines.count() == 1 @@ -431,7 +435,7 @@ def test_create_checkout_with_reservations( } } - with django_assert_num_queries(66): + with django_assert_num_queries(71): response = api_client.post_graphql(query, variables) assert get_graphql_content(response)["data"]["checkoutCreate"] assert Checkout.objects.first().lines.count() == 10 @@ -562,7 +566,7 @@ def test_create_checkout_with_order_promotion( } # when - with django_assert_num_queries(71): + with django_assert_num_queries(76): response = user_api_client.post_graphql(MUTATION_CHECKOUT_CREATE, variables) # then @@ -817,7 +821,7 @@ def test_update_checkout_lines_with_reservations( reservation_length=5, ) - with django_assert_num_queries(81): + with django_assert_num_queries(90): variant_id = graphene.Node.to_global_id("ProductVariant", variants[0].pk) variables = { "id": to_global_id_or_none(checkout), @@ -831,7 +835,7 @@ def test_update_checkout_lines_with_reservations( assert not data["errors"] # Updating multiple lines in checkout has same query count as updating one - with django_assert_num_queries(81): + with django_assert_num_queries(90): variables = { "id": to_global_id_or_none(checkout), "lines": [], @@ -1076,7 +1080,7 @@ def test_add_checkout_lines_with_reservations( new_lines.append({"quantity": 2, "variantId": variant_id}) # Adding multiple lines to checkout has same query count as adding one - with django_assert_num_queries(80): + with django_assert_num_queries(89): variables = { "id": Node.to_global_id("Checkout", checkout.pk), "lines": [new_lines[0]], @@ -1089,7 +1093,7 @@ def test_add_checkout_lines_with_reservations( checkout.lines.exclude(id=line.id).delete() - with django_assert_num_queries(80): + with django_assert_num_queries(89): variables = { "id": Node.to_global_id("Checkout", checkout.pk), "lines": new_lines, @@ -1101,6 +1105,140 @@ def test_add_checkout_lines_with_reservations( assert not data["errors"] +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_add_checkout_lines_catalogue_discount_applies( + user_api_client, + catalogue_promotion_without_rules, + checkout, + channel_USD, + django_assert_num_queries, + count_queries, + variant_with_many_stocks, +): + # given + Stock.objects.update(quantity=100) + variant = variant_with_many_stocks + variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk) + + # prepare promotion with 50% discount + promotion = catalogue_promotion_without_rules + catalogue_predicate = {"variantPredicate": {"ids": [variant_id]}} + rule = promotion.rules.create( + name="Catalogue rule percentage 50", + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=Decimal(50), + ) + rule.channels.add(channel_USD) + fetch_variants_for_promotion_rules(PromotionRule.objects.all()) + + # update prices + update_discounted_prices_for_promotion(Product.objects.all()) + + variables = { + "id": to_global_id_or_none(checkout), + "lines": [{"variantId": variant_id, "quantity": 3}], + "channelSlug": checkout.channel.slug, + } + + # when + with django_assert_num_queries(81): + response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["checkoutLinesAdd"] + assert not data["errors"] + assert checkout.lines.count() == 1 + assert CheckoutLineDiscount.objects.count() == 1 + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_add_checkout_lines_multiple_catalogue_discount_applies( + user_api_client, + catalogue_promotion_without_rules, + checkout, + channel_USD, + django_assert_num_queries, + count_queries, + product_variant_list, + warehouse, +): + # given + variants = product_variant_list + variant_global_ids = [variant.get_global_id() for variant in variants] + + channel_listing = variants[2].channel_listings.first() + channel_listing.channel = channel_USD + channel_listing.currency = channel_USD.currency_code + channel_listing.save(update_fields=["channel_id", "currency"]) + + Stock.objects.bulk_create( + [ + Stock(product_variant=variant, warehouse=warehouse, quantity=1000) + for variant in variants + ] + ) + + # create many rules + promotion = catalogue_promotion_without_rules + rules = [] + catalogue_predicate = {"variantPredicate": {"ids": variant_global_ids}} + for idx in range(5): + reward_value = 2 + idx + rules.append( + PromotionRule( + name=f"Catalogue rule fixed {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(reward_value), + ) + ) + for idx in range(5): + reward_value = idx * 10 + 25 + rules.append( + PromotionRule( + name=f"Catalogue rule percentage {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=Decimal(reward_value), + ) + ) + rules = PromotionRule.objects.bulk_create(rules) + for rule in rules: + rule.channels.add(channel_USD) + fetch_variants_for_promotion_rules(PromotionRule.objects.all()) + + # update prices + update_discounted_prices_for_promotion(Product.objects.all()) + + variables = { + "id": to_global_id_or_none(checkout), + "lines": [ + {"variantId": variant_global_ids[0], "quantity": 4}, + {"variantId": variant_global_ids[1], "quantity": 5}, + {"variantId": variant_global_ids[2], "quantity": 6}, + {"variantId": variant_global_ids[3], "quantity": 7}, + ], + "channelSlug": checkout.channel.slug, + } + + # when + with django_assert_num_queries(81): + response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["checkoutLinesAdd"] + assert not data["errors"] + assert checkout.lines.count() == 4 + assert CheckoutLineDiscount.objects.count() == 4 + + @pytest.mark.django_db @pytest.mark.count_queries(autouse=False) def test_add_checkout_lines_order_discount_applies( @@ -1125,7 +1263,7 @@ def test_add_checkout_lines_order_discount_applies( } # when - with django_assert_num_queries(75): + with django_assert_num_queries(84): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then @@ -1159,7 +1297,7 @@ def test_add_checkout_lines_gift_discount_applies( } # when - with django_assert_num_queries(101): + with django_assert_num_queries(111): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then diff --git a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py index 901904b4781..c7727866827 100644 --- a/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py +++ b/saleor/graphql/checkout/tests/mutations/test_checkout_complete_with_payment.py @@ -19,7 +19,7 @@ from .....core.exceptions import InsufficientStock, InsufficientStockData from .....core.taxes import TaxError, zero_money, zero_taxed_money from .....discount import DiscountType, DiscountValueType, RewardValueType -from .....discount.models import CheckoutLineDiscount, PromotionRule +from .....discount.models import CheckoutLineDiscount from .....giftcard import GiftCardEvents from .....giftcard.models import GiftCard, GiftCardEvent from .....order import OrderOrigin, OrderStatus @@ -29,7 +29,6 @@ from .....payment.interface import GatewayResponse from .....payment.model_helpers import get_subtotal from .....plugins.manager import PluginsManager, get_plugins_manager -from .....product.models import VariantChannelListingPromotionRule from .....tests.utils import flush_post_commit_hooks from .....warehouse.models import Reservation, Stock, WarehouseClickAndCollectOption from .....warehouse.tests.utils import get_available_quantity_for_stock @@ -1679,186 +1678,6 @@ def test_checkout_complete_product_on_old_sale( ).exists(), "Checkout should have been deleted" -def test_checkout_complete_multiple_rules_applied( - user_api_client, - checkout_with_item, - catalogue_promotion_without_rules, - payment_dummy, - address, - shipping_method, -): - # given - checkout = checkout_with_item - checkout.shipping_address = address - checkout.shipping_method = shipping_method - checkout.billing_address = address - checkout.metadata_storage.store_value_in_metadata(items={"accepted": "true"}) - checkout.metadata_storage.store_value_in_private_metadata( - items={"accepted": "false"} - ) - checkout.save() - checkout.metadata_storage.save() - - checkout_line = checkout.lines.first() - checkout_line_quantity = checkout_line.quantity - checkout_line_variant = checkout_line.variant - - channel = checkout.channel - - reward_value_1 = Decimal("2") - reward_value_2 = Decimal("10") - rule_1, rule_2 = PromotionRule.objects.bulk_create( - [ - PromotionRule( - name="Percentage promotion rule 1", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.FIXED, - reward_value=reward_value_1, - catalogue_predicate={ - "productPredicate": { - "ids": [ - graphene.Node.to_global_id( - "Product", checkout_line_variant.product_id - ) - ] - } - }, - ), - PromotionRule( - name="Percentage promotion rule 2", - promotion=catalogue_promotion_without_rules, - reward_value_type=RewardValueType.PERCENTAGE, - reward_value=reward_value_2, - catalogue_predicate={ - "variantPredicate": { - "ids": [ - graphene.Node.to_global_id( - "ProductVariant", checkout_line_variant.id - ) - ] - } - }, - ), - ] - ) - - rule_1.channels.add(channel) - rule_2.channels.add(channel) - - variant_channel_listing = checkout_line_variant.channel_listings.get( - channel=channel - ) - discount_amount_2 = reward_value_2 / 100 * variant_channel_listing.price.amount - discounted_price = ( - variant_channel_listing.price.amount - reward_value_1 - discount_amount_2 - ) - variant_channel_listing.discounted_price_amount = discounted_price - variant_channel_listing.save(update_fields=["discounted_price_amount"]) - - VariantChannelListingPromotionRule.objects.bulk_create( - [ - VariantChannelListingPromotionRule( - variant_channel_listing=variant_channel_listing, - promotion_rule=rule_1, - discount_amount=reward_value_1, - currency=channel.currency_code, - ), - VariantChannelListingPromotionRule( - variant_channel_listing=variant_channel_listing, - promotion_rule=rule_2, - discount_amount=discount_amount_2, - currency=channel.currency_code, - ), - ] - ) - - CheckoutLineDiscount.objects.bulk_create( - [ - CheckoutLineDiscount( - line=checkout_line, - type=DiscountType.PROMOTION, - value_type=DiscountValueType.FIXED, - amount_value=reward_value_1, - currency=channel.currency_code, - promotion_rule=rule_1, - ), - CheckoutLineDiscount( - line=checkout_line, - type=DiscountType.PROMOTION, - value_type=DiscountValueType.FIXED, - amount_value=discount_amount_2, - currency=channel.currency_code, - promotion_rule=rule_2, - ), - ] - ) - - manager = get_plugins_manager(allow_replica=False) - lines, _ = fetch_checkout_lines(checkout) - checkout_info = fetch_checkout_info(checkout, lines, manager) - - total = calculations.checkout_total( - manager=manager, - checkout_info=checkout_info, - lines=lines, - address=address, - ) - payment = payment_dummy - payment.is_active = True - payment.order = None - payment.total = total.gross.amount - payment.currency = total.gross.currency - payment.checkout = checkout - payment.save() - assert not payment.transactions.exists() - - orders_count = Order.objects.count() - variables = { - "id": to_global_id_or_none(checkout), - "redirectUrl": "https://www.example.com", - } - - # when - response = user_api_client.post_graphql(MUTATION_CHECKOUT_COMPLETE, variables) - - # then - content = get_graphql_content(response) - data = content["data"]["checkoutComplete"] - assert not data["errors"] - - order_token = data["order"]["token"] - order_id = data["order"]["id"] - assert Order.objects.count() == orders_count + 1 - order = Order.objects.first() - assert str(order.id) == order_token - assert order_id == graphene.Node.to_global_id("Order", order.id) - assert order.metadata == checkout.metadata_storage.metadata - assert order.private_metadata == checkout.metadata_storage.private_metadata - - order_line = order.lines.first() - subtotal = get_subtotal(order.lines.all(), order.currency) - assert order.subtotal == subtotal - assert data["order"]["subtotal"]["gross"]["amount"] == subtotal.gross.amount - assert order.total == total - assert order.undiscounted_total == total + ( - order_line.undiscounted_total_price - order_line.total_price - ) - assert order_line.discounts.count() == 2 - - assert checkout_line_quantity == order_line.quantity - assert checkout_line_variant == order_line.variant - assert order.shipping_address == address - assert order.shipping_method == checkout.shipping_method - assert order.payments.exists() - order_payment = order.payments.first() - assert order_payment == payment - assert payment.transactions.count() == 1 - - assert not Checkout.objects.filter( - pk=checkout.pk - ).exists(), "Checkout should have been deleted" - - def test_checkout_with_voucher_on_specific_product_complete_with_product_on_promotion( user_api_client, checkout_with_item_and_voucher_specific_products, diff --git a/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py b/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py index a74f429ad2c..c2b3dccc05b 100644 --- a/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py +++ b/saleor/graphql/checkout/tests/mutations/test_order_create_from_checkout.py @@ -15,7 +15,7 @@ from .....checkout.models import Checkout, CheckoutLine from .....core.taxes import TaxError, zero_money, zero_taxed_money from .....discount import DiscountType, DiscountValueType, RewardValueType -from .....discount.models import CheckoutLineDiscount, Promotion +from .....discount.models import CheckoutLineDiscount from .....giftcard import GiftCardEvents from .....giftcard.models import GiftCard, GiftCardEvent from .....order import OrderOrigin, OrderStatus @@ -966,11 +966,54 @@ def test_order_from_checkout_voucher_not_increase_uses_on_preprocess_creation_fa assert code.used == 0 +MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS = """ +mutation orderCreateFromCheckout($id: ID!){ + orderCreateFromCheckout(id: $id){ + order{ + id + total { + currency + net { + amount + } + gross { + amount + } + } + lines { + unitDiscount { + amount + } + unitDiscountType + unitDiscountValue + isGift + quantity + } + discounts { + amount { + amount + } + valueType + type + } + } + errors{ + field + message + code + variants + } + } +} +""" + + def test_order_from_checkout_on_catalogue_promotion( app_api_client, checkout_with_item_on_promotion, permission_handle_checkouts, permission_manage_checkouts, + permission_manage_orders, address, shipping_method, ): @@ -987,9 +1030,13 @@ def test_order_from_checkout_on_catalogue_promotion( # when response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, + MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS, variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], + permissions=[ + permission_handle_checkouts, + permission_manage_checkouts, + permission_manage_orders, + ], ) # then @@ -997,22 +1044,30 @@ def test_order_from_checkout_on_catalogue_promotion( data = content["data"]["orderCreateFromCheckout"] assert not data["errors"] - order = Order.objects.first() - assert order.status == OrderStatus.UNCONFIRMED - assert order.origin == OrderOrigin.CHECKOUT - assert not order.original - - assert order.lines.count() == 1 - line = order.lines.first() - assert line.sale_id - assert line.unit_discount_reason - assert line.discounts.count() == 1 - discount = line.discounts.first() + order_db = Order.objects.first() + assert order_db.status == OrderStatus.UNCONFIRMED + assert order_db.origin == OrderOrigin.CHECKOUT + assert not order_db.original + + assert order_db.lines.count() == 1 + line_db = order_db.lines.first() + assert line_db.sale_id + assert line_db.unit_discount_reason + assert line_db.discounts.count() == 1 + discount = line_db.discounts.first() assert discount.promotion_rule assert ( - discount.amount_value == (order.undiscounted_total - order.total).gross.amount + discount.amount_value + == (order_db.undiscounted_total - order_db.total).gross.amount ) - assert not order.discounts.first() + assert not order_db.discounts.first() + + assert not data["order"]["discounts"] + assert len(data["order"]["lines"]) == 1 + line = data["order"]["lines"][0] + assert line["unitDiscount"]["amount"] == discount.amount_value / line["quantity"] + assert line["unitDiscountType"] == RewardValueType.FIXED.upper() + assert line["unitDiscountValue"] == discount.amount_value / line["quantity"] def test_order_from_checkout_on_order_promotion( @@ -1020,6 +1075,7 @@ def test_order_from_checkout_on_order_promotion( checkout_with_item_and_order_discount, permission_handle_checkouts, permission_manage_checkouts, + permission_manage_orders, address, shipping_method, ): @@ -1036,9 +1092,13 @@ def test_order_from_checkout_on_order_promotion( # when response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, + MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS, variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], + permissions=[ + permission_handle_checkouts, + permission_manage_checkouts, + permission_manage_orders, + ], ) # then @@ -1059,6 +1119,12 @@ def test_order_from_checkout_on_order_promotion( ) assert order_discount.type == DiscountType.ORDER_PROMOTION + discounts = data["order"]["discounts"] + assert len(discounts) == 1 + assert discounts[0]["amount"]["amount"] == order_discount.amount_value + assert discounts[0]["type"] == DiscountType.ORDER_PROMOTION.upper() + assert discounts[0]["valueType"] == DiscountValueType.FIXED.upper() + def test_order_from_checkout_on_gift_promotion( app_api_client, @@ -1066,6 +1132,7 @@ def test_order_from_checkout_on_gift_promotion( gift_promotion_rule, permission_handle_checkouts, permission_manage_checkouts, + permission_manage_orders, address, shipping_method, ): @@ -1083,9 +1150,13 @@ def test_order_from_checkout_on_gift_promotion( # when response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, + MUTATION_ORDER_CREATE_FROM_CHECKOUT_PROMOTIONS, variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], + permissions=[ + permission_handle_checkouts, + permission_manage_checkouts, + permission_manage_orders, + ], ) # then @@ -1101,10 +1172,23 @@ def test_order_from_checkout_on_gift_promotion( assert not order.discounts.all() assert order.lines.count() == line_count gift_line = order.lines.get(is_gift=True) + gift_price = gift_line.variant.channel_listings.get( + channel=checkout.channel + ).discounted_price_amount assert gift_line.discounts.count() == 1 line_discount = gift_line.discounts.first() assert line_discount.promotion_rule == gift_promotion_rule assert line_discount.type == DiscountType.ORDER_PROMOTION + assert line_discount.amount_value == gift_price + assert line_discount.value == gift_price + + assert not data["order"]["discounts"] + lines = data["order"]["lines"] + assert len(lines) == 2 + gift_line_api = [line for line in lines if line["isGift"]][0] + assert gift_line_api["unitDiscount"]["amount"] == gift_price + assert gift_line_api["unitDiscountValue"] == gift_price + assert gift_line_api["unitDiscountType"] == RewardValueType.FIXED.upper() def test_order_from_checkout_on_catalogue_and_gift_promotion( @@ -1197,91 +1281,6 @@ def test_order_from_checkout_on_catalogue_and_gift_promotion( ).gross.amount == top_price + line_discount.amount_value -def test_order_from_checkout_multiple_rules_applied( - app_api_client, - checkout_with_item_on_promotion, - permission_handle_checkouts, - permission_manage_checkouts, - address, - shipping_method, -): - # given - checkout = checkout_with_item_on_promotion - checkout.shipping_address = address - checkout.shipping_method = shipping_method - checkout.billing_address = address - checkout.save() - - channel = checkout.channel - - line = checkout.lines.first() - variant = line.variant - variant_channel_listing = variant.channel_listings.get(channel=channel) - - reward_value_2 = Decimal("10.00") - promotion = Promotion.objects.first() - rule_2 = promotion.rules.create( - name="Percentage promotion rule 2", - reward_value_type=RewardValueType.PERCENTAGE, - reward_value=reward_value_2, - catalogue_predicate={ - "variantPredicate": { - "ids": [graphene.Node.to_global_id("ProductVariant", line.variant_id)] - } - }, - ) - rule_2.channels.add(channel) - - discount_amount_2 = reward_value_2 / 100 * variant_channel_listing.price.amount - - variant_channel_listing.variantlistingpromotionrule.create( - promotion_rule=rule_2, - discount_amount=discount_amount_2, - currency=channel.currency_code, - ) - CheckoutLineDiscount.objects.create( - line=line, - type=DiscountType.PROMOTION, - value_type=DiscountValueType.PERCENTAGE, - amount_value=discount_amount_2, - currency=channel.currency_code, - promotion_rule=rule_2, - ) - - variant_channel_listing.discounted_price_amount = ( - variant_channel_listing.discounted_price_amount - discount_amount_2 - ) - variant_channel_listing.save(update_fields=["discounted_price_amount"]) - - variables = { - "id": graphene.Node.to_global_id("Checkout", checkout.pk), - } - - # when - response = app_api_client.post_graphql( - MUTATION_ORDER_CREATE_FROM_CHECKOUT, - variables, - permissions=[permission_handle_checkouts, permission_manage_checkouts], - ) - - # then - content = get_graphql_content(response) - data = content["data"]["orderCreateFromCheckout"] - assert not data["errors"] - - order = Order.objects.first() - - assert order.status == OrderStatus.UNCONFIRMED - assert order.origin == OrderOrigin.CHECKOUT - assert not order.original - - assert order.lines.count() == 1 - line = order.lines.first() - assert line.sale_id == graphene.Node.to_global_id("Promotion", promotion.pk) - assert line.unit_discount_reason - assert line.discounts.count() == 2 - - @pytest.mark.integration def test_order_from_checkout_without_inventory_tracking( app_api_client, diff --git a/saleor/graphql/discount/utils.py b/saleor/graphql/discount/utils.py index a40740bc52b..6bac0d3b236 100644 --- a/saleor/graphql/discount/utils.py +++ b/saleor/graphql/discount/utils.py @@ -10,6 +10,7 @@ from ...checkout.models import Checkout from ...discount.models import Promotion, PromotionRule from ...discount.utils import update_rule_variant_relation +from ...order.models import Order from ...product.managers import ProductsQueryset, ProductVariantQueryset from ...product.models import ( Category, @@ -20,6 +21,7 @@ ) from ..checkout.filters import CheckoutDiscountedObjectWhere from ..core.connection import where_filter_qs +from ..order.filters import OrderDiscountedObjectWhere from ..product.filters import ( CategoryWhere, CollectionWhere, @@ -283,6 +285,10 @@ def _handle_predicate( return _handle_checkout_predicate( result_qs, base_qs, predicate_data, operator, currency ) + elif predicate_type == PredicateObjectType.ORDER: + return _handle_order_predicate( + result_qs, base_qs, predicate_data, operator, currency + ) def _handle_catalogue_predicate( @@ -327,6 +333,32 @@ def _handle_checkout_predicate( return result_qs +def _handle_order_predicate( + result_qs: QuerySet, + base_qs: QuerySet, + predicate_data: dict[str, Union[dict, str, list, bool]], + operator, + currency: Optional[str] = None, +): + predicate_data = _predicate_to_snake_case(predicate_data) + if predicate := predicate_data.get("discounted_object_predicate"): + if currency: + predicate["currency"] = currency + + orders = where_filter_qs( + Order.objects.filter(pk__in=base_qs.values("pk")), + {}, + OrderDiscountedObjectWhere, + predicate, + None, + ) + if operator == Operators.AND: + result_qs &= orders + else: + result_qs |= orders + return result_qs + + def _predicate_to_snake_case(obj: Any) -> Any: if isinstance(obj, dict): data = {} diff --git a/saleor/graphql/order/filters.py b/saleor/graphql/order/filters.py index aeb1fffe94e..aa6056ae404 100644 --- a/saleor/graphql/order/filters.py +++ b/saleor/graphql/order/filters.py @@ -2,6 +2,7 @@ import django_filters import graphene +from django.core.exceptions import ValidationError from django.db.models import Exists, OuterRef, Q from django.utils import timezone from graphql.error import GraphQLError @@ -12,6 +13,7 @@ from ...order.search import search_orders from ...payment import ChargeStatus from ...product.models import ProductVariant +from ..channel.filters import get_currency_from_filter_data from ..core.filters import ( GlobalIDMultipleChoiceFilter, ListObjectTypeFilter, @@ -20,9 +22,10 @@ ) from ..core.types import DateRangeInput, DateTimeRangeInput from ..core.utils import from_global_id_or_error +from ..discount.filters import DiscountedObjectWhere from ..payment.enums import PaymentChargeStatusEnum from ..utils import resolve_global_ids_to_primary_keys -from ..utils.filters import filter_range_field +from ..utils.filters import filter_range_field, filter_where_by_numeric_field from .enums import OrderAuthorizeStatusEnum, OrderChargeStatusEnum, OrderStatusFilter @@ -230,3 +233,28 @@ def is_valid(self): message="'ids' and 'numbers` are not allowed to use together in filter." ) return super().is_valid() + + +class OrderDiscountedObjectWhere(DiscountedObjectWhere): + class Meta: + model = Order + fields = ["subtotal_net_amount", "total_net_amount"] + + def filter_base_subtotal_price(self, queryset, name, value): + currency = get_currency_from_filter_data(self.data) + return _filter_price(queryset, name, "subtotal_net_amount", value, currency) + + def filter_base_total_price(self, queryset, name, value): + currency = get_currency_from_filter_data(self.data) + return _filter_price(queryset, name, "total_net_amount", value, currency) + + +def _filter_price(qs, _, field_name, value, currency): + # We will have single channel/currency as the rule can be applied only + # on channels with the same currencies + if not currency: + raise ValidationError( + "You must provide a currency to filter by price field.", code="required" + ) + qs = qs.filter(currency=currency) + return filter_where_by_numeric_field(qs, field_name, value) diff --git a/saleor/graphql/order/mutations/draft_order_create.py b/saleor/graphql/order/mutations/draft_order_create.py index 6dfe4c639b0..4e61972b67a 100644 --- a/saleor/graphql/order/mutations/draft_order_create.py +++ b/saleor/graphql/order/mutations/draft_order_create.py @@ -481,7 +481,7 @@ def _commit_changes( ) @classmethod - def should_invalidate_prices(cls, instance, cleaned_input, is_new_instance) -> bool: + def should_invalidate_prices(cls, cleaned_input, is_new_instance) -> bool: # Force price recalculation for all new instances return is_new_instance @@ -565,7 +565,7 @@ def _save_draft_order( "display_gross_prices", ] ) - if cls.should_invalidate_prices(instance, cleaned_input, is_new_instance): + if cls.should_invalidate_prices(cleaned_input, is_new_instance): invalidate_order_prices(instance) updated_fields.extend(["should_refresh_prices"]) recalculate_order_weight(instance) diff --git a/saleor/graphql/order/mutations/draft_order_update.py b/saleor/graphql/order/mutations/draft_order_update.py index d835ff65998..83cf777f575 100644 --- a/saleor/graphql/order/mutations/draft_order_update.py +++ b/saleor/graphql/order/mutations/draft_order_update.py @@ -51,14 +51,13 @@ def get_instance(cls, info: ResolveInfo, **data): return instance @classmethod - def should_invalidate_prices(cls, instance, cleaned_input, is_new_instance) -> bool: + def should_invalidate_prices(cls, cleaned_input, *args) -> bool: return any( field in cleaned_input for field in [ "shipping_address", "billing_address", "shipping_method", - "lines", "voucher", ] ) diff --git a/saleor/graphql/order/mutations/order_update.py b/saleor/graphql/order/mutations/order_update.py index ec5e8dee2d1..06012422d54 100644 --- a/saleor/graphql/order/mutations/order_update.py +++ b/saleor/graphql/order/mutations/order_update.py @@ -84,7 +84,7 @@ def get_instance(cls, info: ResolveInfo, **data): return instance @classmethod - def should_invalidate_prices(cls, instance, cleaned_input, is_new_instance) -> bool: + def should_invalidate_prices(cls, cleaned_input, *args) -> bool: return any( cleaned_input.get(field) is not None for field in ["shipping_address", "billing_address"] @@ -101,7 +101,7 @@ def save(cls, info: ResolveInfo, instance, cleaned_input): *prepare_order_search_vector_value(instance) ) manager = get_plugin_manager_promise(info.context).get() - if cls.should_invalidate_prices(instance, cleaned_input, False): + if cls.should_invalidate_prices(cleaned_input): invalidate_order_prices(instance) instance.save() diff --git a/saleor/graphql/order/tests/mutations/test_draft_order_complete.py b/saleor/graphql/order/tests/mutations/test_draft_order_complete.py index e568f56cf21..82267764ad1 100644 --- a/saleor/graphql/order/tests/mutations/test_draft_order_complete.py +++ b/saleor/graphql/order/tests/mutations/test_draft_order_complete.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from decimal import Decimal from unittest.mock import patch import graphene @@ -10,10 +11,13 @@ from .....core import EventDeliveryStatus from .....core.models import EventDelivery +from .....core.prices import quantize_price from .....core.taxes import zero_taxed_money +from .....discount import DiscountValueType from .....discount.models import VoucherCustomer from .....order import OrderOrigin, OrderStatus from .....order import events as order_events +from .....order.calculations import fetch_order_prices_if_expired from .....order.error_codes import OrderErrorCode from .....order.interface import OrderTaxedPricesData from .....order.models import OrderEvent @@ -1193,3 +1197,212 @@ def side_effect(order, *args, **kwargs): order.refresh_from_db() assert not order.should_refresh_prices assert order.tax_error == "Empty tax data." + + +DRAFT_ORDER_COMPLETE_WITH_DISCOUNTS_MUTATION = """ + mutation draftComplete($id: ID!) { + draftOrderComplete(id: $id) { + errors { + field + code + message + } + order { + id + total { + net { + amount + } + } + discounts { + amount { + amount + } + valueType + type + reason + } + lines { + id + quantity + totalPrice { + net { + amount + } + } + unitDiscount { + amount + } + unitDiscountValue + unitDiscountReason + unitDiscountType + isGift + } + } + } + } + """ + + +def test_draft_order_complete_with_catalogue_and_order_discount( + staff_api_client, + permission_group_manage_orders, + staff_user, + draft_order_and_promotions, + plugins_manager, +): + # given + Allocation.objects.all().delete() + permission_group_manage_orders.user_set.add(staff_api_client.user) + + order, rule_catalogue, rule_total, _ = draft_order_and_promotions + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + order_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_total.promotion_id + ) + rule_catalogue_value = rule_catalogue.reward_value + rule_total_value = rule_total.reward_value + + currency = order.currency + order_id = graphene.Node.to_global_id("Order", order.id) + variables = {"id": order_id} + fetch_order_prices_if_expired(order, plugins_manager, force_update=True) + + # when + response = staff_api_client.post_graphql( + DRAFT_ORDER_COMPLETE_WITH_DISCOUNTS_MUTATION, variables + ) + + # then + content = get_graphql_content(response) + order_data = content["data"]["draftOrderComplete"]["order"] + + assert len(order_data["discounts"]) == 1 + + order_discount = order_data["discounts"][0] + assert order_discount["amount"]["amount"] == 25.00 == rule_total_value + assert order_discount["reason"] == f"Promotion: {order_promotion_id}" + assert order_discount["amount"]["amount"] == 25.00 == rule_total_value + assert order_discount["valueType"] == DiscountValueType.FIXED.upper() + + lines_db = order.lines.all() + line_1_db = [line for line in lines_db if line.quantity == 3][0] + line_2_db = [line for line in lines_db if line.quantity == 2][0] + line_1_base_total = line_1_db.quantity * line_1_db.base_unit_price_amount + line_2_base_total = line_2_db.quantity * line_2_db.base_unit_price_amount + base_total = line_1_base_total + line_2_base_total + line_1_order_discount_portion = rule_total_value * line_1_base_total / base_total + line_2_order_discount_portion = rule_total_value - line_1_order_discount_portion + + lines = order_data["lines"] + line_1 = [line for line in lines if line["quantity"] == 3][0] + line_2 = [line for line in lines if line["quantity"] == 2][0] + line_1_total = quantize_price( + line_1_db.undiscounted_total_price_net_amount - line_1_order_discount_portion, + currency, + ) + assert line_1["totalPrice"]["net"]["amount"] == float(line_1_total) + assert line_1["unitDiscount"]["amount"] == 0.00 + assert line_1["unitDiscountReason"] is None + assert line_1["unitDiscountValue"] == 0.00 + + line_2_total = quantize_price( + line_2_db.undiscounted_total_price_net_amount + - rule_catalogue_value * line_2_db.quantity + - line_2_order_discount_portion, + currency, + ) + assert line_2["totalPrice"]["net"]["amount"] == float(line_2_total) + assert line_2["unitDiscount"]["amount"] == rule_catalogue_value + assert line_2["unitDiscountReason"] == f"Promotion: {catalogue_promotion_id}" + assert line_2["unitDiscountType"] == DiscountValueType.FIXED.upper() + assert line_2["unitDiscountValue"] == rule_catalogue_value + + total = ( + order.undiscounted_total_net_amount + - line_2["quantity"] * rule_catalogue_value + - rule_total_value + ) + assert order_data["total"]["net"]["amount"] == total + assert total == line_2_total + line_1_total + order.base_shipping_price_amount + + +def test_draft_order_complete_with_catalogue_and_gift_discount( + staff_api_client, + permission_group_manage_orders, + staff_user, + draft_order_and_promotions, + plugins_manager, +): + # given + Allocation.objects.all().delete() + permission_group_manage_orders.user_set.add(staff_api_client.user) + + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + rule_total.reward_value = Decimal(0) + rule_total.save(update_fields=["reward_value"]) + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + gift_promotion_id = graphene.Node.to_global_id("Promotion", rule_gift.promotion_id) + rule_catalogue_value = rule_catalogue.reward_value + + currency = order.currency + order_id = graphene.Node.to_global_id("Order", order.id) + variables = {"id": order_id} + fetch_order_prices_if_expired(order, plugins_manager, force_update=True) + + # when + response = staff_api_client.post_graphql( + DRAFT_ORDER_COMPLETE_WITH_DISCOUNTS_MUTATION, variables + ) + + # then + content = get_graphql_content(response) + order_data = content["data"]["draftOrderComplete"]["order"] + assert not order_data["discounts"] + + lines_db = order.lines.all() + line_1_db = [line for line in lines_db if line.quantity == 3][0] + line_2_db = [line for line in lines_db if line.quantity == 2][0] + gift_line_db = [line for line in lines_db if line.is_gift][0] + gift_price = gift_line_db.variant.channel_listings.get( + channel=order.channel + ).price_amount + + lines = order_data["lines"] + assert len(lines) == 3 + line_1 = [line for line in lines if line["quantity"] == 3][0] + line_2 = [line for line in lines if line["quantity"] == 2][0] + gift_line = [line for line in lines if line["isGift"] is True][0] + + line_1_total = line_1_db.undiscounted_total_price_net_amount + assert line_1["totalPrice"]["net"]["amount"] == line_1_total + assert line_1["unitDiscount"]["amount"] == 0.00 + assert line_1["unitDiscountReason"] is None + assert line_1["unitDiscountValue"] == 0.00 + + line_2_total = quantize_price( + line_2_db.undiscounted_total_price_net_amount + - rule_catalogue_value * line_2_db.quantity, + currency, + ) + assert line_2["totalPrice"]["net"]["amount"] == line_2_total + assert line_2["unitDiscount"]["amount"] == rule_catalogue_value + assert line_2["unitDiscountReason"] == f"Promotion: {catalogue_promotion_id}" + assert line_2["unitDiscountType"] == DiscountValueType.FIXED.upper() + assert line_2["unitDiscountValue"] == rule_catalogue_value + + assert gift_line["totalPrice"]["net"]["amount"] == 0.00 + assert gift_line["unitDiscount"]["amount"] == gift_price + assert gift_line["unitDiscountReason"] == f"Promotion: {gift_promotion_id}" + assert gift_line["unitDiscountType"] == DiscountValueType.FIXED.upper() + assert gift_line["unitDiscountValue"] == gift_price + + total = ( + order.undiscounted_total_net_amount - rule_catalogue_value * line_2_db.quantity + ) + assert order_data["total"]["net"]["amount"] == total + assert total == line_2_total + line_1_total + order.base_shipping_price_amount diff --git a/saleor/graphql/order/tests/mutations/test_draft_order_create.py b/saleor/graphql/order/tests/mutations/test_draft_order_create.py index bb5d8d14d33..9f6c74c85ff 100644 --- a/saleor/graphql/order/tests/mutations/test_draft_order_create.py +++ b/saleor/graphql/order/tests/mutations/test_draft_order_create.py @@ -8,8 +8,9 @@ from prices import Money from .....checkout import AddressType +from .....core.prices import quantize_price from .....core.taxes import TaxError, zero_taxed_money -from .....discount import DiscountType, DiscountValueType +from .....discount import DiscountType, DiscountValueType, RewardType, RewardValueType from .....discount.models import VoucherChannelListing, VoucherCustomer from .....order import OrderStatus from .....order import events as order_events @@ -18,6 +19,7 @@ from .....payment.model_helpers import get_subtotal from .....product.models import ProductVariant from .....tax import TaxCalculationStrategy +from .....tests.utils import round_up from ....tests.utils import assert_no_permission, get_graphql_content DRAFT_ORDER_CREATE_MUTATION = """ @@ -40,6 +42,14 @@ amount } discountName + discounts { + amount { + amount + } + valueType + type + reason + } redirectUrl lines { productName @@ -116,6 +126,7 @@ unitDiscountReason unitDiscountType unitDiscountValue + isGift } } } @@ -2289,7 +2300,7 @@ def test_draft_order_create_with_custom_price_in_order_line( assert order_line_1.undiscounted_base_unit_price_amount == expected_price_variant_1 -def test_draft_order_create_product_on_promotion( +def test_draft_order_create_product_catalogue_promotion( staff_api_client, permission_group_manage_orders, staff_user, @@ -2415,7 +2426,7 @@ def test_draft_order_create_product_on_promotion( assert event_parameters["lines"][0]["quantity"] == quantity -def test_draft_order_create_product_on_promotion_flat_taxes( +def test_draft_order_create_product_catalogue_promotion_flat_taxes( staff_api_client, permission_group_manage_orders, staff_user, @@ -2544,3 +2555,247 @@ def test_draft_order_create_product_on_promotion_flat_taxes( assert event_parameters["lines"][0]["item"] == str(order_lines[0]) assert event_parameters["lines"][0]["line_pk"] == str(order_lines[0].pk) assert event_parameters["lines"][0]["quantity"] == quantity + + +def test_draft_order_create_order_promotion_flat_rates( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + order_promotion_rule, + variant_with_many_stocks, + channel_USD, +): + # given + query = DRAFT_ORDER_CREATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + currency = channel_USD.currency_code + + tc = channel_USD.tax_configuration + tc.country_exceptions.all().delete() + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.prices_entered_with_tax = False + tc.save() + tax_rate = Decimal("1.23") + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_value_type == RewardValueType.PERCENTAGE + reward_value = rule.reward_value + assert rule.reward_value == Decimal("25") + + variant = variant_with_many_stocks + user_id = graphene.Node.to_global_id("User", customer_user.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + + quantity = 4 + variant_list = [ + {"variantId": variant_id, "quantity": quantity}, + ] + + # calculate expected values + variant_price = variant.channel_listings.get( + channel=channel_USD + ).discounted_price_amount + undiscounted_subtotal_net = Decimal(quantity * variant_price) + discount_amount = quantize_price( + reward_value / 100 * undiscounted_subtotal_net, currency + ) + subtotal_net = undiscounted_subtotal_net - discount_amount + subtotal_gross = quantize_price(tax_rate * subtotal_net, currency) + shipping_price_net = shipping_method.channel_listings.get( + channel=channel_USD + ).price_amount + shipping_price_gross = quantize_price(tax_rate * shipping_price_net, currency) + total_gross = quantize_price(subtotal_gross + shipping_price_gross, currency) + + shipping_address = graphql_address_data + shipping_id = graphene.Node.to_global_id("ShippingMethod", shipping_method.id) + channel_id = graphene.Node.to_global_id("Channel", channel_USD.id) + redirect_url = "https://www.example.com" + + variables = { + "input": { + "user": user_id, + "lines": variant_list, + "billingAddress": shipping_address, + "shippingAddress": shipping_address, + "shippingMethod": shipping_id, + "channelId": channel_id, + "redirectUrl": redirect_url, + } + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderCreate"]["errors"] + order = content["data"]["draftOrderCreate"]["order"] + assert order["status"] == OrderStatus.DRAFT.upper() + assert order["subtotal"]["gross"]["amount"] == float(subtotal_gross) + assert order["total"]["gross"]["amount"] == float(total_gross) + assert order["shippingPrice"]["gross"]["amount"] == float(shipping_price_gross) + + assert len(order["discounts"]) == 1 + assert order["discounts"][0]["amount"]["amount"] == discount_amount + assert order["discounts"][0]["reason"] == f"Promotion: {promotion_id}" + assert order["discounts"][0]["type"] == DiscountType.ORDER_PROMOTION.upper() + assert order["discounts"][0]["valueType"] == RewardValueType.PERCENTAGE.upper() + + assert len(order["lines"]) == 1 + assert order["lines"][0]["quantity"] == quantity + assert order["lines"][0]["totalPrice"]["gross"]["amount"] == float(subtotal_gross) + assert order["lines"][0]["undiscountedUnitPrice"]["gross"]["amount"] == float( + quantize_price(undiscounted_subtotal_net * tax_rate / quantity, currency) + ) + assert order["lines"][0]["unitPrice"]["gross"]["amount"] == float( + round_up(subtotal_gross / quantity) + ) + + order_db = Order.objects.get() + assert order_db.total_gross_amount == total_gross + assert order_db.subtotal_gross_amount == subtotal_gross + assert order_db.shipping_price_gross_amount == shipping_price_gross + + line_db = order_db.lines.get() + assert line_db.total_price_gross_amount == subtotal_gross + assert line_db.undiscounted_unit_price_gross_amount == quantize_price( + undiscounted_subtotal_net * tax_rate / quantity, currency + ) + assert line_db.unit_price_net_amount == quantize_price( + subtotal_net / quantity, currency + ) + assert line_db.unit_price_gross_amount == round_up(subtotal_gross / quantity) + + discount_db = order_db.discounts.get() + assert discount_db.amount_value == discount_amount + assert discount_db.reason == f"Promotion: {promotion_id}" + assert discount_db.value == reward_value + assert discount_db.value_type == RewardValueType.PERCENTAGE + + +def test_draft_order_create_gift_promotion_flat_rates( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + gift_promotion_rule, + variant_with_many_stocks, + channel_USD, +): + # given + query = DRAFT_ORDER_CREATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + currency = channel_USD.currency_code + + tc = channel_USD.tax_configuration + tc.country_exceptions.all().delete() + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.prices_entered_with_tax = False + tc.save() + tax_rate = Decimal("1.23") + + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_type == RewardType.GIFT + + variant = variant_with_many_stocks + user_id = graphene.Node.to_global_id("User", customer_user.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + + quantity = 4 + variant_list = [ + {"variantId": variant_id, "quantity": quantity}, + ] + + # calculate expected values + variant_price = variant.channel_listings.get( + channel=channel_USD + ).discounted_price_amount + subtotal_net = quantity * variant_price + subtotal_gross = quantize_price(tax_rate * subtotal_net, currency) + shipping_price_net = shipping_method.channel_listings.get( + channel=channel_USD + ).price_amount + shipping_price_gross = quantize_price(tax_rate * shipping_price_net, currency) + total_gross = quantize_price(subtotal_gross + shipping_price_gross, currency) + + shipping_address = graphql_address_data + shipping_id = graphene.Node.to_global_id("ShippingMethod", shipping_method.id) + channel_id = graphene.Node.to_global_id("Channel", channel_USD.id) + redirect_url = "https://www.example.com" + + variables = { + "input": { + "user": user_id, + "lines": variant_list, + "billingAddress": shipping_address, + "shippingAddress": shipping_address, + "shippingMethod": shipping_id, + "channelId": channel_id, + "redirectUrl": redirect_url, + } + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderCreate"]["errors"] + order = content["data"]["draftOrderCreate"]["order"] + + assert order["status"] == OrderStatus.DRAFT.upper() + assert order["subtotal"]["gross"]["amount"] == float(subtotal_gross) + assert Decimal(order["total"]["gross"]["amount"]) == float(total_gross) + assert Decimal(order["shippingPrice"]["gross"]["amount"]) == float( + shipping_price_gross + ) + + assert not order["discounts"] + + assert len(order["lines"]) == 2 + line = [line for line in order["lines"] if line["quantity"] == 4][0] + gift_line = [line for line in order["lines"] if line["isGift"]][0] + + assert line["totalPrice"]["gross"]["amount"] == float(subtotal_gross) + assert line["undiscountedUnitPrice"]["gross"]["amount"] == float( + subtotal_gross / quantity + ) + assert line["unitPrice"]["gross"]["amount"] == float(subtotal_gross / quantity) + assert line["unitDiscount"]["amount"] == 0.00 + + order_db = Order.objects.get() + assert order_db.total_gross_amount == total_gross + assert order_db.subtotal_gross_amount == subtotal_gross + assert order_db.shipping_price_gross_amount == shipping_price_gross + + lines_db = order_db.lines.all() + assert len(lines_db) == 2 + gift_line_db = [line for line in lines_db if line.is_gift][0] + gift_price = gift_line_db.variant.channel_listings.get( + channel=channel_USD + ).price_amount + + assert gift_line_db.total_price_gross_amount == Decimal(0) + assert gift_line_db.undiscounted_unit_price_gross_amount == Decimal(0) + assert gift_line_db.unit_price_gross_amount == Decimal(0) + assert gift_line_db.base_unit_price_amount == Decimal(0) + assert gift_line_db.unit_discount_value == gift_price + + assert gift_line["totalPrice"]["gross"]["amount"] == 0.00 + assert gift_line["undiscountedUnitPrice"]["gross"]["amount"] == 0.00 + assert gift_line["unitPrice"]["gross"]["amount"] == 0.00 + assert gift_line["unitDiscount"]["amount"] == gift_price + assert gift_line["unitDiscountReason"] == f"Promotion: {promotion_id}" + assert gift_line["unitDiscountType"] == RewardValueType.FIXED.upper() + assert gift_line["unitDiscountValue"] == gift_price + + discount_db = gift_line_db.discounts.get() + assert discount_db.amount_value == gift_price + assert discount_db.reason == f"Promotion: {promotion_id}" + assert discount_db.type == DiscountType.ORDER_PROMOTION diff --git a/saleor/graphql/order/tests/mutations/test_draft_order_update.py b/saleor/graphql/order/tests/mutations/test_draft_order_update.py index e568cc5c84b..3670562fb3f 100644 --- a/saleor/graphql/order/tests/mutations/test_draft_order_update.py +++ b/saleor/graphql/order/tests/mutations/test_draft_order_update.py @@ -1,9 +1,10 @@ import graphene +import pytest from prices import TaxedMoney from .....core.prices import quantize_price from .....core.taxes import zero_money -from .....discount import DiscountType, DiscountValueType +from .....discount import DiscountType, DiscountValueType, RewardValueType from .....order import OrderStatus from .....order.error_codes import OrderErrorCode from .....order.models import OrderEvent @@ -66,16 +67,58 @@ gross { amount } + net { + amount + } } subtotal { gross { amount } + net { + amount + } } undiscountedTotal { gross { amount } + net { + amount + } + } + discounts { + amount { + amount + } + valueType + type + reason + } + lines { + quantity + unitDiscount { + amount + } + undiscountedUnitPrice { + net { + amount + } + } + unitPrice { + net { + amount + } + } + totalPrice { + net { + amount + } + } + unitDiscountReason + unitDiscountType + unitDiscountValue + isGift } } } @@ -537,6 +580,234 @@ def test_draft_order_update_voucher_including_drafts_in_voucher_usage_invalid_co assert error["field"] == "voucher" +def test_draft_order_update_add_voucher_code_remove_order_promotion( + staff_api_client, + permission_group_manage_orders, + order_with_lines_and_order_promotion, + voucher, +): + # given + order = order_with_lines_and_order_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + order_discount = order.discounts.get() + assert order_discount.type == DiscountType.ORDER_PROMOTION + + discount_amount = voucher.channel_listings.get(channel=order.channel).discount_value + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": voucher.codes.first().code, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(order_discount._meta.model.DoesNotExist): + order_discount.refresh_from_db() + + order.refresh_from_db() + voucher_discount = order.discounts.get() + assert voucher_discount.amount_value == discount_amount + assert voucher_discount.value == discount_amount + assert voucher_discount.type == DiscountType.VOUCHER + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_amount + ) + + +def test_draft_order_update_add_voucher_code_remove_gift_promotion( + staff_api_client, + permission_group_manage_orders, + order_with_lines_and_gift_promotion, + voucher, +): + # given + order = order_with_lines_and_gift_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + + assert order.lines.count() == 3 + gift_line = order.lines.get(is_gift=True) + gift_discount = gift_line.discounts.get() + + discount_amount = voucher.channel_listings.get(channel=order.channel).discount_value + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": voucher.codes.first().code, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(gift_line._meta.model.DoesNotExist): + gift_line.refresh_from_db() + + with pytest.raises(gift_discount._meta.model.DoesNotExist): + gift_discount.refresh_from_db() + + order.refresh_from_db() + assert order.lines.count() == 2 + voucher_discount = order.discounts.get() + assert voucher_discount.amount_value == discount_amount + assert voucher_discount.value == discount_amount + assert voucher_discount.type == DiscountType.VOUCHER + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_amount + ) + + +def test_draft_order_update_remove_voucher_code_add_order_promotion( + staff_api_client, + permission_group_manage_orders, + draft_order, + voucher, + order_promotion_rule, +): + # given + order = draft_order + order.voucher = voucher + order.save(update_fields=["voucher"]) + + voucher_listing = voucher.channel_listings.get(channel=order.channel) + discount_amount = voucher_listing.discount_value + voucher_discount = order.discounts.create( + voucher=voucher, + value=discount_amount, + type=DiscountType.VOUCHER, + ) + + order.total_gross_amount -= discount_amount + order.total_net_amount -= discount_amount + order.save(update_fields=["total_net_amount", "total_gross_amount"]) + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": None, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(voucher_discount._meta.model.DoesNotExist): + voucher_discount.refresh_from_db() + + order.refresh_from_db() + order_discount = order.discounts.get() + reward_value = order_promotion_rule.reward_value + assert order_discount.value == reward_value + assert order_discount.value_type == order_promotion_rule.reward_value_type + + undiscounted_subtotal = ( + order.undiscounted_total_net_amount - order.base_shipping_price_amount + ) + assert order_discount.amount.amount == reward_value / 100 * undiscounted_subtotal + assert ( + order.total_net_amount + == order.undiscounted_total_net_amount - order_discount.amount.amount + ) + + +def test_draft_order_update_remove_voucher_code_add_gift_promotion( + staff_api_client, + permission_group_manage_orders, + draft_order, + voucher, + gift_promotion_rule, +): + # given + order = draft_order + order.voucher = voucher + order.save(update_fields=["voucher"]) + assert order.lines.count() == 2 + + voucher_listing = voucher.channel_listings.get(channel=order.channel) + discount_amount = voucher_listing.discount_value + voucher_discount = order.discounts.create( + voucher=voucher, + value=discount_amount, + type=DiscountType.VOUCHER, + ) + + order.total_gross_amount -= discount_amount + order.total_net_amount -= discount_amount + order.save(update_fields=["total_net_amount", "total_gross_amount"]) + + permission_group_manage_orders.user_set.add(staff_api_client.user) + query = DRAFT_ORDER_UPDATE_MUTATION + order_id = graphene.Node.to_global_id("Order", order.id) + + variables = { + "id": order_id, + "input": { + "voucherCode": None, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["draftOrderUpdate"] + assert not data["errors"] + + with pytest.raises(voucher_discount._meta.model.DoesNotExist): + voucher_discount.refresh_from_db() + + order.refresh_from_db() + assert order.lines.count() == 3 + assert not order.discounts.exists() + + gift_line = order.lines.filter(is_gift=True).first() + gift_discount = gift_line.discounts.get() + gift_price = gift_line.variant.channel_listings.get( + channel=order.channel + ).price_amount + + assert gift_discount.value == gift_price + assert gift_discount.amount.amount == gift_price + assert gift_discount.value_type == DiscountValueType.FIXED + + assert order.total_net_amount == order.undiscounted_total_net_amount + + def test_draft_order_update_with_non_draft_order( staff_api_client, permission_group_manage_orders, order_with_lines, voucher ): @@ -1163,3 +1434,115 @@ def test_draft_order_update_no_shipping_method_channel_listings( assert len(errors) == 1 assert errors[0]["code"] == OrderErrorCode.SHIPPING_METHOD_NOT_APPLICABLE.name assert errors[0]["field"] == "shippingMethod" + + +def test_draft_order_update_order_promotion( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + variant_with_many_stocks, + channel_USD, + draft_order, + order_promotion_rule, +): + # given + query = DRAFT_ORDER_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_value_type == RewardValueType.PERCENTAGE + reward_value = rule.reward_value + + variables = { + "id": graphene.Node.to_global_id("Order", draft_order.pk), + "input": { + "billingAddress": graphql_address_data, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderUpdate"]["errors"] + draft_order.refresh_from_db() + undiscounted_total = draft_order.undiscounted_total_net_amount + shipping_price = draft_order.base_shipping_price_amount + order = content["data"]["draftOrderUpdate"]["order"] + assert len(order["discounts"]) == 1 + discount_amount = reward_value / 100 * (undiscounted_total - shipping_price) + assert order["discounts"][0]["amount"]["amount"] == discount_amount + assert order["discounts"][0]["reason"] == f"Promotion: {promotion_id}" + assert order["discounts"][0]["type"] == DiscountType.ORDER_PROMOTION.upper() + assert order["discounts"][0]["valueType"] == RewardValueType.PERCENTAGE.upper() + + assert ( + order["subtotal"]["net"]["amount"] + == undiscounted_total - discount_amount - shipping_price + ) + assert order["total"]["net"]["amount"] == undiscounted_total - discount_amount + assert order["undiscountedTotal"]["net"]["amount"] == undiscounted_total + + +def test_draft_order_update_gift_promotion( + staff_api_client, + permission_group_manage_orders, + customer_user, + shipping_method, + graphql_address_data, + variant_with_many_stocks, + channel_USD, + draft_order, + gift_promotion_rule, +): + # given + query = DRAFT_ORDER_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + variables = { + "id": graphene.Node.to_global_id("Order", draft_order.pk), + "input": { + "billingAddress": graphql_address_data, + }, + } + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + assert not content["data"]["draftOrderUpdate"]["errors"] + + gift_line_db = [line for line in draft_order.lines.all() if line.is_gift][0] + gift_price = gift_line_db.variant.channel_listings.get( + channel=draft_order.channel + ).price_amount + + order = content["data"]["draftOrderUpdate"]["order"] + lines = order["lines"] + assert len(lines) == 3 + gift_line = [line for line in lines if line["isGift"]][0] + + assert gift_line["totalPrice"]["net"]["amount"] == 0.00 + assert gift_line["unitDiscount"]["amount"] == gift_price + assert gift_line["unitDiscountReason"] == f"Promotion: {promotion_id}" + assert gift_line["unitDiscountType"] == RewardValueType.FIXED.upper() + assert gift_line["unitDiscountValue"] == gift_price + + assert ( + order["subtotal"]["net"]["amount"] + == draft_order.undiscounted_total_net_amount + - draft_order.base_shipping_price_amount + ) + assert order["total"]["net"]["amount"] == draft_order.undiscounted_total_net_amount + assert ( + order["undiscountedTotal"]["net"]["amount"] + == draft_order.undiscounted_total_net_amount + ) diff --git a/saleor/graphql/order/tests/mutations/test_order_discount.py b/saleor/graphql/order/tests/mutations/test_order_discount.py index 9ac84ea9298..e33cdcbca41 100644 --- a/saleor/graphql/order/tests/mutations/test_order_discount.py +++ b/saleor/graphql/order/tests/mutations/test_order_discount.py @@ -7,7 +7,7 @@ from prices import Money, TaxedMoney, fixed_discount, percentage_discount from .....core.prices import quantize_price -from .....discount import DiscountValueType +from .....discount import DiscountType, DiscountValueType from .....order import OrderEvents, OrderStatus from .....order.error_codes import OrderErrorCode from .....order.interface import OrderTaxedPricesData @@ -280,6 +280,116 @@ def test_add_fixed_order_discount_to_order_by_app( assert discount_data["amount_value"] == str(order_discount.amount.amount) +def test_add_manual_discount_to_order_with_order_discount( + order_with_lines_and_order_promotion, + staff_api_client, + permission_group_manage_orders, +): + """Order discount should be deleted in a favour of manual discount.""" + # given + order = order_with_lines_and_order_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + order_discount = order.discounts.get() + + permission_group_manage_orders.user_set.add(staff_api_client.user) + discount_value = Decimal("10.00") + + variables = { + "orderId": graphene.Node.to_global_id("Order", order.pk), + "input": { + "valueType": DiscountValueTypeEnum.FIXED.name, + "value": discount_value, + }, + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_ADD, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountAdd"] + order.refresh_from_db() + assert not data["errors"] + + with pytest.raises(order_discount._meta.model.DoesNotExist): + order_discount.refresh_from_db() + + assert order.discounts.count() == 1 + manual_discount = order.discounts.get() + + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.FIXED + assert manual_discount.amount.amount == discount_value + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_value + ) + assert ( + order.shipping_price_net_amount + order.subtotal_net_amount + == order.total_net_amount + ) + + +def test_add_manual_discount_to_order_with_gift_discount( + order_with_lines_and_gift_promotion, + staff_api_client, + permission_group_manage_orders, +): + """Order discount should be deleted in a favour of manual discount.""" + # given + order = order_with_lines_and_gift_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + + assert order.lines.count() == 3 + gift_line = order.lines.filter(is_gift=True).first() + gift_discount = gift_line.discounts.get() + + permission_group_manage_orders.user_set.add(staff_api_client.user) + discount_value = Decimal("10.00") + + variables = { + "orderId": graphene.Node.to_global_id("Order", order.pk), + "input": { + "valueType": DiscountValueTypeEnum.FIXED.name, + "value": discount_value, + }, + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_ADD, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountAdd"] + order.refresh_from_db() + assert not data["errors"] + + assert order.lines.count() == 2 + + with pytest.raises(gift_line._meta.model.DoesNotExist): + gift_line.refresh_from_db() + + with pytest.raises(gift_discount._meta.model.DoesNotExist): + gift_discount.refresh_from_db() + + assert order.discounts.count() == 1 + manual_discount = order.discounts.get() + + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.FIXED + assert manual_discount.amount.amount == discount_value + + assert ( + order.total_net_amount == order.undiscounted_total_net_amount - discount_value + ) + assert ( + order.shipping_price_net_amount + order.subtotal_net_amount + == order.total_net_amount + ) + + ORDER_DISCOUNT_UPDATE = """ mutation OrderDiscountUpdate($discountId: ID!, $input: OrderDiscountCommonInput!){ orderDiscountUpdate(discountId:$discountId, input: $input){ @@ -598,6 +708,9 @@ def test_update_percentage_order_discount_to_order_by_app( orderDiscountDelete(discountId: $discountId){ order{ id + discounts { + id + } } errors{ field @@ -642,8 +755,8 @@ def test_delete_order_discount_from_order( errors = data["errors"] assert len(errors) == 0 - assert order.undiscounted_total == current_undiscounted_total - assert order.total == current_undiscounted_total + assert order.undiscounted_total.net == current_undiscounted_total.net + assert order.total.net == current_undiscounted_total.net event = order.events.get() assert event.type == OrderEvents.ORDER_DISCOUNT_DELETED @@ -747,8 +860,8 @@ def test_delete_order_discount_from_order_by_app( errors = data["errors"] assert len(errors) == 0 - assert order.undiscounted_total == current_undiscounted_total - assert order.total == current_undiscounted_total + assert order.undiscounted_total.net == current_undiscounted_total.net + assert order.total.net == current_undiscounted_total.net event = order.events.get() assert event.type == OrderEvents.ORDER_DISCOUNT_DELETED @@ -756,6 +869,92 @@ def test_delete_order_discount_from_order_by_app( assert order.search_vector +def test_delete_manual_discount_from_order_with_subtotal_promotion( + draft_order_with_fixed_discount_order, + staff_api_client, + permission_group_manage_orders, + order_promotion_rule, +): + # given + order = draft_order_with_fixed_discount_order + manual_discount = draft_order_with_fixed_discount_order.discounts.get() + + permission_group_manage_orders.user_set.add(staff_api_client.user) + variables = { + "discountId": graphene.Node.to_global_id("OrderDiscount", manual_discount.pk), + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_DELETE, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountDelete"] + assert not data["errors"] + + with pytest.raises(manual_discount._meta.model.DoesNotExist): + manual_discount.refresh_from_db() + + order.refresh_from_db() + order_discount = order.discounts.get() + reward_value = order_promotion_rule.reward_value + assert order_discount.value == reward_value + assert order_discount.value_type == order_promotion_rule.reward_value_type + + undiscounted_subtotal = ( + order.undiscounted_total_net_amount - order.base_shipping_price_amount + ) + assert order_discount.amount.amount == reward_value / 100 * undiscounted_subtotal + assert ( + order.total_net_amount + == order.undiscounted_total_net_amount - order_discount.amount.amount + ) + + +def test_delete_manual_discount_from_order_with_gift_promotion( + draft_order_with_fixed_discount_order, + staff_api_client, + permission_group_manage_orders, + gift_promotion_rule, +): + # given + order = draft_order_with_fixed_discount_order + manual_discount = draft_order_with_fixed_discount_order.discounts.get() + assert order.lines.count() == 2 + + permission_group_manage_orders.user_set.add(staff_api_client.user) + variables = { + "discountId": graphene.Node.to_global_id("OrderDiscount", manual_discount.pk), + } + + # when + response = staff_api_client.post_graphql(ORDER_DISCOUNT_DELETE, variables) + content = get_graphql_content(response) + + # then + data = content["data"]["orderDiscountDelete"] + assert not data["errors"] + + with pytest.raises(manual_discount._meta.model.DoesNotExist): + manual_discount.refresh_from_db() + + order.refresh_from_db() + assert order.lines.count() == 3 + assert not order.discounts.exists() + + gift_line = order.lines.filter(is_gift=True).first() + gift_discount = gift_line.discounts.get() + gift_price = gift_line.variant.channel_listings.get( + channel=order.channel + ).price_amount + + assert gift_discount.value == gift_price + assert gift_discount.amount.amount == gift_price + assert gift_discount.value_type == DiscountValueType.FIXED + + assert order.total_net_amount == order.undiscounted_total_net_amount + + ORDER_LINE_DISCOUNT_UPDATE = """ mutation OrderLineDiscountUpdate($input: OrderDiscountCommonInput!, $orderLineId: ID!){ orderLineDiscountUpdate(orderLineId: $orderLineId, input: $input){ @@ -798,14 +997,23 @@ def test_update_order_line_discount( line_to_discount.undiscounted_total_price = total_price line_to_discount.save() + line_to_discount.discounts.create( + value_type="fixed", + value=0, + amount_value=0, + name="Manual line discount", + type="manual", + ) + line_price_before_discount = line_to_discount.unit_price value = Decimal("5") + value_type = DiscountValueTypeEnum.FIXED reason = "New reason for unit discount" variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line_to_discount.pk), "input": { - "valueType": DiscountValueTypeEnum.FIXED.name, + "valueType": value_type.name, "value": value, "reason": reason, }, @@ -866,9 +1074,16 @@ def test_update_order_line_discount( discount_data = line_data.get("discount") assert discount_data["value"] == str(value) - assert discount_data["value_type"] == DiscountValueTypeEnum.FIXED.value + assert discount_data["value_type"] == value_type.value assert discount_data["amount_value"] == str(unit_discount.amount) + line_discount = line_to_discount.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert line_discount.amount_value == value * line_to_discount.quantity + def test_update_order_line_discount_by_user_no_channel_access( draft_order_with_fixed_discount_order, @@ -916,11 +1131,12 @@ def test_update_order_line_discount_by_app( line_to_discount = order.lines.first() value = Decimal("5") + value_type = DiscountValueTypeEnum.FIXED reason = "New reason for unit discount" variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line_to_discount.pk), "input": { - "valueType": DiscountValueTypeEnum.FIXED.name, + "valueType": value_type.name, "value": value, "reason": reason, }, @@ -953,9 +1169,16 @@ def test_update_order_line_discount_by_app( discount_data = line_data.get("discount") assert discount_data["value"] == str(value) - assert discount_data["value_type"] == DiscountValueTypeEnum.FIXED.value + assert discount_data["value_type"] == value_type.value assert discount_data["amount_value"] == str(unit_discount.amount) + line_discount = line_to_discount.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert line_discount.amount_value == value * line_to_discount.quantity + @pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) def test_update_order_line_discount_line_with_discount( @@ -1002,11 +1225,12 @@ def test_update_order_line_discount_line_with_discount( line_undiscounted_price = line_to_discount.undiscounted_unit_price value = Decimal("50") + value_type = DiscountValueTypeEnum.PERCENTAGE reason = "New reason for unit discount" variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line_to_discount.pk), "input": { - "valueType": DiscountValueTypeEnum.PERCENTAGE.name, + "valueType": value_type.name, "value": value, "reason": reason, }, @@ -1022,7 +1246,6 @@ def test_update_order_line_discount_line_with_discount( data = content["data"]["orderLineDiscountUpdate"] line_to_discount.refresh_from_db() - errors = data["errors"] assert not errors @@ -1047,13 +1270,64 @@ def test_update_order_line_discount_line_with_discount( discount_data = line_data.get("discount") assert discount_data["value"] == str(value) - assert discount_data["value_type"] == DiscountValueTypeEnum.PERCENTAGE.value + assert discount_data["value_type"] == value_type.value assert discount_data["amount_value"] == str(unit_discount.amount) assert discount_data["old_value"] == str(line_discount_value_before_update) assert discount_data["old_value_type"] == DiscountValueTypeEnum.FIXED.value assert discount_data["old_amount_value"] == str(line_discount_amount_before_update) + line_discount = line_to_discount.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert ( + line_discount.amount_value + == line_to_discount.unit_discount_amount * line_to_discount.quantity + ) + + +def test_update_order_line_discount_line_with_catalogue_promotion( + order_with_lines_and_catalogue_promotion, + staff_api_client, + permission_group_manage_orders, +): + # given + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = order_with_lines_and_catalogue_promotion + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + line = order.lines.get(quantity=3) + assert line.discounts.filter(type=DiscountType.PROMOTION).exists() + + value = Decimal("5") + value_type = DiscountValueTypeEnum.FIXED + reason = "Manual fixed line discount" + variables = { + "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), + "input": { + "valueType": value_type.name, + "value": value, + "reason": reason, + }, + } + + # when + response = staff_api_client.post_graphql(ORDER_LINE_DISCOUNT_UPDATE, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineDiscountUpdate"] + assert not data["errors"] + + line_discount = line.discounts.get() + assert line_discount.type == DiscountType.MANUAL + assert line_discount.value == value + assert line_discount.value_type == value_type.value + assert line_discount.reason == reason + assert line_discount.amount_value == value * line.quantity + def test_update_order_line_discount_order_is_not_draft( draft_order_with_fixed_discount_order, @@ -1141,6 +1415,13 @@ def test_delete_discount_from_order_line( line.unit_discount_value = Decimal("2.5") line.save() + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=Decimal("2.5"), + currency=order.currency, + ) + variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), } @@ -1169,6 +1450,8 @@ def test_delete_discount_from_order_line( line_data = lines[0] assert line_data.get("line_pk") == str(line.pk) + assert not line.discounts.exists() + @patch("saleor.plugins.manager.PluginsManager.calculate_order_line_unit") @patch("saleor.plugins.manager.PluginsManager.calculate_order_line_total") @@ -1254,6 +1537,13 @@ def test_delete_discount_from_order_line_by_app( line.unit_discount_value = Decimal("2.5") line.save() + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=Decimal("2.5"), + currency=order.currency, + ) + variables = { "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), } @@ -1287,6 +1577,8 @@ def test_delete_discount_from_order_line_by_app( line_data = lines[0] assert line_data.get("line_pk") == str(line.pk) + assert not line.discounts.exists() + def test_delete_order_line_discount_order_is_not_draft( draft_order_with_fixed_discount_order, @@ -1321,3 +1613,42 @@ def test_delete_order_line_discount_order_is_not_draft( assert error["code"] == OrderErrorCode.CANNOT_DISCOUNT.name assert line.unit_discount_amount == Decimal("2.5") + + +def test_delete_order_line_discount_line_with_catalogue_promotion( + order_with_lines, + staff_api_client, + permission_group_manage_orders, + catalogue_promotion, +): + # given + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = order_with_lines + order.status = OrderStatus.DRAFT + order.save(update_fields=["status"]) + line = order.lines.get(quantity=3) + + manual_reward_value = Decimal(1) + line.discounts.create( + type=DiscountType.MANUAL, + value_type=DiscountValueType.FIXED, + value=manual_reward_value, + amount_value=manual_reward_value * line.quantity, + currency=order.currency, + reason="Manual line discount", + ) + + variables = { + "orderLineId": graphene.Node.to_global_id("OrderLine", line.pk), + } + + # when + response = staff_api_client.post_graphql(ORDER_LINE_DISCOUNT_REMOVE, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineDiscountRemove"] + assert not data["errors"] + # Deleting manual discount should result in creating catalogue discount in this case + # https://github.com/saleor/saleor/issues/15517 + # assert line.discounts.filter(type=DiscountType.PROMOTION).exists() diff --git a/saleor/graphql/order/tests/mutations/test_order_line_update.py b/saleor/graphql/order/tests/mutations/test_order_line_update.py index 9a9e09697c3..19ad0526f9c 100644 --- a/saleor/graphql/order/tests/mutations/test_order_line_update.py +++ b/saleor/graphql/order/tests/mutations/test_order_line_update.py @@ -1,8 +1,10 @@ +from decimal import Decimal from unittest.mock import patch import graphene import pytest +from .....discount import DiscountType, RewardValueType from .....order import OrderStatus from .....order import events as order_events from .....order.error_codes import OrderErrorCode @@ -24,6 +26,12 @@ orderLine { id quantity + unitDiscount { + amount + } + unitDiscountType + unitDiscountValue + isGift } order { total { @@ -31,6 +39,11 @@ amount } } + discounts { + amount { + amount + } + } } } } @@ -417,3 +430,97 @@ def test_order_line_update_quantity_gift( assert len(errors) == 1 assert errors[0]["field"] == "id" assert errors[0]["code"] == OrderErrorCode.NON_EDITABLE_GIFT_LINE.name + + +def test_order_line_update_order_promotion( + draft_order, + staff_api_client, + permission_group_manage_orders, + order_promotion_rule, +): + # given + query = ORDER_LINE_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = draft_order + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + reward_value = Decimal("25") + assert rule.reward_value == reward_value + assert rule.reward_value_type == RewardValueType.PERCENTAGE + + order.lines.last().delete() + line = order.lines.first() + line_id = graphene.Node.to_global_id("OrderLine", line.id) + variant = line.variant + variant_channel_listing = variant.channel_listings.get(channel=order.channel) + quantity = 4 + undiscounted_subtotal = quantity * variant_channel_listing.discounted_price_amount + expected_discount = round(reward_value / 100 * undiscounted_subtotal, 2) + + variables = {"lineId": line_id, "quantity": quantity} + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineUpdate"] + + discounts = data["order"]["discounts"] + assert len(discounts) == 1 + assert discounts[0]["amount"]["amount"] == expected_discount + + discount_db = order.discounts.get() + assert discount_db.promotion_rule == rule + assert discount_db.amount_value == expected_discount + assert discount_db.type == DiscountType.ORDER_PROMOTION + assert discount_db.reason == f"Promotion: {promotion_id}" + + +def test_order_line_update_gift_promotion( + draft_order, + staff_api_client, + permission_group_manage_orders, + gift_promotion_rule, +): + # given + query = ORDER_LINE_UPDATE_MUTATION + permission_group_manage_orders.user_set.add(staff_api_client.user) + order = draft_order + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + order.lines.last().delete() + line = order.lines.first() + line_id = graphene.Node.to_global_id("OrderLine", line.id) + quantity = 4 + + variables = {"lineId": line_id, "quantity": quantity} + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["orderLineUpdate"] + + line = data["orderLine"] + assert line["quantity"] == quantity + assert line["unitDiscount"]["amount"] == 0 + assert line["unitDiscountValue"] == 0 + + gift_line_db = order.lines.get(is_gift=True) + gift_price = gift_line_db.variant.channel_listings.get( + channel=order.channel + ).price_amount + assert gift_line_db.unit_discount_amount == gift_price + assert gift_line_db.unit_price_gross_amount == Decimal(0) + + assert not data["order"]["discounts"] + + discount = gift_line_db.discounts.get() + assert discount.promotion_rule == rule + assert discount.amount_value == gift_price + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" diff --git a/saleor/graphql/order/tests/mutations/test_order_lines_create.py b/saleor/graphql/order/tests/mutations/test_order_lines_create.py index da5680b5891..9c320976a7c 100644 --- a/saleor/graphql/order/tests/mutations/test_order_lines_create.py +++ b/saleor/graphql/order/tests/mutations/test_order_lines_create.py @@ -57,6 +57,12 @@ currency } } + unitDiscount { + amount + } + unitDiscountType + unitDiscountValue + isGift } order { total { @@ -64,6 +70,11 @@ amount } } + discounts { + amount { + amount + } + } } } } @@ -611,7 +622,6 @@ def test_order_lines_create_variant_on_promotion( line_data = data["orderLines"][0] assert line_data["productSku"] == variant.sku assert line_data["quantity"] == quantity - assert line_data["quantity"] == quantity assert ( line_data["unitPrice"]["gross"]["amount"] @@ -644,6 +654,152 @@ def test_order_lines_create_variant_on_promotion( ) +@pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) +@patch("saleor.plugins.manager.PluginsManager.draft_order_updated") +@patch("saleor.plugins.manager.PluginsManager.order_updated") +def test_order_lines_create_order_promotion( + order_updated_webhook_mock, + draft_order_updated_webhook_mock, + status, + order_with_lines, + permission_group_manage_orders, + staff_api_client, + variant_with_many_stocks, + order_promotion_rule, +): + # given + query = ORDER_LINES_CREATE_MUTATION + + order = order_with_lines + order.status = status + order.save(update_fields=["status"]) + order.lines.all().delete() + + rule = order_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + assert rule.reward_value_type == RewardValueType.PERCENTAGE + assert rule.reward_value == Decimal("25") + + variant = variant_with_many_stocks + quantity = 5 + order_id = graphene.Node.to_global_id("Order", order.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + variant_channel_listing = variant.channel_listings.get(channel=order.channel) + expected_discount = round( + quantity * variant_channel_listing.discounted_price.amount * Decimal(0.25), 2 + ) + expected_unit_discount = round(expected_discount / quantity, 2) + + variables = {"orderId": order_id, "variantId": variant_id, "quantity": quantity} + permission_group_manage_orders.user_set.add(staff_api_client.user) + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + assert_proper_webhook_called_once( + order, status, draft_order_updated_webhook_mock, order_updated_webhook_mock + ) + assert OrderEvent.objects.count() == 1 + assert OrderEvent.objects.last().type == order_events.OrderEvents.ADDED_PRODUCTS + content = get_graphql_content(response) + data = content["data"]["orderLinesCreate"] + + line_data = data["orderLines"][0] + assert line_data["productSku"] == variant.sku + assert line_data["quantity"] == quantity + assert line_data["unitDiscount"]["amount"] == 0.00 + assert ( + line_data["unitPrice"]["gross"]["amount"] + == variant_channel_listing.price_amount - expected_unit_discount + ) + assert ( + line_data["unitPrice"]["net"]["amount"] + == variant_channel_listing.price_amount - expected_unit_discount + ) + + line = order.lines.get(product_sku=variant.sku) + assert line.unit_discount_amount == 0 + assert ( + line.unit_price_gross_amount + == variant_channel_listing.discounted_price.amount - expected_unit_discount + ) + + assert len(data["order"]["discounts"]) == 1 + discount = data["order"]["discounts"][0] + assert discount["amount"]["amount"] == expected_discount + + discount = order.discounts.get() + assert discount.promotion_rule == rule + assert discount.amount_value == expected_discount + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + +@pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) +@patch("saleor.plugins.manager.PluginsManager.draft_order_updated") +@patch("saleor.plugins.manager.PluginsManager.order_updated") +def test_order_lines_create_gift_promotion( + order_updated_webhook_mock, + draft_order_updated_webhook_mock, + status, + order_with_lines, + permission_group_manage_orders, + staff_api_client, + variant_with_many_stocks, + gift_promotion_rule, +): + # given + query = ORDER_LINES_CREATE_MUTATION + + order = order_with_lines + order.status = status + order.save(update_fields=["status"]) + order.lines.all().delete() + + rule = gift_promotion_rule + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + variant = variant_with_many_stocks + quantity = 5 + order_id = graphene.Node.to_global_id("Order", order.id) + variant_id = graphene.Node.to_global_id("ProductVariant", variant.id) + + variables = {"orderId": order_id, "variantId": variant_id, "quantity": quantity} + permission_group_manage_orders.user_set.add(staff_api_client.user) + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + assert_proper_webhook_called_once( + order, status, draft_order_updated_webhook_mock, order_updated_webhook_mock + ) + assert OrderEvent.objects.count() == 1 + assert OrderEvent.objects.last().type == order_events.OrderEvents.ADDED_PRODUCTS + content = get_graphql_content(response) + data = content["data"]["orderLinesCreate"] + + lines = data["orderLines"] + # gift line is not returned + assert len(lines) == 1 + + gift_line_db = order.lines.get(is_gift=True) + gift_price = gift_line_db.variant.channel_listings.get( + channel=order.channel + ).price_amount + assert gift_line_db.unit_discount_amount == gift_price + assert gift_line_db.unit_price_gross_amount == Decimal(0) + + assert not data["order"]["discounts"] + + discount = gift_line_db.discounts.get() + assert discount.promotion_rule == rule + assert discount.amount_value == gift_price + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + @pytest.mark.parametrize("status", [OrderStatus.DRAFT, OrderStatus.UNCONFIRMED]) @patch("saleor.plugins.manager.PluginsManager.draft_order_updated") @patch("saleor.plugins.manager.PluginsManager.order_updated") diff --git a/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py b/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py index d7c15b56ac7..de79b9412f3 100644 --- a/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py +++ b/saleor/graphql/order/tests/queries/test_draft_order_with_filter.py @@ -3,6 +3,7 @@ import graphene import pytest +from django.core.exceptions import ValidationError from freezegun import freeze_time from .....core.postgres import FlatConcatSearchVector @@ -13,7 +14,9 @@ prepare_order_search_vector_value, update_order_search_vector, ) +from ....core.connection import where_filter_qs from ....tests.utils import get_graphql_content +from ...filters import OrderDiscountedObjectWhere @pytest.fixture @@ -218,3 +221,243 @@ def test_draft_orders_query_with_filter_search( response = staff_api_client.post_graphql(draft_orders_query_with_filter, variables) content = get_graphql_content(response) assert content["data"]["draftOrders"]["totalCount"] == count + + +@pytest.mark.parametrize(("gte", "count"), [(20, 1), (0, 1), (500, 0), (20.01, 0)]) +def test_draft_orders_query_with_filter_base_total_price_range(draft_order, gte, count): + # given + order = draft_order + currency = order.currency + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_total_price": { + "range": { + "gte": gte, + } + }, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +@pytest.mark.parametrize(("gte", "count"), [(20, 1), (0, 1), (500, 0), (20.01, 0)]) +def test_draft_orders_query_with_filter_base_subtotal_price_range( + draft_order, gte, count +): + # given + order = draft_order + currency = order.currency + order.subtotal_net_amount = Decimal("20") + order.save(update_fields=["subtotal_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_subtotal_price": { + "range": { + "gte": gte, + } + }, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +@pytest.mark.parametrize( + ("one_of", "count"), [([1, 20, 70], 1), ([3, 20.1], 0), ([-3, 0], 0)] +) +def test_draft_orders_query_with_filter_base_total_price_one_of( + draft_order, one_of, count +): + # given + order = draft_order + currency = order.currency + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_total_price": {"one_of": one_of}, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +@pytest.mark.parametrize( + ("one_of", "count"), [([1, 20, 70], 1), ([3, 20.1], 0), ([-3, 0], 0)] +) +def test_draft_orders_query_with_filter_base_subtotal_price_one_of( + draft_order, one_of, count +): + # given + order = draft_order + currency = order.currency + order.subtotal_net_amount = Decimal("20") + order.save(update_fields=["subtotal_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "currency": currency, + "base_subtotal_price": {"one_of": one_of}, + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == count + if count: + assert result.first() == order + + +def test_draft_orders_query_with_filter_base_total_price_missing_currency(draft_order): + # given + order = draft_order + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "base_total_price": { + "range": { + "gte": 20, + } + }, + } + + # when + with pytest.raises(ValidationError) as validation_error: + where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert validation_error.value.code == "required" + + +def test_draft_orders_query_with_filter_base_subtotal_price_missing_currency( + draft_order, +): + # given + order = draft_order + order.subtotal_net_amount = Decimal("20") + order.save(update_fields=["subtotal_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "base_subtotal_price": { + "range": { + "gte": 20, + } + }, + } + + # when + with pytest.raises(ValidationError) as validation_error: + where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert validation_error.value.code == "required" + + +def test_draft_orders_query_with_filter_price_with_and_or(draft_order): + # given + order = draft_order + currency = order.currency + order.total_net_amount = Decimal("20") + order.save(update_fields=["total_net_amount"]) + + qs = Order.objects.all() + predicate_data = { + "AND": [ + { + "OR": [ + { + "currency": currency, + "base_total_price": { + "range": { + "gte": 20, + } + }, + }, + { + "currency": currency, + "base_total_price": { + "range": { + "gte": 10, + } + }, + }, + ] + } + ], + } + + # when + result = where_filter_qs( + qs, + {}, + OrderDiscountedObjectWhere, + predicate_data, + None, + ) + + # then + assert result.count() == 1 diff --git a/saleor/graphql/order/types.py b/saleor/graphql/order/types.py index d581bc2853b..57154d85fa7 100644 --- a/saleor/graphql/order/types.py +++ b/saleor/graphql/order/types.py @@ -25,6 +25,7 @@ from ...graphql.utils import get_user_or_app_from_context from ...graphql.warehouse.dataloaders import StockByIdLoader, WarehouseByIdLoader from ...order import OrderStatus, calculations, models +from ...order.calculations import fetch_order_prices_if_expired from ...order.models import FulfillmentStatus from ...order.utils import ( get_order_country, @@ -1526,7 +1527,11 @@ def resolve_token(root: models.Order, info): @staticmethod def resolve_discounts(root: models.Order, info): - return OrderDiscountsByOrderIDLoader(info.context).load(root.id) + def with_manager(manager): + fetch_order_prices_if_expired(root, manager) + return OrderDiscountsByOrderIDLoader(info.context).load(root.id) + + return get_plugin_manager_promise(info.context).then(with_manager) @staticmethod @traced_resolver diff --git a/saleor/order/base_calculations.py b/saleor/order/base_calculations.py index 219b77f2c84..e5848c04eb8 100644 --- a/saleor/order/base_calculations.py +++ b/saleor/order/base_calculations.py @@ -109,6 +109,13 @@ def propagate_order_discount_on_order_prices( currency=currency, price_to_discount=subtotal, ) + elif order_discount.type == DiscountType.ORDER_PROMOTION: + subtotal = apply_discount_to_value( + value=order_discount.value, + value_type=order_discount.value_type, + currency=currency, + price_to_discount=subtotal, + ) elif order_discount.type == DiscountType.MANUAL: if order_discount.value_type == DiscountValueType.PERCENTAGE: subtotal = apply_discount_to_value( @@ -123,7 +130,7 @@ def propagate_order_discount_on_order_prices( currency=currency, price_to_discount=shipping_price, ) - else: + elif order_discount.value_type == DiscountValueType.FIXED: temporary_undiscounted_total = subtotal + shipping_price if temporary_undiscounted_total.amount > 0: temporary_total = apply_discount_to_value( diff --git a/saleor/order/calculations.py b/saleor/order/calculations.py index bc23505a2d3..4ac30e918f5 100644 --- a/saleor/order/calculations.py +++ b/saleor/order/calculations.py @@ -11,6 +11,7 @@ from ..core.prices import quantize_price from ..core.taxes import TaxData, TaxEmptyData, TaxError, zero_taxed_money from ..discount import DiscountType +from ..discount.utils import create_or_update_discount_objects_from_promotion_for_order from ..payment.model_helpers import get_subtotal from ..plugins import PLUGIN_IDENTIFIER_PREFIX from ..plugins.manager import PluginsManager @@ -25,6 +26,7 @@ ) from . import ORDER_EDITABLE_STATUS from .base_calculations import apply_order_discounts, base_order_line_total +from .fetch import DraftOrderLineInfo, fetch_draft_order_lines_info from .interface import OrderTaxedPricesData from .models import Order, OrderLine @@ -49,18 +51,18 @@ def fetch_order_prices_if_expired( if not force_update and not order.should_refresh_prices: return order, lines - if lines is None: - lines = list( - order.lines.using(database_connection_name).select_related( - "variant__product__product_type" - ) - ) - else: - prefetch_related_objects(lines, "variant__product__product_type") + # handle promotions + lines_info: list[DraftOrderLineInfo] = fetch_draft_order_lines_info(order, lines) + create_or_update_discount_objects_from_promotion_for_order( + order, lines_info, database_connection_name + ) + lines = [line_info.line for line_info in lines_info] + _update_order_discount_for_voucher(order) - order.should_refresh_prices = False + _clear_prefetched_discounts(order, lines) + prefetch_related_objects([order], "discounts") - _update_order_discount_for_voucher(order) + # handle taxes _recalculate_prices( order, manager, @@ -68,7 +70,7 @@ def fetch_order_prices_if_expired( database_connection_name=database_connection_name, ) - order.subtotal = get_subtotal(lines, order.currency) + order.should_refresh_prices = False with transaction.atomic(savepoint=False): with allow_writer(): order.save( @@ -98,6 +100,11 @@ def fetch_order_prices_if_expired( "undiscounted_total_price_net_amount", "undiscounted_total_price_gross_amount", "tax_rate", + "unit_discount_amount", + "unit_discount_reason", + "unit_discount_type", + "unit_discount_value", + "base_unit_price_amount", ], ) @@ -128,14 +135,14 @@ def _update_order_discount_for_voucher(order: Order): voucher_code=order.voucher_code, ) - # Prefetch has to be cleared and refreshed to avoid returning cached discounts - if ( - hasattr(order, "_prefetched_objects_cache") - and "discounts" in order._prefetched_objects_cache - ): - del order._prefetched_objects_cache["discounts"] - prefetch_related_objects([order], "discounts") +def _clear_prefetched_discounts(order, lines): + if hasattr(order, "_prefetched_objects_cache"): + order._prefetched_objects_cache.pop("discounts", None) + + for line in lines: + if hasattr(line, "_prefetched_objects_cache"): + line._prefetched_objects_cache.pop("discounts", None) def _recalculate_prices( @@ -338,6 +345,7 @@ def _recalculate_with_plugins( order.undiscounted_total = undiscounted_subtotal + TaxedMoney( net=order.base_shipping_price, gross=order.base_shipping_price ) + order.subtotal = get_subtotal(lines, order.currency) order.total = manager.calculate_order_total(order, lines, plugin_ids=plugin_ids) @@ -387,12 +395,14 @@ def _apply_tax_data( order_line.tax_rate = normalize_tax_rate_for_db(tax_line.tax_rate) subtotal += line_total_price + order.subtotal = subtotal order.total = shipping_price + subtotal def _remove_tax(order, lines): order.total_gross_amount = order.total_net_amount order.undiscounted_total_gross_amount = order.undiscounted_total_net_amount + order.subtotal_gross_amount = order.subtotal_net_amount order.shipping_price_gross_amount = order.shipping_price_net_amount order.shipping_tax_rate = Decimal("0.00") diff --git a/saleor/order/fetch.py b/saleor/order/fetch.py index 9d1b364e4a1..5bab6e8ab90 100644 --- a/saleor/order/fetch.py +++ b/saleor/order/fetch.py @@ -1,14 +1,21 @@ from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional, cast from uuid import UUID -if TYPE_CHECKING: - from ..channel.models import Channel - from ..discount.models import OrderLineDiscount - from ..payment.models import Payment - from ..product.models import DigitalContent, ProductVariant - from .models import Order, OrderLine +from django.db.models import prefetch_related_objects + +from ..channel.models import Channel +from ..discount import DiscountType +from ..discount.interface import VariantPromotionRuleInfo, fetch_variant_rules_info +from ..discount.models import OrderLineDiscount, Voucher +from ..payment.models import Payment +from ..product.models import ( + DigitalContent, + ProductVariant, + ProductVariantChannelListing, +) +from .models import Order, OrderLine @dataclass @@ -63,3 +70,78 @@ def fetch_order_lines(order: "Order") -> list[OrderLineInfo]: ) return lines_info + + +@dataclass +class DraftOrderLineInfo: + line: "OrderLine" + variant: "ProductVariant" + channel_listing: "ProductVariantChannelListing" + discounts: list["OrderLineDiscount"] + rules_info: list["VariantPromotionRuleInfo"] + channel: "Channel" + voucher: Optional["Voucher"] = None + + def get_promotion_discounts(self) -> list["OrderLineDiscount"]: + return [ + discount + for discount in self.discounts + if discount.type in [DiscountType.PROMOTION, DiscountType.ORDER_PROMOTION] + ] + + def get_catalogue_discounts(self) -> list["OrderLineDiscount"]: + return [ + discount + for discount in self.discounts + if discount.type == DiscountType.PROMOTION + ] + + +def fetch_draft_order_lines_info( + order: "Order", lines: Optional[Iterable["OrderLine"]] = None +) -> list[DraftOrderLineInfo]: + prefetch_related_fields = [ + "discounts__promotion_rule__promotion", + "variant__channel_listings__variantlistingpromotionrule__promotion_rule__promotion__translations", + "variant__channel_listings__variantlistingpromotionrule__promotion_rule__translations", + ] + if lines is None: + lines = list(order.lines.prefetch_related(*prefetch_related_fields)) + else: + prefetch_related_objects(lines, *prefetch_related_fields) + + lines_info = [] + channel = order.channel + for line in lines: + variant = cast(ProductVariant, line.variant) + variant_channel_listing = get_prefetched_variant_listing(variant, channel.id) + if not variant_channel_listing: + continue + + rules_info = ( + fetch_variant_rules_info(variant_channel_listing, order.language_code) + if not line.is_gift + else [] + ) + lines_info.append( + DraftOrderLineInfo( + line=line, + variant=variant, + channel_listing=variant_channel_listing, + discounts=list(line.discounts.all()), + rules_info=rules_info, + channel=channel, + ) + ) + return lines_info + + +def get_prefetched_variant_listing( + variant: Optional[ProductVariant], channel_id: int +) -> Optional[ProductVariantChannelListing]: + if not variant: + return None + for channel_listing in variant.channel_listings.all(): + if channel_listing.channel_id == channel_id: + return channel_listing + return None diff --git a/saleor/order/tests/benchmark/__init__.py b/saleor/order/tests/benchmark/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/saleor/order/tests/benchmark/test_fetch_order_prices.py b/saleor/order/tests/benchmark/test_fetch_order_prices.py new file mode 100644 index 00000000000..3a2a866ad80 --- /dev/null +++ b/saleor/order/tests/benchmark/test_fetch_order_prices.py @@ -0,0 +1,144 @@ +from decimal import Decimal + +import pytest +from prices import Money, TaxedMoney + +from ....discount import RewardValueType +from ....discount.models import OrderDiscount, OrderLineDiscount, PromotionRule +from ....product.models import Product +from ....product.utils.variant_prices import update_discounted_prices_for_promotion +from ....product.utils.variants import fetch_variants_for_promotion_rules +from ....tax import TaxCalculationStrategy +from ....warehouse.models import Stock +from ... import OrderStatus +from ...calculations import fetch_order_prices_if_expired +from ...models import OrderLine + + +@pytest.fixture +def order_with_lines(order_with_lines): + order_with_lines.status = OrderStatus.UNCONFIRMED + return order_with_lines + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_fetch_order_prices_catalogue_discount( + order_with_lines_and_catalogue_promotion, + plugins_manager, + django_assert_num_queries, + count_queries, +): + # given + OrderLineDiscount.objects.all().delete() + order = order_with_lines_and_catalogue_promotion + channel = order.channel + + tc = channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + + # when + with django_assert_num_queries(38): + fetch_order_prices_if_expired(order, plugins_manager, None, True) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_fetch_order_prices_multiple_catalogue_discounts( + order_with_lines, + catalogue_promotion_without_rules, + plugins_manager, + product_variant_list, + django_assert_num_queries, + count_queries, +): + # given + Stock.objects.update(quantity=100) + order = order_with_lines + channel = order.channel + + variants = product_variant_list + variants.extend(line.variant for line in order.lines.all()) + variant_global_ids = [variant.get_global_id() for variant in variants] + + # create many rules + promotion = catalogue_promotion_without_rules + rules = [] + catalogue_predicate = {"variantPredicate": {"ids": variant_global_ids}} + for idx in range(5): + reward_value = 2 + idx + rules.append( + PromotionRule( + name=f"Catalogue rule fixed {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(reward_value), + ) + ) + for idx in range(5): + reward_value = idx * 10 + 25 + rules.append( + PromotionRule( + name=f"Catalogue rule percentage {reward_value}", + promotion=promotion, + catalogue_predicate=catalogue_predicate, + reward_value_type=RewardValueType.PERCENTAGE, + reward_value=Decimal(reward_value), + ) + ) + rules = PromotionRule.objects.bulk_create(rules) + for rule in rules: + rule.channels.add(channel) + + fetch_variants_for_promotion_rules(PromotionRule.objects.all()) + + # update prices + update_discounted_prices_for_promotion(Product.objects.all()) + + # create lines + new_order_lines = [] + for idx, variant in enumerate(product_variant_list): + base_price = variant.channel_listings.first().discounted_price + currency = base_price.currency + gross = Money(amount=base_price.amount * Decimal(1.23), currency=currency) + quantity = 4 + idx + unit_price = TaxedMoney(net=base_price, gross=gross) + new_order_lines.append( + OrderLine( + order=order, + product_name=str(variant.product), + variant_name=str(variant), + product_sku=variant.sku, + product_variant_id=variant.get_global_id(), + is_shipping_required=variant.is_shipping_required(), + is_gift_card=variant.is_gift_card(), + quantity=quantity, + variant=variant, + unit_price=unit_price, + total_price=unit_price * quantity, + tax_rate=Decimal("0.23"), + ) + ) + OrderLine.objects.bulk_create(new_order_lines) + + tc = channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + + # when + with django_assert_num_queries(38): + fetch_order_prices_if_expired(order, plugins_manager, None, True) + + # then + assert OrderLineDiscount.objects.count() == 7 + assert not OrderDiscount.objects.exists() diff --git a/saleor/order/tests/test_discount_calculations.py b/saleor/order/tests/test_apply_order_discount.py similarity index 94% rename from saleor/order/tests/test_discount_calculations.py rename to saleor/order/tests/test_apply_order_discount.py index 66767e5d094..f22c5dc56f8 100644 --- a/saleor/order/tests/test_discount_calculations.py +++ b/saleor/order/tests/test_apply_order_discount.py @@ -1,7 +1,7 @@ from decimal import Decimal import pytest -from prices import Money, TaxedMoney +from prices import Money from ...core.prices import quantize_price from ...core.taxes import zero_money @@ -9,51 +9,7 @@ from ...order.base_calculations import ( apply_order_discounts, apply_subtotal_discount_to_order_lines, - base_order_line_total, - base_order_total, ) -from ...order.interface import OrderTaxedPricesData - - -def test_base_order_total(order_with_lines): - # given - order = order_with_lines - lines = order.lines.all() - shipping_price = order.shipping_price.net - subtotal = zero_money(order.currency) - for line in lines: - subtotal += line.base_unit_price * line.quantity - undiscounted_total = subtotal + shipping_price - - # when - order_total = base_order_total(order, lines) - - # then - assert order_total == undiscounted_total - - -def test_base_order_line_total(order_with_lines): - # given - line = order_with_lines.lines.all().first() - - # when - order_total = base_order_line_total(line) - - # then - base_line_unit_price = line.base_unit_price - quantity = line.quantity - expected_price_with_discount = ( - TaxedMoney(base_line_unit_price, base_line_unit_price) * quantity - ) - base_line_undiscounted_unit_price = line.undiscounted_base_unit_price - expected_undiscounted_price = ( - TaxedMoney(base_line_undiscounted_unit_price, base_line_undiscounted_unit_price) - * quantity - ) - assert order_total == OrderTaxedPricesData( - price_with_discounts=expected_price_with_discount, - undiscounted_price=expected_undiscounted_price, - ) def test_apply_order_discounts_voucher_entire_order(order_with_lines, voucher): diff --git a/saleor/order/tests/test_fetch.py b/saleor/order/tests/test_fetch.py new file mode 100644 index 00000000000..684d97115d1 --- /dev/null +++ b/saleor/order/tests/test_fetch.py @@ -0,0 +1,73 @@ +from decimal import Decimal + +import pytest + +from ...product.models import ProductVariantChannelListing +from ..fetch import fetch_draft_order_lines_info + + +@pytest.mark.django_db +@pytest.mark.count_queries(autouse=False) +def test_fetch_draft_order_lines_info( + draft_order_and_promotions, django_assert_num_queries, count_queries +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + channel = order.channel + lines = order.lines.all() + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + manual_discount = line_1.discounts.create( + value=Decimal(1), + amount_value=Decimal(1), + currency=channel.currency_code, + name="Manual line discount", + ) + rule_translation = "Rule translation" + rule_catalogue.translations.create( + language_code=order.language_code, name=rule_translation + ) + promotion_translation = "Promotion" + rule_catalogue.promotion.translations.create( + language_code=order.language_code, name=promotion_translation + ) + + # when + with django_assert_num_queries(9): + lines_info = fetch_draft_order_lines_info(order) + + # then + line_info_1 = [line_info for line_info in lines_info if line_info.line == line_1][0] + line_info_2 = [line_info for line_info in lines_info if line_info.line == line_2][0] + + variant_1 = line_1.variant + assert line_info_1.variant == variant_1 + assert ( + line_info_1.channel_listing + == ProductVariantChannelListing.objects.filter( + channel=channel, variant=variant_1 + ).first() + ) + assert line_info_1.discounts == [manual_discount] + assert line_info_1.channel == channel + assert line_info_1.voucher is None + assert line_info_1.rules_info == [] + + variant_2 = line_2.variant + assert line_info_2.variant == variant_2 + assert ( + line_info_2.channel_listing + == ProductVariantChannelListing.objects.filter( + channel=channel, variant=variant_2 + ).first() + ) + assert line_info_2.discounts == [] + assert line_info_2.channel == channel + assert line_info_2.voucher is None + rule_info_2 = line_info_2.rules_info[0] + assert rule_info_2.rule == rule_catalogue + assert rule_info_2.variant_listing_promotion_rule + assert rule_info_2.promotion == rule_catalogue.promotion + assert rule_info_2.promotion_translation.name == promotion_translation + assert rule_info_2.rule_translation.name == rule_translation diff --git a/saleor/order/tests/test_fetch_order_prices.py b/saleor/order/tests/test_fetch_order_prices.py new file mode 100644 index 00000000000..5ccf22d535b --- /dev/null +++ b/saleor/order/tests/test_fetch_order_prices.py @@ -0,0 +1,1394 @@ +from decimal import Decimal + +import before_after +import graphene +import pytest + +from ...core.prices import quantize_price +from ...discount import DiscountType, DiscountValueType +from ...discount.models import ( + OrderDiscount, + OrderLineDiscount, + PromotionRule, +) +from ...tax import TaxCalculationStrategy +from ...tests.utils import round_down, round_up +from .. import OrderStatus, calculations + + +@pytest.fixture +def order_with_lines(order_with_lines): + order_with_lines.status = OrderStatus.UNCONFIRMED + return order_with_lines + + +@pytest.mark.parametrize("create_new_discounts", [True, False]) +def test_fetch_order_prices_catalogue_discount_flat_rates( + order_with_lines_and_catalogue_promotion, + plugins_manager, + create_new_discounts, +): + # given + if create_new_discounts: + OrderLineDiscount.objects.all().delete() + + order = order_with_lines_and_catalogue_promotion + channel = order.channel + rule = PromotionRule.objects.get() + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + reward_value = rule.reward_value + + tc = channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + discount = line_1.discounts.get() + reward_amount = reward_value * line_1.quantity + assert discount.amount_value == reward_amount + assert discount.value == reward_value + assert discount.value_type == DiscountValueType.FIXED + assert discount.type == DiscountType.PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=channel) + variant_1_unit_price = variant_1_listing.discounted_price_amount + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + assert variant_1_undiscounted_unit_price - variant_1_unit_price == reward_value + + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert ( + line_1.base_unit_price_amount + == variant_1_undiscounted_unit_price - reward_value + ) + assert ( + line_1.unit_price_net_amount == variant_1_undiscounted_unit_price - reward_value + ) + assert line_1.unit_price_gross_amount == line_1.unit_price_net_amount * tax_rate + assert ( + line_1.total_price_net_amount == line_1.unit_price_net_amount * line_1.quantity + ) + assert line_1.total_price_gross_amount == line_1.total_price_net_amount * tax_rate + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.unit_price_net_amount == variant_2_undiscounted_unit_price + assert ( + line_2.unit_price_gross_amount == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.total_price_net_amount == line_2.undiscounted_total_price_net_amount + assert ( + line_2.total_price_gross_amount == line_2.undiscounted_total_price_gross_amount + ) + + shipping_net_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_net_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == order.undiscounted_total_net_amount - reward_amount + assert order.total_gross_amount == order.total_net_amount * tax_rate + assert order.subtotal_net_amount == order.total_net_amount - shipping_net_price + assert order.subtotal_gross_amount == order.subtotal_net_amount * tax_rate + + assert line_1.unit_discount_amount == reward_value + assert line_1.unit_discount_reason == f"Promotion: {promotion_id}" + assert line_1.unit_discount_type == DiscountValueType.FIXED + assert line_1.unit_discount_value == reward_value + + +@pytest.mark.parametrize("create_new_discounts", [True, False]) +def test_fetch_order_prices_order_discount_flat_rates( + order_with_lines_and_order_promotion, + plugins_manager, + create_new_discounts, +): + # given + if create_new_discounts: + OrderDiscount.objects.all().delete() + + order = order_with_lines_and_order_promotion + currency = order.currency + rule = PromotionRule.objects.get() + reward_amount = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert not OrderLineDiscount.objects.exists() + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + discount = OrderDiscount.objects.get() + + line_1_base_total = line_1.quantity * line_1.base_unit_price_amount + line_2_base_total = line_2.quantity * line_2.base_unit_price_amount + base_total = line_1_base_total + line_2_base_total + line_1_order_discount_portion = reward_amount * line_1_base_total / base_total + line_2_order_discount_portion = reward_amount - line_1_order_discount_portion + + assert discount.order == order + assert discount.amount_value == reward_amount + assert discount.value == reward_amount + assert discount.value_type == DiscountValueType.FIXED + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount - line_1_order_discount_portion, + currency, + ) + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == round_down( + line_1_total_net_amount * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.unit_price_net_amount == line_1_total_net_amount / line_1.quantity + assert line_1.unit_price_gross_amount == quantize_price( + line_1.unit_price_net_amount * tax_rate, currency + ) + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount - line_2_order_discount_portion, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == round_up( + line_2_total_net_amount * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == round_down( + line_2.unit_price_net_amount * tax_rate + ) + + shipping_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert ( + order.total_net_amount + == line_1_total_net_amount + line_2_total_net_amount + shipping_price + ) + assert order.total_gross_amount == order.total_net_amount * tax_rate + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == order.subtotal_net_amount * tax_rate + + +@pytest.mark.parametrize("create_new_discounts", [True, False]) +def test_fetch_order_prices_gift_discount_flat_rates( + order_with_lines_and_gift_promotion, + plugins_manager, + create_new_discounts, +): + # given + if create_new_discounts: + OrderLineDiscount.objects.all().delete() + + order = order_with_lines_and_gift_promotion + rule = PromotionRule.objects.get() + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert len(lines) == 3 + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + gift_line = [line for line in lines if line.is_gift][0] + assert not line_1.discounts.exists() + assert not line_2.discounts.exists() + discount = OrderLineDiscount.objects.get() + + variant_gift = gift_line.variant + variant_gift_listing = variant_gift.channel_listings.get(channel=order.channel) + variant_gift_undiscounted_unit_price = variant_gift_listing.price_amount + + assert discount.line == gift_line + assert discount.amount_value == variant_gift_undiscounted_unit_price + assert discount.value == variant_gift_undiscounted_unit_price + assert discount.value_type == DiscountValueType.FIXED + assert discount.type == DiscountType.ORDER_PROMOTION + assert discount.reason == f"Promotion: {promotion_id}" + + assert gift_line.unit_discount_amount == variant_gift_undiscounted_unit_price + assert gift_line.unit_discount_reason == f"Promotion: {promotion_id}" + assert gift_line.unit_discount_type == DiscountValueType.FIXED + assert gift_line.unit_discount_value == variant_gift_undiscounted_unit_price + assert gift_line.undiscounted_total_price_net_amount == Decimal(0) + assert gift_line.undiscounted_total_price_gross_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_net_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_gross_amount == Decimal(0) + assert gift_line.total_price_net_amount == Decimal(0) + assert gift_line.total_price_gross_amount == Decimal(0) + assert gift_line.base_unit_price_amount == Decimal(0) + assert gift_line.unit_price_net_amount == Decimal(0) + assert gift_line.unit_price_gross_amount == Decimal(0) + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.total_price_net_amount == line_1.undiscounted_total_price_net_amount + assert ( + line_1.total_price_gross_amount == line_1.undiscounted_total_price_gross_amount + ) + assert line_1.base_unit_price_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_net_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_gross_amount == line_1.undiscounted_unit_price_gross_amount + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.total_price_net_amount == line_2.undiscounted_total_price_net_amount + assert ( + line_2.total_price_gross_amount == line_2.undiscounted_total_price_gross_amount + ) + assert line_2.base_unit_price_amount == line_2.undiscounted_unit_price_net_amount + assert line_2.unit_price_net_amount == line_2.undiscounted_unit_price_net_amount + assert line_2.unit_price_gross_amount == line_2.undiscounted_unit_price_gross_amount + + shipping_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == order.undiscounted_total_net_amount + assert order.total_gross_amount == order.undiscounted_total_gross_amount + assert ( + order.subtotal_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + ) + assert order.subtotal_gross_amount == order.subtotal_net_amount * tax_rate + + +def test_fetch_order_prices_catalogue_and_order_discounts_flat_rates( + draft_order_and_promotions, + plugins_manager, +): + # given + order, rule_catalogue, rule_total, _ = draft_order_and_promotions + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + order_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_total.promotion_id + ) + rule_catalogue_reward = rule_catalogue.reward_value + rule_total_reward = rule_total.reward_value + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + catalogue_discount = OrderLineDiscount.objects.get() + order_discount = OrderDiscount.objects.get() + + line_1_base_total = line_1.quantity * line_1.base_unit_price_amount + line_2_base_total = line_2.quantity * line_2.base_unit_price_amount + base_total = line_1_base_total + line_2_base_total + line_1_order_discount_portion = rule_total_reward * line_1_base_total / base_total + line_2_order_discount_portion = rule_total_reward - line_1_order_discount_portion + + assert order_discount.order == order + assert order_discount.amount_value == rule_total_reward + assert order_discount.value == rule_total_reward + assert order_discount.value_type == DiscountValueType.FIXED + assert order_discount.type == DiscountType.ORDER_PROMOTION + assert order_discount.reason == f"Promotion: {order_promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount - line_1_order_discount_portion, + currency, + ) + assert not line_1.discounts.exists() + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == round_up( + line_1_total_net_amount * tax_rate + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == round_up( + line_1.unit_price_net_amount * tax_rate + ) + + assert catalogue_discount.line == line_2 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_2.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {catalogue_promotion_id}" + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount + - line_2_order_discount_portion + - catalogue_discount.amount_value, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert ( + line_2.base_unit_price_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == round_down( + line_2_total_net_amount * tax_rate + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + shipping_price = order.shipping_price_net_amount + total_net_amount = quantize_price( + order.undiscounted_total_net_amount + - order_discount.amount_value + - catalogue_discount.amount_value, + currency, + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_catalogue_and_gift_discounts_flat_rates( + draft_order_and_promotions, + plugins_manager, +): + # given + order, rule_catalogue, rule_total, rule_gift = draft_order_and_promotions + rule_total.reward_value = Decimal(0) + rule_total.save(update_fields=["reward_value"]) + + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + gift_promotion_id = graphene.Node.to_global_id("Promotion", rule_gift.promotion_id) + rule_catalogue_reward = rule_catalogue.reward_value + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert len(lines) == 3 + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + gift_line = [line for line in lines if line.is_gift][0] + + assert OrderLineDiscount.objects.count() == 2 + gift_discount = gift_line.discounts.get() + catalogue_discount = line_2.discounts.get() + + variant_gift = gift_line.variant + variant_gift_listing = variant_gift.channel_listings.get(channel=order.channel) + variant_gift_undiscounted_unit_price = variant_gift_listing.price_amount + + assert gift_discount.line == gift_line + assert gift_discount.amount_value == variant_gift_undiscounted_unit_price + assert gift_discount.value == variant_gift_undiscounted_unit_price + assert gift_discount.value_type == DiscountValueType.FIXED + assert gift_discount.type == DiscountType.ORDER_PROMOTION + assert gift_discount.reason == f"Promotion: {gift_promotion_id}" + + assert gift_line.unit_discount_amount == variant_gift_undiscounted_unit_price + assert gift_line.unit_discount_reason == f"Promotion: {gift_promotion_id}" + assert gift_line.unit_discount_type == DiscountValueType.FIXED + assert gift_line.unit_discount_value == variant_gift_undiscounted_unit_price + assert gift_line.undiscounted_total_price_net_amount == Decimal(0) + assert gift_line.undiscounted_total_price_gross_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_net_amount == Decimal(0) + assert gift_line.undiscounted_unit_price_gross_amount == Decimal(0) + assert gift_line.total_price_net_amount == Decimal(0) + assert gift_line.total_price_gross_amount == Decimal(0) + assert gift_line.base_unit_price_amount == Decimal(0) + assert gift_line.unit_price_net_amount == Decimal(0) + assert gift_line.unit_price_gross_amount == Decimal(0) + + assert not line_1.discounts.exists() + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = line_1.undiscounted_total_price_net_amount + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.total_price_net_amount == line_1.undiscounted_total_price_net_amount + assert ( + line_1.total_price_gross_amount == line_1.undiscounted_total_price_gross_amount + ) + assert line_1.base_unit_price_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_net_amount == line_1.undiscounted_unit_price_net_amount + assert line_1.unit_price_gross_amount == line_1.undiscounted_unit_price_gross_amount + + assert catalogue_discount.line == line_2 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_2.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {catalogue_promotion_id}" + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount - catalogue_discount.amount_value, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert ( + line_2.base_unit_price_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert ( + line_2.unit_price_net_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + shipping_price = order.shipping_price_net_amount + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + total_net_amount = quantize_price( + order.undiscounted_total_net_amount - catalogue_discount.amount_value, + currency, + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_catalogue_and_order_discounts_exceed_total_flat_rates( + draft_order_and_promotions, + plugins_manager, +): + # given + order, rule_catalogue, rule_total, _ = draft_order_and_promotions + rule_total.reward_value = Decimal(100000) + rule_total.save(update_fields=["reward_value"]) + catalogue_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_catalogue.promotion_id + ) + order_promotion_id = graphene.Node.to_global_id( + "Promotion", rule_total.promotion_id + ) + rule_catalogue_reward = rule_catalogue.reward_value + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + catalogue_discount = OrderLineDiscount.objects.get() + order_discount = OrderDiscount.objects.get() + + shipping_price = order.shipping_price_net_amount + rule_total_reward = quantize_price( + order.undiscounted_total_net_amount + - shipping_price + - rule_catalogue_reward * line_2.quantity, + currency, + ) + assert order_discount.order == order + assert order_discount.amount_value == rule_total_reward + assert order_discount.value == rule_total.reward_value + assert order_discount.value_type == DiscountValueType.FIXED + assert order_discount.type == DiscountType.ORDER_PROMOTION + assert order_discount.reason == f"Promotion: {order_promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + assert not line_1.discounts.exists() + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == Decimal(0) + assert line_1.total_price_gross_amount == Decimal(0) + assert line_1.unit_price_net_amount == Decimal(0) + assert line_1.unit_price_gross_amount == Decimal(0) + + assert catalogue_discount.line == line_2 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_2.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {catalogue_promotion_id}" + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert ( + line_2.base_unit_price_amount + == variant_2_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_2.total_price_net_amount == Decimal(0) + assert line_2.total_price_gross_amount == Decimal(0) + assert line_2.unit_price_net_amount == Decimal(0) + assert line_2.unit_price_gross_amount == Decimal(0) + + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == shipping_price + assert order.total_gross_amount == shipping_price * tax_rate + assert order.subtotal_net_amount == Decimal(0) + assert order.subtotal_gross_amount == Decimal(0) + + +def test_fetch_order_prices_manual_discount_and_order_discount_flat_rates( + order_with_lines_and_order_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_order_promotion + assert OrderDiscount.objects.exists() + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + discount_value = Decimal("50") + manual_discount = order.discounts.create( + value_type=DiscountValueType.PERCENTAGE, + value=discount_value, + name="Manual order discount", + type=DiscountType.MANUAL, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert not OrderLineDiscount.objects.exists() + assert OrderDiscount.objects.count() == 1 + manual_discount.refresh_from_db() + + assert manual_discount.order == order + assert manual_discount.amount_value == Decimal( + order.undiscounted_total_net_amount / 2 + ) + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.PERCENTAGE + assert manual_discount.type == DiscountType.MANUAL + assert not manual_discount.reason + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == quantize_price( + line_1_total_net_amount * tax_rate, currency + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == quantize_price( + line_1.unit_price_net_amount * tax_rate, currency + ) + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + undiscounted_shipping_price = order.base_shipping_price_amount + total_net_amount = quantize_price( + order.undiscounted_total_net_amount * discount_value / 100, currency + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + undiscounted_shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_manual_discount_and_gift_discount_flat_rates( + order_with_lines_and_gift_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_gift_promotion + assert OrderLineDiscount.objects.exists() + currency = order.currency + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + discount_value = Decimal("50") + manual_discount = order.discounts.create( + value_type=DiscountValueType.PERCENTAGE, + value=discount_value, + name="Manual order discount", + type=DiscountType.MANUAL, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert not OrderLineDiscount.objects.exists() + assert OrderDiscount.objects.count() == 1 + assert len(lines) == 2 + manual_discount.refresh_from_db() + + assert manual_discount.order == order + assert manual_discount.amount_value == Decimal( + order.undiscounted_total_net_amount / 2 + ) + assert manual_discount.value == discount_value + assert manual_discount.value_type == DiscountValueType.PERCENTAGE + assert manual_discount.type == DiscountType.MANUAL + assert not manual_discount.reason + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + assert not [line for line in lines if line.is_gift] + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + line_1.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert line_1.base_unit_price_amount == variant_1_undiscounted_unit_price + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == quantize_price( + line_1_total_net_amount * tax_rate, currency + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == quantize_price( + line_1.unit_price_net_amount * tax_rate, currency + ) + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount * discount_value / 100, currency + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + undiscounted_shipping_price = order.base_shipping_price_amount + total_net_amount = quantize_price( + order.undiscounted_total_net_amount * discount_value / 100, currency + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + undiscounted_shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == quantize_price( + total_net_amount * tax_rate, currency + ) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + assert ( + order.shipping_price_net_amount + == undiscounted_shipping_price * discount_value / 100 + ) + assert order.shipping_price_gross_amount == quantize_price( + order.shipping_price_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_manual_discount_and_catalogue_discount_flat_rates( + order_with_lines_and_catalogue_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_catalogue_promotion + currency = order.currency + rule = PromotionRule.objects.get() + rule_catalogue_reward = rule.reward_value + promotion_id = graphene.Node.to_global_id("Promotion", rule.promotion_id) + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + tax_rate = Decimal("1.23") + + manual_discount_value = Decimal("50") + manual_discount = order.discounts.create( + value_type=DiscountValueType.PERCENTAGE, + value=manual_discount_value, + name="Manual order discount", + type=DiscountType.MANUAL, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + catalogue_discount = OrderLineDiscount.objects.get() + assert OrderDiscount.objects.count() == 1 + + manual_discount.refresh_from_db() + manual_discount_amount = Decimal( + (order.undiscounted_total_net_amount - catalogue_discount.amount_value) + * manual_discount_value + / 100 + ) + assert manual_discount.order == order + assert manual_discount.amount_value == manual_discount_amount + assert manual_discount.value == manual_discount_value + assert manual_discount.value_type == DiscountValueType.PERCENTAGE + assert manual_discount.type == DiscountType.MANUAL + assert not manual_discount.reason + + line_1 = [line for line in lines if line.quantity == 3][0] + line_2 = [line for line in lines if line.quantity == 2][0] + + assert catalogue_discount.line == line_1 + assert catalogue_discount.amount_value == rule_catalogue_reward * line_1.quantity + assert catalogue_discount.value == rule_catalogue_reward + assert catalogue_discount.value_type == DiscountValueType.FIXED + assert catalogue_discount.type == DiscountType.PROMOTION + assert catalogue_discount.reason == f"Promotion: {promotion_id}" + + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + variant_1_undiscounted_unit_price = variant_1_listing.price_amount + line_1_total_net_amount = quantize_price( + (variant_1_undiscounted_unit_price - rule_catalogue_reward) + * line_1.quantity + * manual_discount_value + / 100, + currency, + ) + assert ( + line_1.undiscounted_total_price_net_amount + == variant_1_undiscounted_unit_price * line_1.quantity + ) + assert ( + line_1.undiscounted_total_price_gross_amount + == line_1.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_1.undiscounted_unit_price_net_amount == variant_1_undiscounted_unit_price + ) + assert ( + line_1.undiscounted_unit_price_gross_amount + == variant_1_undiscounted_unit_price * tax_rate + ) + assert ( + line_1.base_unit_price_amount + == variant_1_undiscounted_unit_price - rule_catalogue_reward + ) + assert line_1.total_price_net_amount == line_1_total_net_amount + assert line_1.total_price_gross_amount == quantize_price( + line_1_total_net_amount * tax_rate, currency + ) + assert line_1.unit_price_net_amount == quantize_price( + line_1_total_net_amount / line_1.quantity, currency + ) + assert line_1.unit_price_gross_amount == round_up( + line_1.unit_price_net_amount * tax_rate + ) + assert line_1.unit_discount_amount == rule_catalogue_reward + assert line_1.unit_discount_reason == f"Promotion: {promotion_id}" + assert line_1.unit_discount_value == rule_catalogue_reward + assert line_1.unit_discount_type == DiscountValueType.FIXED + + variant_2 = line_2.variant + variant_2_listing = variant_2.channel_listings.get(channel=order.channel) + variant_2_undiscounted_unit_price = variant_2_listing.price_amount + line_2_total_net_amount = quantize_price( + line_2.undiscounted_total_price_net_amount * manual_discount_value / 100, + currency, + ) + assert ( + line_2.undiscounted_total_price_net_amount + == variant_2_undiscounted_unit_price * line_2.quantity + ) + assert ( + line_2.undiscounted_total_price_gross_amount + == line_2.undiscounted_total_price_net_amount * tax_rate + ) + assert ( + line_2.undiscounted_unit_price_net_amount == variant_2_undiscounted_unit_price + ) + assert ( + line_2.undiscounted_unit_price_gross_amount + == variant_2_undiscounted_unit_price * tax_rate + ) + assert line_2.base_unit_price_amount == variant_2_undiscounted_unit_price + assert line_2.total_price_net_amount == line_2_total_net_amount + assert line_2.total_price_gross_amount == quantize_price( + line_2_total_net_amount * tax_rate, currency + ) + assert line_2.unit_price_net_amount == quantize_price( + line_2_total_net_amount / line_2.quantity, currency + ) + assert line_2.unit_price_gross_amount == quantize_price( + line_2.unit_price_net_amount * tax_rate, currency + ) + + undiscounted_shipping_price = order.base_shipping_price_amount + total_net_amount = quantize_price( + (order.undiscounted_total_net_amount - catalogue_discount.amount_value) + * manual_discount_value + / 100, + currency, + ) + assert ( + order.undiscounted_total_net_amount + == line_1.undiscounted_total_price_net_amount + + line_2.undiscounted_total_price_net_amount + + undiscounted_shipping_price + ) + assert ( + order.undiscounted_total_gross_amount + == order.undiscounted_total_net_amount * tax_rate + ) + assert order.total_net_amount == total_net_amount + assert order.total_gross_amount == round_up(total_net_amount * tax_rate) + assert ( + order.subtotal_net_amount == line_1_total_net_amount + line_2_total_net_amount + ) + assert order.subtotal_gross_amount == quantize_price( + order.subtotal_net_amount * tax_rate, currency + ) + assert ( + order.shipping_price_net_amount + == undiscounted_shipping_price * manual_discount_value / 100 + ) + assert order.shipping_price_gross_amount == quantize_price( + order.shipping_price_net_amount * tax_rate, currency + ) + + +def test_fetch_order_prices_manual_line_discount_and_catalogue_discount_flat_rates( + order_with_lines_and_catalogue_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_catalogue_promotion + + tc = order.channel.tax_configuration + tc.country_exceptions.all().delete() + tc.prices_entered_with_tax = False + tc.tax_calculation_strategy = TaxCalculationStrategy.FLAT_RATES + tc.save() + + line_1 = order.lines.get(quantity=3) + variant_1 = line_1.variant + variant_1_listing = variant_1.channel_listings.get(channel=order.channel) + + manual_discount_value = Decimal("5") + manual_discount_value_type = DiscountValueType.FIXED + manual_discount_reason = "Manual line discount" + manual_discount = line_1.discounts.create( + value_type=manual_discount_value_type, + value=manual_discount_value, + name="Manual order line discount", + type=DiscountType.MANUAL, + reason=manual_discount_reason, + ) + + # when + order, lines = calculations.fetch_order_prices_if_expired( + order, plugins_manager, None, True + ) + + # then + assert OrderLineDiscount.objects.count() == 1 + assert not OrderDiscount.objects.exists() + manual_discount.refresh_from_db() + + line_1 = [line for line in lines if line.quantity == 3][0] + + assert line_1.base_unit_price_amount == variant_1_listing.price_amount + assert manual_discount.line == line_1 + assert manual_discount.value == manual_discount_value + assert manual_discount.value_type == manual_discount_value_type + assert manual_discount.type == DiscountType.MANUAL + assert manual_discount.reason == manual_discount_reason + + # TODO https://github.com/saleor/saleor/issues/15517 + # line_1_total_net_amount = quantize_price( + # (variant_1_listing.price_amount - manual_discount_value) * line_1.quantity, + # order.currency + # ) + # assert line_1.total_price_net_amount == line_1_total_net_amount + + +def test_fetch_order_prices_catalogue_discount_race_condition( + order_with_lines_and_catalogue_promotion, + plugins_manager, +): + # given + order = order_with_lines_and_catalogue_promotion + OrderLineDiscount.objects.all().delete() + + # when + def call_before_creating_catalogue_line_discount(*args, **kwargs): + calculations.fetch_order_prices_if_expired(order, plugins_manager, None, True) + + with before_after.before( + "saleor.discount.utils.prepare_line_discount_objects_for_catalogue_promotions", + call_before_creating_catalogue_line_discount, + ): + calculations.fetch_order_prices_if_expired(order, plugins_manager, None, True) + + # then + assert OrderLineDiscount.objects.count() == 1 diff --git a/saleor/order/utils.py b/saleor/order/utils.py index 9e88fa5f8e0..f159dc5b1ec 100644 --- a/saleor/order/utils.py +++ b/saleor/order/utils.py @@ -841,6 +841,9 @@ def update_discount_for_order_line( value: Optional[Decimal], ): """Update discount fields for order line. Apply discount to the price.""" + # TODO: Move price calculation to fetch_order_prices_if_expired function. + # Here we should only create order line discount object + # https://github.com/saleor/saleor/issues/15517 current_value = order_line.unit_discount_value current_value_type = order_line.unit_discount_type value = value or current_value @@ -893,6 +896,52 @@ def update_discount_for_order_line( # from db order_line.save(update_fields=fields_to_update) + _update_manual_order_line_discount_object( + value, value_type, reason, order_line, order.currency + ) + + +def _update_manual_order_line_discount_object( + value, value_type, reason, order_line, currency +): + discount_to_update = None + discount_to_delete_ids = [] + discounts = order_line.discounts.all() + for discount in discounts: + if discount.type == DiscountType.MANUAL and not discount_to_update: + discount_to_update = discount + else: + discount_to_delete_ids.append(discount.pk) + + if discount_to_delete_ids: + OrderLineDiscount.objects.filter(id__in=discount_to_delete_ids).delete() + + amount_value = quantize_price( + order_line.unit_discount.amount * order_line.quantity, currency + ) + if not discount_to_update: + order_line.discounts.create( + type=DiscountType.MANUAL, + value_type=value_type, + value=value, + amount_value=amount_value, + currency=currency, + reason=reason, + ) + else: + update_fields = [] + if discount_to_update.value_type != value_type: + discount_to_update.value_type = value_type + update_fields.append("value_type") + if discount_to_update.value != value: + discount_to_update.value = value + discount_to_update.amount_value = amount_value + update_fields.extend(["value", "amount_value"]) + if discount_to_update.reason != reason: + discount_to_update.reason = reason + update_fields.append("reason") + discount_to_update.save(update_fields=update_fields) + def remove_discount_from_order_line(order_line: OrderLine, order: "Order"): """Drop discount applied to order line. Restore undiscounted price.""" @@ -922,6 +971,7 @@ def remove_discount_from_order_line(order_line: OrderLine, order: "Order"): "tax_rate", ] ) + order_line.discounts.all().delete() def update_order_charge_status(order: Order, granted_refund_amount: Decimal): diff --git a/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_gift_promotion.yaml b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_gift_promotion.yaml new file mode 100644 index 00000000000..a4a6bf4fe7c --- /dev/null +++ b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_gift_promotion.yaml @@ -0,0 +1,81 @@ +interactions: +- request: + body: '{"createTransactionModel": {"companyCode": "DEFAULT", "type": "SalesInvoice", + "lines": [{"quantity": 3, "amount": "30.000", "taxCode": "O9999999", "taxIncluded": + true, "itemCode": "SKU_AA", "discounted": false, "description": "Test product"}, + {"quantity": 2, "amount": "40.000", "taxCode": "O9999999", "taxIncluded": true, + "itemCode": "SKU_B", "discounted": false, "description": "Test product 2"}, + {"quantity": 1, "amount": "0.000", "taxCode": "O9999999", "taxIncluded": true, + "itemCode": "SKU_A", "discounted": false, "description": "Test product"}, {"quantity": + 1, "amount": "10.000", "taxCode": "FR000000", "taxIncluded": true, "itemCode": + "Shipping", "discounted": false, "description": null}], "code": "368cf947-7561-44f7-b39a-32a8dcc8482e", + "date": "2024-02-26", "customerCode": 0, "discount": null, "addresses": {"shipFrom": + {"line1": "Teczowa 7", "line2": "", "city": "Wroclaw", "region": "", "country": + "PL", "postalCode": "53-601"}, "shipTo": {"line1": "T\u0119czowa 7", "line2": + "", "city": "WROC\u0141AW", "region": "", "country": "PL", "postalCode": "53-601"}}, + "commit": false, "currencyCode": "USD", "email": "test@example.com"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br + Authorization: + - Basic Og== + Connection: + - keep-alive + Content-Length: + - '1145' + User-Agent: + - Saleor/3.20 + method: POST + uri: https://sandbox-rest.avatax.com/api/v2/transactions/createoradjust + response: + body: + string: '{"id":85050296660452,"code":"368cf947-7561-44f7-b39a-32a8dcc8482e","companyId":7799660,"date":"2024-02-26","status":"Saved","type":"SalesInvoice","batchCode":"","currencyCode":"USD","exchangeRateCurrencyCode":"USD","customerUsageType":"","entityUseCode":"","customerVendorCode":"0","customerCode":"0","exemptNo":"","reconciled":false,"locationCode":"","reportingLocationCode":"","purchaseOrderNo":"","referenceCode":"","salespersonCode":"","taxOverrideType":"None","taxOverrideAmount":0.0,"taxOverrideReason":"","totalAmount":65.04,"totalExempt":0.0,"totalDiscount":0.0,"totalTax":14.96,"totalTaxable":65.04,"totalTaxCalculated":14.96,"adjustmentReason":"NotAdjusted","adjustmentDescription":"","locked":false,"region":"","country":"PL","version":1,"softwareVersion":"24.2.0.0","originAddressId":85050296660454,"destinationAddressId":85050296660453,"exchangeRateEffectiveDate":"2024-02-26","exchangeRate":1.0,"description":"","email":"test@example.com","businessIdentificationNo":"","modifiedDate":"2024-02-26T18:58:42.8269135Z","modifiedUserId":6479978,"taxDate":"2024-02-26","lines":[{"id":85050296660458,"transactionId":85050296660452,"lineNumber":"1","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_AA","lineAmount":24.3900,"quantity":3.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":5.61,"taxableAmount":24.39,"taxCalculated":5.61,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660482,"transactionLineId":85050296660458,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":5.6100,"taxableAmount":24.3900,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":5.6100,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":24.3900,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":24.39,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":5.61,"reportingTaxCalculated":5.61,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660463,"documentLineId":85050296660458,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660464,"documentLineId":85050296660458,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296660459,"transactionId":85050296660452,"lineNumber":"2","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product 2","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_B","lineAmount":32.5200,"quantity":2.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":7.48,"taxableAmount":32.52,"taxCalculated":7.48,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660503,"transactionLineId":85050296660459,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":7.4800,"taxableAmount":32.5200,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":7.4800,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":32.5200,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":32.52,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":7.48,"reportingTaxCalculated":7.48,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660484,"documentLineId":85050296660459,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660485,"documentLineId":85050296660459,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296660460,"transactionId":85050296660452,"lineNumber":"3","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":false,"isSSTP":false,"itemCode":"SKU_A","lineAmount":0.0,"quantity":1.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":0.0,"taxableAmount":0.0,"taxCalculated":0.0,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660524,"transactionLineId":85050296660460,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":0.0000,"taxableAmount":0.0000,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":0.0000,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":0.0000,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":0.0,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":0.0,"reportingTaxCalculated":0.0,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660505,"documentLineId":85050296660460,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660506,"documentLineId":85050296660460,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296660461,"transactionId":85050296660452,"lineNumber":"4","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"","destinationAddressId":85050296660453,"originAddressId":85050296660454,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"Shipping","lineAmount":8.1300,"quantity":1.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":1.87,"taxableAmount":8.13,"taxCalculated":1.87,"taxCode":"FR000000","taxCodeId":8550,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296660545,"transactionLineId":85050296660461,"transactionId":85050296660452,"addressId":85050296660453,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":1.8700,"taxableAmount":8.1300,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":1.8700,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":8.1300,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":8.13,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":1.87,"reportingTaxCalculated":1.87,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296660526,"documentLineId":85050296660461,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296660527,"documentLineId":85050296660461,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230D","vatNumberTypeId":0}],"addresses":[{"id":85050296660453,"transactionId":85050296660452,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"WROCLAW","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102},{"id":85050296660454,"transactionId":85050296660452,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"Wroclaw","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102}],"locationTypes":[{"documentLocationTypeId":85050296660456,"documentId":85050296660452,"documentAddressId":85050296660454,"locationTypeCode":"ShipFrom"},{"documentLocationTypeId":85050296660457,"documentId":85050296660452,"documentAddressId":85050296660453,"locationTypeCode":"ShipTo"}],"summary":[{"country":"PL","region":"PL","jurisType":"Country","jurisCode":"PL","jurisName":"POLAND","taxAuthorityType":45,"stateAssignedNo":"","taxType":"Output","taxSubType":"O","taxName":"Standard + Rate","rateType":"Standard","taxable":65.04,"rate":0.230000,"tax":14.96,"taxCalculated":14.96,"nonTaxable":0.00,"exemption":0.00}]}' + headers: + Connection: + - keep-alive + Content-Type: + - application/json; charset=utf-8 + Date: + - Mon, 26 Feb 2024 18:58:42 GMT + Location: + - /api/v2/companies/7799660/transactions/85050296660452 + ServerDuration: + - '00:00:00.0815874' + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + api-supported-versions: + - '2.0' + cache-control: + - private, no-cache, no-store + referrer-policy: + - same-origin + strict-transport-security: + - max-age=31536000; includeSubdomains + x-avalara-uid: + - df8331af-ab28-4015-9bee-4d945f48634c + x-correlation-id: + - df8331af-ab28-4015-9bee-4d945f48634c + x-frame-options: + - sameorigin + x-permitted-cross-domain-policies: + - none + x-xss-protection: + - 1; mode=block + status: + code: 201 + message: Created +version: 1 diff --git a/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_order_promotion.yaml b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_order_promotion.yaml new file mode 100644 index 00000000000..0ff21af4c2f --- /dev/null +++ b/saleor/plugins/avatax/tests/cassettes/test_avatax/test_calculate_order_total_order_promotion.yaml @@ -0,0 +1,77 @@ +interactions: +- request: + body: '{"createTransactionModel": {"companyCode": "DEFAULT", "type": "SalesInvoice", + "lines": [{"quantity": 3, "amount": "30.000", "taxCode": "O9999999", "taxIncluded": + true, "itemCode": "SKU_AA", "discounted": true, "description": "Test product"}, + {"quantity": 2, "amount": "40.000", "taxCode": "O9999999", "taxIncluded": true, + "itemCode": "SKU_B", "discounted": true, "description": "Test product 2"}, {"quantity": + 1, "amount": "10.000", "taxCode": "FR000000", "taxIncluded": true, "itemCode": + "Shipping", "discounted": false, "description": null}], "code": "a01b5217-42b3-48dc-889f-954154d1c3f6", + "date": "2024-02-26", "customerCode": 0, "discount": "25.000", "addresses": + {"shipFrom": {"line1": "Teczowa 7", "line2": "", "city": "Wroclaw", "region": + "", "country": "PL", "postalCode": "53-601"}, "shipTo": {"line1": "T\u0119czowa + 7", "line2": "", "city": "WROC\u0141AW", "region": "", "country": "PL", "postalCode": + "53-601"}}, "commit": false, "currencyCode": "USD", "email": "test@example.com"}}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br + Authorization: + - Basic Og== + Connection: + - keep-alive + Content-Length: + - '994' + User-Agent: + - Saleor/3.20 + method: POST + uri: https://sandbox-rest.avatax.com/api/v2/transactions/createoradjust + response: + body: + string: '{"id":85050296358456,"code":"a01b5217-42b3-48dc-889f-954154d1c3f6","companyId":7799660,"date":"2024-02-26","status":"Saved","type":"SalesInvoice","batchCode":"","currencyCode":"USD","exchangeRateCurrencyCode":"USD","customerUsageType":"","entityUseCode":"","customerVendorCode":"0","customerCode":"0","exemptNo":"","reconciled":false,"locationCode":"","reportingLocationCode":"","purchaseOrderNo":"","referenceCode":"","salespersonCode":"","taxOverrideType":"None","taxOverrideAmount":0.0,"taxOverrideReason":"","totalAmount":69.71,"totalExempt":0.0,"totalDiscount":25.0,"totalTax":10.29,"totalTaxable":44.71,"totalTaxCalculated":10.29,"adjustmentReason":"NotAdjusted","adjustmentDescription":"","locked":false,"region":"","country":"PL","version":1,"softwareVersion":"24.2.0.0","originAddressId":85050296358458,"destinationAddressId":85050296358457,"exchangeRateEffectiveDate":"2024-02-26","exchangeRate":1.0,"description":"","email":"test@example.com","businessIdentificationNo":"","modifiedDate":"2024-02-26T18:49:50.8851355Z","modifiedUserId":6479978,"taxDate":"2024-02-26","lines":[{"id":85050296358462,"transactionId":85050296358456,"lineNumber":"1","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product","destinationAddressId":85050296358457,"originAddressId":85050296358458,"discountAmount":10.71,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_AA","lineAmount":26.3900,"quantity":3.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":3.61,"taxableAmount":15.68,"taxCalculated":3.61,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296358485,"transactionLineId":85050296358462,"transactionId":85050296358456,"addressId":85050296358457,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":3.6100,"taxableAmount":15.6800,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":3.6100,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":15.6800,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":15.68,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":3.61,"reportingTaxCalculated":3.61,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296358466,"documentLineId":85050296358462,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296358467,"documentLineId":85050296358462,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296358463,"transactionId":85050296358456,"lineNumber":"2","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"Test + product 2","destinationAddressId":85050296358457,"originAddressId":85050296358458,"discountAmount":14.29,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"SKU_B","lineAmount":35.1900,"quantity":2.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":4.81,"taxableAmount":20.9,"taxCalculated":4.81,"taxCode":"O9999999","taxCodeId":9111,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296358506,"transactionLineId":85050296358463,"transactionId":85050296358456,"addressId":85050296358457,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":4.8100,"taxableAmount":20.9000,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":4.8100,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":20.9000,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":20.9,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":4.81,"reportingTaxCalculated":4.81,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296358487,"documentLineId":85050296358463,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296358488,"documentLineId":85050296358463,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230C","vatNumberTypeId":0},{"id":85050296358464,"transactionId":85050296358456,"lineNumber":"3","boundaryOverrideId":0,"customerUsageType":"","entityUseCode":"","description":"","destinationAddressId":85050296358457,"originAddressId":85050296358458,"discountAmount":0.0,"discountTypeId":0,"exemptAmount":0.0,"exemptCertId":0,"exemptNo":"","isItemTaxable":true,"isSSTP":false,"itemCode":"Shipping","lineAmount":8.1300,"quantity":1.0,"ref1":"","ref2":"","reportingDate":"2024-02-26","revAccount":"","sourcing":"Destination","tax":1.87,"taxableAmount":8.13,"taxCalculated":1.87,"taxCode":"FR000000","taxCodeId":8550,"taxDate":"2024-02-26","taxEngine":"","taxOverrideType":"None","businessIdentificationNo":"","taxOverrideAmount":0.0,"taxOverrideReason":"","taxIncluded":true,"details":[{"id":85050296358527,"transactionLineId":85050296358464,"transactionId":85050296358456,"addressId":85050296358457,"country":"PL","region":"PL","countyFIPS":"","stateFIPS":"PL","exemptAmount":0.0000,"exemptReasonId":4,"inState":true,"jurisCode":"PL","jurisName":"POLAND","jurisdictionId":200102,"signatureCode":"","stateAssignedNo":"","jurisType":"CNT","jurisdictionType":"Country","nonTaxableAmount":0.0000,"nonTaxableRuleId":0,"nonTaxableType":"RateRule","rate":0.230000,"rateRuleId":411502,"rateSourceId":0,"serCode":"","sourcing":"Destination","tax":1.8700,"taxableAmount":8.1300,"taxType":"Output","taxSubTypeId":"O","taxTypeGroupId":"InputAndOutput","taxName":"Standard + Rate","taxAuthorityTypeId":45,"taxRegionId":205102,"taxCalculated":1.8700,"taxOverride":0.0000,"rateType":"Standard","rateTypeCode":"S","taxableUnits":8.1300,"nonTaxableUnits":0.0000,"exemptUnits":0.0000,"unitOfBasis":"PerCurrencyUnit","isNonPassThru":false,"isFee":false,"reportingTaxableUnits":8.13,"reportingNonTaxableUnits":0.0,"reportingExemptUnits":0.0,"reportingTax":1.87,"reportingTaxCalculated":1.87,"liabilityType":"Seller","chargedTo":"Buyer"}],"nonPassthroughDetails":[],"lineLocationTypes":[{"documentLineLocationTypeId":85050296358508,"documentLineId":85050296358464,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLineLocationTypeId":85050296358509,"documentLineId":85050296358464,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"parameters":[{"name":"Transport","value":"None"},{"name":"IsMarketplace","value":"False"},{"name":"IsTriangulation","value":"false"},{"name":"IsGoodsSecondHand","value":"false"}],"hsCode":"","costInsuranceFreight":0.0,"vatCode":"PLS-230D","vatNumberTypeId":0}],"addresses":[{"id":85050296358457,"transactionId":85050296358456,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"WROCLAW","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102},{"id":85050296358458,"transactionId":85050296358456,"boundaryLevel":"Zip5","line1":"Teczowa + 7","line2":"","line3":"","city":"Wroclaw","region":"","postalCode":"53-601","country":"PL","taxRegionId":205102}],"locationTypes":[{"documentLocationTypeId":85050296358460,"documentId":85050296358456,"documentAddressId":85050296358458,"locationTypeCode":"ShipFrom"},{"documentLocationTypeId":85050296358461,"documentId":85050296358456,"documentAddressId":85050296358457,"locationTypeCode":"ShipTo"}],"summary":[{"country":"PL","region":"PL","jurisType":"Country","jurisCode":"PL","jurisName":"POLAND","taxAuthorityType":45,"stateAssignedNo":"","taxType":"Output","taxSubType":"O","taxName":"Standard + Rate","rateType":"Standard","taxable":44.71,"rate":0.230000,"tax":10.29,"taxCalculated":10.29,"nonTaxable":0.00,"exemption":0.00}]}' + headers: + Connection: + - keep-alive + Content-Type: + - application/json; charset=utf-8 + Date: + - Mon, 26 Feb 2024 18:49:50 GMT + Location: + - /api/v2/companies/7799660/transactions/85050296358456 + ServerDuration: + - '00:00:00.0712472' + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + api-supported-versions: + - '2.0' + cache-control: + - private, no-cache, no-store + referrer-policy: + - same-origin + strict-transport-security: + - max-age=31536000; includeSubdomains + x-avalara-uid: + - 33c0538c-1f94-4fad-b391-ed49bb1ad281 + x-correlation-id: + - 33c0538c-1f94-4fad-b391-ed49bb1ad281 + x-frame-options: + - sameorigin + x-permitted-cross-domain-policies: + - none + x-xss-protection: + - 1; mode=block + status: + code: 201 + message: Created +version: 1 diff --git a/saleor/plugins/avatax/tests/test_avatax.py b/saleor/plugins/avatax/tests/test_avatax.py index d81b293880f..32275c3878d 100644 --- a/saleor/plugins/avatax/tests/test_avatax.py +++ b/saleor/plugins/avatax/tests/test_avatax.py @@ -2465,6 +2465,56 @@ def test_calculate_order_total_for_JPY( assert price == TaxedMoney(net=Money("3496", "JPY"), gross=Money("4300", "JPY")) +@pytest.mark.vcr +@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) +def test_calculate_order_total_order_promotion( + order_with_lines_and_order_promotion, + shipping_zone, + site_settings, + address, + plugin_configuration, +): + plugin_configuration() + manager = get_plugins_manager(allow_replica=False) + order = order_with_lines_and_order_promotion + method = shipping_zone.shipping_methods.get() + order.shipping_address = order.billing_address.get_copy() + order_set_shipping_method(order, method) + order.save() + + site_settings.company_address = address + site_settings.save() + + price = manager.calculate_order_total(order, order.lines.all()) + price = quantize_price(price, price.currency) + assert price == TaxedMoney(net=Money("44.71", "USD"), gross=Money("55.00", "USD")) + + +@pytest.mark.vcr +@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) +def test_calculate_order_total_gift_promotion( + order_with_lines_and_gift_promotion, + shipping_zone, + site_settings, + address, + plugin_configuration, +): + plugin_configuration() + manager = get_plugins_manager(allow_replica=False) + order = order_with_lines_and_gift_promotion + method = shipping_zone.shipping_methods.get() + order.shipping_address = order.billing_address.get_copy() + order_set_shipping_method(order, method) + order.save() + + site_settings.company_address = address + site_settings.save() + + price = manager.calculate_order_total(order, order.lines.all()) + price = quantize_price(price, price.currency) + assert price == TaxedMoney(net=Money("65.04", "USD"), gross=Money("80.00", "USD")) + + @pytest.mark.vcr @override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) def test_calculate_order_shipping_entire_order_voucher( diff --git a/saleor/tax/calculations/order.py b/saleor/tax/calculations/order.py index 0d88de8cb5f..d05c7bf1745 100644 --- a/saleor/tax/calculations/order.py +++ b/saleor/tax/calculations/order.py @@ -68,18 +68,20 @@ def update_order_prices_with_flat_rates( ) order.shipping_tax_rate = normalize_tax_rate_for_db(shipping_tax_rate) - # Calculate order total. - order.undiscounted_total = undiscounted_subtotal + order.base_shipping_price - order.total = _calculate_order_total( - order, lines, database_connection_name=database_connection_name + _set_order_totals( + order, + lines, + prices_entered_with_tax, + database_connection_name=database_connection_name, ) -def _calculate_order_total( +def _set_order_totals( order: "Order", lines: Iterable["OrderLine"], + prices_entered_with_tax: bool, database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, -) -> TaxedMoney: +): currency = order.currency default_value = base_calculations.base_order_total( @@ -87,16 +89,28 @@ def _calculate_order_total( ) default_value = TaxedMoney(default_value, default_value) if default_value <= zero_taxed_money(currency): - return quantize_price(default_value, currency) + order.total = quantize_price(default_value, currency) + order.undiscounted_total = quantize_price(default_value, currency) + order.subtotal = quantize_price(default_value, currency) + return - total = zero_taxed_money(currency) + subtotal = zero_taxed_money(currency) undiscounted_subtotal = zero_taxed_money(currency) for line in lines: - total += line.total_price + subtotal += line.total_price undiscounted_subtotal += line.undiscounted_total_price - total += order.shipping_price - return quantize_price(max(total, zero_taxed_money(currency)), currency) + shipping_tax_rate = order.shipping_tax_rate or 0 + undiscounted_shipping_price = calculate_flat_rate_tax( + order.base_shipping_price, + Decimal(shipping_tax_rate * 100), + prices_entered_with_tax, + ) + undiscounted_total = undiscounted_subtotal + undiscounted_shipping_price + + order.total = quantize_price(subtotal + order.shipping_price, currency) + order.undiscounted_total = quantize_price(undiscounted_total, currency) + order.subtotal = quantize_price(subtotal, currency) def _calculate_order_shipping( diff --git a/saleor/tax/tests/test_checkout_calculations.py b/saleor/tax/tests/test_checkout_calculations.py index 9b9d07239a9..c026cda0be6 100644 --- a/saleor/tax/tests/test_checkout_calculations.py +++ b/saleor/tax/tests/test_checkout_calculations.py @@ -7,7 +7,7 @@ from ...checkout.utils import add_variant_to_checkout from ...core.prices import quantize_price from ...core.taxes import zero_taxed_money -from ...discount.utils import create_discount_objects_for_order_promotions +from ...discount.utils import create_checkout_discount_objects_for_order_promotions from ...plugins.manager import get_plugins_manager from ...tax.models import TaxClassCountryRate from .. import TaxCalculationStrategy @@ -880,7 +880,7 @@ def test_calculate_checkout_line_total_discount_from_order_promotion( lines, _ = fetch_checkout_lines(checkout) checkout_info = fetch_checkout_info(checkout, lines, manager) checkout_line_info = lines[0] - create_discount_objects_for_order_promotions(checkout_info, lines) + create_checkout_discount_objects_for_order_promotions(checkout_info, lines) # when line_price = calculate_checkout_line_total( @@ -928,7 +928,7 @@ def test_calculate_checkout_line_total_discount_for_gift_line( lines, _ = fetch_checkout_lines(checkout) checkout_info = fetch_checkout_info(checkout, lines, manager) checkout_line_info = [line_info for line_info in lines if line_info.line.is_gift][0] - create_discount_objects_for_order_promotions(checkout_info, lines) + create_checkout_discount_objects_for_order_promotions(checkout_info, lines) # when line_price = calculate_checkout_line_total( diff --git a/saleor/tax/tests/test_order_calculations.py b/saleor/tax/tests/test_order_calculations.py index e8daf7c6123..1074037fcd2 100644 --- a/saleor/tax/tests/test_order_calculations.py +++ b/saleor/tax/tests/test_order_calculations.py @@ -77,7 +77,7 @@ def test_calculations_calculate_order_undiscounted_total( # then assert order.undiscounted_total == TaxedMoney( - net=Money("80.00", "USD"), gross=Money("80.00", "USD") + net=Money("65.04", "USD"), gross=Money("80.00", "USD") ) diff --git a/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py b/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py index 13edade7401..2bf3b3e34cf 100644 --- a/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py +++ b/saleor/tests/e2e/orders/discounts/test_order_product_with_percentage_promotion.py @@ -124,8 +124,8 @@ def test_order_products_on_percentage_promotion_CORE_2103( product_price = order_line["undiscountedUnitPrice"]["gross"]["amount"] assert product_price == float(product_variant_price) assert discount == order_line["unitDiscount"]["amount"] - assert order_line["unitDiscountType"] == "FIXED" - assert order_line["unitDiscountValue"] == discount + assert order_line["unitDiscountType"] == "PERCENTAGE" + assert order_line["unitDiscountValue"] == discount_value assert order_line["unitDiscountReason"] == promotion_reason product_discounted_price = product_price - discount assert product_discounted_price == order_line["unitPrice"]["gross"]["amount"] diff --git a/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py b/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py index e6b719f323c..7fd8086ee15 100644 --- a/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py +++ b/saleor/tests/e2e/orders/discounts/test_order_products_on_percentage_sale.py @@ -148,8 +148,8 @@ def test_order_products_on_percentage_sale_CORE_1003( order_line = order["order"]["lines"][0] assert order_line["unitDiscount"]["amount"] == discount - assert order_line["unitDiscountValue"] == discount - assert order_line["unitDiscountType"] == "FIXED" + assert order_line["unitDiscountValue"] == sale_discount_value + assert order_line["unitDiscountType"] == "PERCENTAGE" assert draft_line["unitDiscountReason"] == f"Sale: {sale_id}" product_price = order_line["undiscountedUnitPrice"]["gross"]["amount"] assert product_price == undiscounted_price diff --git a/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py b/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py index 8965e3f9a07..4877ca7dd4d 100644 --- a/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py +++ b/saleor/tests/e2e/orders/discounts/test_order_products_on_promotion_and_manual_order_discount.py @@ -161,8 +161,8 @@ def test_order_products_on_promotion_and_manual_order_discount_CORE_2108( ) assert product_price == product_variant_price assert order_line["unitDiscount"]["amount"] == promotion_value - assert order_line["unitDiscountType"] == "FIXED" - assert order_line["unitDiscountValue"] == promotion_value + assert order_line["unitDiscountType"] == "PERCENTAGE" + assert order_line["unitDiscountValue"] == promotion_discount_value assert order_line["unitDiscountReason"] == promotion_reason product_discounted_price = product_price - promotion_value shipping_amount = quantize_price( diff --git a/saleor/tests/fixtures.py b/saleor/tests/fixtures.py index 1c24734f9bc..f4f558fc7dd 100644 --- a/saleor/tests/fixtures.py +++ b/saleor/tests/fixtures.py @@ -488,8 +488,6 @@ def checkout_with_item_on_promotion(checkout_with_item): variant = line.variant - channel = checkout_with_item.channel - reward_value = Decimal("5") rule = promotion.rules.create( catalogue_predicate={ @@ -5071,7 +5069,122 @@ def order_with_lines_for_cc( @pytest.fixture -def order_fulfill_data(order_with_lines, warehouse): +def order_with_lines_and_catalogue_promotion( + order_with_lines, channel_USD, catalogue_promotion_without_rules +): + order = order_with_lines + promotion = catalogue_promotion_without_rules + line = order.lines.get(quantity=3) + variant = line.variant + reward_value = Decimal(3) + rule = promotion.rules.create( + name="Catalogue rule fixed", + catalogue_predicate={ + "variantPredicate": { + "ids": [graphene.Node.to_global_id("ProductVariant", variant)] + } + }, + reward_value_type=RewardValueType.FIXED, + reward_value=reward_value, + ) + rule.channels.add(channel_USD) + + listing = variant.channel_listings.get(channel=channel_USD) + listing.discounted_price_amount = listing.price_amount - reward_value + listing.save(update_fields=["discounted_price_amount"]) + listing.variantlistingpromotionrule.create( + promotion_rule=rule, + discount_amount=reward_value, + currency=order.currency, + ) + + line.discounts.create( + type=DiscountType.PROMOTION, + value_type=RewardValueType.FIXED, + value=reward_value, + amount_value=reward_value * line.quantity, + currency=order.currency, + promotion_rule=rule, + ) + return order + + +@pytest.fixture +def order_with_lines_and_order_promotion( + order_with_lines, + channel_USD, + order_promotion_without_rules, +): + order = order_with_lines + promotion = order_promotion_without_rules + rule = promotion.rules.create( + name="Fixed subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_value_type=RewardValueType.FIXED, + reward_value=Decimal(25), + reward_type=RewardType.SUBTOTAL_DISCOUNT, + ) + rule.channels.add(channel_USD) + + order.discounts.create( + promotion_rule=rule, + type=DiscountType.ORDER_PROMOTION, + value_type=rule.reward_value_type, + value=rule.reward_value, + amount_value=rule.reward_value, + currency=order.currency, + ) + return order + + +@pytest.fixture +def order_with_lines_and_gift_promotion( + order_with_lines, + channel_USD, + order_promotion_without_rules, + variant_with_many_stocks, +): + order = order_with_lines + variant = variant_with_many_stocks + variant_listing = variant.channel_listings.get(channel=channel_USD) + promotion = order_promotion_without_rules + rule = promotion.rules.create( + name="Gift subtotal rule", + order_predicate={ + "discountedObjectPredicate": {"baseSubtotalPrice": {"range": {"gte": 10}}} + }, + reward_type=RewardType.GIFT, + ) + rule.channels.add(channel_USD) + rule.gifts.set([variant]) + + gift_line = order.lines.create( + quantity=1, + variant=variant, + is_gift=True, + currency=order.currency, + unit_price_net_amount=0, + unit_price_gross_amount=0, + total_price_net_amount=0, + total_price_gross_amount=0, + is_shipping_required=True, + is_gift_card=False, + ) + gift_line.discounts.create( + promotion_rule=rule, + type=DiscountType.ORDER_PROMOTION, + value_type=RewardValueType.FIXED, + value=variant_listing.price_amount, + amount_value=variant_listing.price_amount, + currency=order.currency, + ) + return order + + +@pytest.fixture +def order_fulfill_data(order_with_lines, warehouse, checkout): FulfillmentData = namedtuple("FulfillmentData", "order variables warehouse") order = order_with_lines order_id = graphene.Node.to_global_id("Order", order.id) diff --git a/saleor/tests/utils.py b/saleor/tests/utils.py index bd18d48214c..bb677fc9c12 100644 --- a/saleor/tests/utils.py +++ b/saleor/tests/utils.py @@ -1,4 +1,6 @@ import json +import math +from decimal import Decimal from django.db import connections, transaction @@ -21,3 +23,11 @@ def flush_post_commit_hooks(): def dummy_editorjs(text, json_format=False): data = {"blocks": [{"data": {"text": text}, "type": "paragraph"}]} return json.dumps(data) if json_format else data + + +def round_down(price: Decimal) -> Decimal: + return Decimal(math.floor(price * 100)) / 100 + + +def round_up(price: Decimal) -> Decimal: + return Decimal(math.ceil(price * 100)) / 100 From 9117ab5e775bea27e0749b98f90730f117d0ebb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Wed, 17 Apr 2024 10:46:29 +0200 Subject: [PATCH 30/35] Deprecate `taxTypes` query (#15802) * Deprecate query * Add changelog --- CHANGELOG.md | 2 +- saleor/graphql/core/schema.py | 2 ++ saleor/graphql/schema.graphql | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6c83ed37fd..5031460a04e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,8 @@ All notable, unreleased changes to this project will be documented in this file. ### GraphQL API - Add `translatableContent` to all translation types; add translated object id to all translatable content types - #15617 by @zedzior - - Add a `taxConfiguration` to a `Channel` - #15610 by @Air-t +- Deprecate the `taxTypes` query - #15802 by @maarcingebala ### Saleor Apps diff --git a/saleor/graphql/core/schema.py b/saleor/graphql/core/schema.py index 3e6f856b4a0..fac28232cb8 100644 --- a/saleor/graphql/core/schema.py +++ b/saleor/graphql/core/schema.py @@ -1,5 +1,6 @@ import graphene +from ..core.descriptions import DEPRECATED_IN_3X_FIELD from ..core.doc_category import DOC_CATEGORY_TAXES from ..core.fields import BaseField from ..plugins.dataloaders import get_plugin_manager_promise @@ -13,6 +14,7 @@ class CoreQueries(graphene.ObjectType): NonNullList(TaxType), description="List of all tax rates available from tax gateway.", doc_category=DOC_CATEGORY_TAXES, + deprecation_reason=f"{DEPRECATED_IN_3X_FIELD} Use `taxClasses` field instead.", ) def resolve_tax_types(self, info: ResolveInfo): diff --git a/saleor/graphql/schema.graphql b/saleor/graphql/schema.graphql index 800908aa588..633ca308253 100644 --- a/saleor/graphql/schema.graphql +++ b/saleor/graphql/schema.graphql @@ -1282,7 +1282,7 @@ type Query { ): ExportFileCountableConnection """List of all tax rates available from tax gateway.""" - taxTypes: [TaxType!] @doc(category: "Taxes") + taxTypes: [TaxType!] @doc(category: "Taxes") @deprecated(reason: "This field will be removed in Saleor 4.0. Use `taxClasses` field instead.") """ Look up a checkout by id. From d1b19098c20921e9748f0342c1c3b159caf96285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20G=C4=99bala?= <5421321+maarcingebala@users.noreply.github.com> Date: Wed, 17 Apr 2024 10:55:45 +0200 Subject: [PATCH 31/35] Optimize plugin manager initialization (#15769) (#15807) * WIP Optimize initializing plugins for channels * Fix tests * Add flag loaded_all_channels * Handle invalid channel slug * Add backwards compatibility for list_payment_gateways * Fix saving plugin configuration * Pass channel_slug to manager methods to properly choose plugins --- saleor/core/tests/test_dataloaders.py | 2 + saleor/graphql/account/resolvers.py | 8 +- .../tests/benchmark/test_permission_group.py | 4 +- .../graphql/account/tests/queries/test_me.py | 4 +- saleor/graphql/account/types.py | 4 +- .../benchmark/test_checkout_mutations.py | 22 +- .../graphql/checkout/tests/test_checkout.py | 8 +- saleor/graphql/checkout/types.py | 4 +- .../tests/benchmark/test_promotion_create.py | 2 +- .../tests/benchmark/test_promotion_delete.py | 2 +- .../benchmark/test_promotion_rule_create.py | 2 +- .../test_voucher_code_bulk_delete.py | 4 +- ...payment_gateway_initialize_tokenization.py | 2 +- .../payment_method_intialize_tokenization.py | 1 + .../payment_method_process_tokenization.py | 1 + .../payment_method_request_delete.py | 2 +- .../mutations/stored_payment_methods/utils.py | 6 +- ...payment_gateway_initialize_tokenization.py | 6 +- ..._payment_method_initialize_tokenization.py | 6 +- ...est_payment_method_process_tokenization.py | 6 +- ...st_stored_payment_method_request_delete.py | 8 +- saleor/graphql/plugins/resolvers.py | 2 +- saleor/graphql/product/mutations/utils.py | 7 +- .../tests/benchmark/test_collection.py | 4 +- .../product/tests/benchmark/test_product.py | 4 +- .../product/tests/benchmark/test_variant.py | 4 +- .../product/tests/deprecated/test_utils.py | 4 +- .../tests/mutations/test_product_create.py | 2 +- .../tests/mutations/test_product_update.py | 2 +- saleor/graphql/product/types/products.py | 4 +- .../tests/benchmark/test_stock_bulk_update.py | 4 +- saleor/payment/gateway.py | 2 +- .../payment/gateways/adyen/tests/conftest.py | 1 + .../gateways/authorize_net/tests/conftest.py | 1 + .../gateways/np_atobarai/tests/conftest.py | 1 + .../payment/gateways/stripe/tests/conftest.py | 1 + saleor/plugins/admin_email/tests/conftest.py | 1 + .../plugins/admin_email/tests/test_plugin.py | 1 + saleor/plugins/avatax/plugin.py | 30 +- .../avatax/tests/deprecated/__init__.py | 0 .../deprecated/test_tax_code_mutations.py | 259 ------- saleor/plugins/avatax/tests/test_avatax.py | 147 +--- saleor/plugins/base_plugin.py | 15 +- saleor/plugins/manager.py | 696 +++++++++++++----- .../plugins/openid_connect/tests/conftest.py | 1 + saleor/plugins/sendgrid/tests/conftest.py | 1 + saleor/plugins/tests/sample_plugins.py | 4 +- saleor/plugins/tests/test_manager.py | 70 +- saleor/plugins/user_email/tests/conftest.py | 1 + .../plugins/user_email/tests/test_plugin.py | 1 + saleor/plugins/webhook/conftest.py | 1 + saleor/plugins/webhook/plugin.py | 4 +- saleor/plugins/webhook/tests/conftest.py | 1 + 53 files changed, 690 insertions(+), 690 deletions(-) delete mode 100644 saleor/plugins/avatax/tests/deprecated/__init__.py delete mode 100644 saleor/plugins/avatax/tests/deprecated/test_tax_code_mutations.py diff --git a/saleor/core/tests/test_dataloaders.py b/saleor/core/tests/test_dataloaders.py index 943cdda80bc..5dd142b08fd 100644 --- a/saleor/core/tests/test_dataloaders.py +++ b/saleor/core/tests/test_dataloaders.py @@ -13,6 +13,7 @@ def test_plugins_manager_loader_loads_requestor_in_plugin(rf, customer_user, set handler.load_middleware() handler.get_response(request) manager = get_plugin_manager_promise(request).get() + manager.get_all_plugins() plugin = manager.all_plugins.pop() assert isinstance(plugin.requestor, type(customer_user)) @@ -31,6 +32,7 @@ def test_plugins_manager_loader_requestor_in_plugin_when_no_app_and_user_in_req_ handler.load_middleware() handler.get_response(request) manager = get_plugin_manager_promise(request).get() + manager.get_all_plugins() plugin = manager.all_plugins.pop() assert not plugin.requestor diff --git a/saleor/graphql/account/resolvers.py b/saleor/graphql/account/resolvers.py index 046c0282dc3..ab1ec5fc5a7 100644 --- a/saleor/graphql/account/resolvers.py +++ b/saleor/graphql/account/resolvers.py @@ -178,11 +178,13 @@ def resolve_address_validation_rules( @traced_resolver -def resolve_payment_sources(_info, user: models.User, manager, channel_slug: str): - stored_customer_accounts = ( +def resolve_payment_sources( + _info, user: models.User, manager, channel_slug: Optional[str] +): + stored_customer_accounts = [ (gtw.id, fetch_customer_id(user, gtw.id)) for gtw in gateway.list_gateways(manager, channel_slug) - ) + ] return list( chain( *[ diff --git a/saleor/graphql/account/tests/benchmark/test_permission_group.py b/saleor/graphql/account/tests/benchmark/test_permission_group.py index db890373a96..1e684cb1cb2 100644 --- a/saleor/graphql/account/tests/benchmark/test_permission_group.py +++ b/saleor/graphql/account/tests/benchmark/test_permission_group.py @@ -348,7 +348,7 @@ def test_groups_for_federation_query_count( ], } - with django_assert_num_queries(2): + with django_assert_num_queries(1): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -363,7 +363,7 @@ def test_groups_for_federation_query_count( ], } - with django_assert_num_queries(2): + with django_assert_num_queries(1): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 3 diff --git a/saleor/graphql/account/tests/queries/test_me.py b/saleor/graphql/account/tests/queries/test_me.py index d012191d319..2ed22c3e10c 100644 --- a/saleor/graphql/account/tests/queries/test_me.py +++ b/saleor/graphql/account/tests/queries/test_me.py @@ -375,7 +375,9 @@ def test_me_query_stored_payment_methods( ) # then - mocked_list_stored_payment_methods.assert_called_once_with(request_data) + mocked_list_stored_payment_methods.assert_called_once_with( + request_data, channel_slug=channel_USD.slug + ) content = get_graphql_content(response) data = content["data"]["me"] diff --git a/saleor/graphql/account/types.py b/saleor/graphql/account/types.py index beeb8341f7c..11502adaa7d 100644 --- a/saleor/graphql/account/types.py +++ b/saleor/graphql/account/types.py @@ -707,7 +707,9 @@ def get_stored_payment_methods(data): user=root, channel=channel_obj, ) - return manager.list_stored_payment_methods(request_data) + return manager.list_stored_payment_methods( + request_data, channel_slug=channel + ) return Promise.all( [ diff --git a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py index e0cba93fdc0..5d0c084f38b 100644 --- a/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py +++ b/saleor/graphql/checkout/tests/benchmark/test_checkout_mutations.py @@ -417,7 +417,7 @@ def test_create_checkout_with_reservations( } } - with django_assert_num_queries(71): + with django_assert_num_queries(72): response = api_client.post_graphql(query, variables) assert get_graphql_content(response)["data"]["checkoutCreate"] assert Checkout.objects.first().lines.count() == 1 @@ -435,7 +435,7 @@ def test_create_checkout_with_reservations( } } - with django_assert_num_queries(71): + with django_assert_num_queries(72): response = api_client.post_graphql(query, variables) assert get_graphql_content(response)["data"]["checkoutCreate"] assert Checkout.objects.first().lines.count() == 10 @@ -566,7 +566,7 @@ def test_create_checkout_with_order_promotion( } # when - with django_assert_num_queries(76): + with django_assert_num_queries(77): response = user_api_client.post_graphql(MUTATION_CHECKOUT_CREATE, variables) # then @@ -821,7 +821,7 @@ def test_update_checkout_lines_with_reservations( reservation_length=5, ) - with django_assert_num_queries(90): + with django_assert_num_queries(91): variant_id = graphene.Node.to_global_id("ProductVariant", variants[0].pk) variables = { "id": to_global_id_or_none(checkout), @@ -835,7 +835,7 @@ def test_update_checkout_lines_with_reservations( assert not data["errors"] # Updating multiple lines in checkout has same query count as updating one - with django_assert_num_queries(90): + with django_assert_num_queries(91): variables = { "id": to_global_id_or_none(checkout), "lines": [], @@ -1080,7 +1080,7 @@ def test_add_checkout_lines_with_reservations( new_lines.append({"quantity": 2, "variantId": variant_id}) # Adding multiple lines to checkout has same query count as adding one - with django_assert_num_queries(89): + with django_assert_num_queries(90): variables = { "id": Node.to_global_id("Checkout", checkout.pk), "lines": [new_lines[0]], @@ -1093,7 +1093,7 @@ def test_add_checkout_lines_with_reservations( checkout.lines.exclude(id=line.id).delete() - with django_assert_num_queries(89): + with django_assert_num_queries(90): variables = { "id": Node.to_global_id("Checkout", checkout.pk), "lines": new_lines, @@ -1143,7 +1143,7 @@ def test_add_checkout_lines_catalogue_discount_applies( } # when - with django_assert_num_queries(81): + with django_assert_num_queries(82): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then @@ -1228,7 +1228,7 @@ def test_add_checkout_lines_multiple_catalogue_discount_applies( } # when - with django_assert_num_queries(81): + with django_assert_num_queries(82): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then @@ -1263,7 +1263,7 @@ def test_add_checkout_lines_order_discount_applies( } # when - with django_assert_num_queries(84): + with django_assert_num_queries(85): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then @@ -1297,7 +1297,7 @@ def test_add_checkout_lines_gift_discount_applies( } # when - with django_assert_num_queries(111): + with django_assert_num_queries(112): response = user_api_client.post_graphql(MUTATION_CHECKOUT_LINES_ADD, variables) # then diff --git a/saleor/graphql/checkout/tests/test_checkout.py b/saleor/graphql/checkout/tests/test_checkout.py index 2d449948662..3d342aa9cf5 100644 --- a/saleor/graphql/checkout/tests/test_checkout.py +++ b/saleor/graphql/checkout/tests/test_checkout.py @@ -2708,7 +2708,9 @@ def test_checkout_with_stored_payment_methods_empty_response( # then content = get_graphql_content(response) - mocked_list_stored_payment_methods.assert_called_once_with(request_data) + mocked_list_stored_payment_methods.assert_called_once_with( + request_data, channel_slug=checkout_with_prices.channel.slug + ) assert content["data"]["checkout"]["storedPaymentMethods"] == [] @@ -2779,7 +2781,9 @@ def test_checkout_with_stored_payment_methods( # then content = get_graphql_content(response) - mocked_list_stored_payment_methods.assert_called_once_with(request_data) + mocked_list_stored_payment_methods.assert_called_once_with( + request_data, channel_slug=checkout_with_prices.channel.slug + ) assert content["data"]["checkout"]["storedPaymentMethods"] == [ { "id": payment_method_id, diff --git a/saleor/graphql/checkout/types.py b/saleor/graphql/checkout/types.py index c2bb387e79f..bfccba5d092 100644 --- a/saleor/graphql/checkout/types.py +++ b/saleor/graphql/checkout/types.py @@ -1288,7 +1288,9 @@ def _resolve_stored_payment_methods(data): user=user, channel=channel, ) - return manager.list_stored_payment_methods(request_data) + return manager.list_stored_payment_methods( + request_data, channel_slug=channel.slug + ) manager = get_plugin_manager_promise(info.context) channel_loader = ChannelByIdLoader(info.context).load(root.channel_id) diff --git a/saleor/graphql/discount/tests/benchmark/test_promotion_create.py b/saleor/graphql/discount/tests/benchmark/test_promotion_create.py index 6f7d0bcaa4e..a2c4f35bc24 100644 --- a/saleor/graphql/discount/tests/benchmark/test_promotion_create.py +++ b/saleor/graphql/discount/tests/benchmark/test_promotion_create.py @@ -196,7 +196,7 @@ def test_promotion_create_order_promotion( } # when - with django_assert_num_queries(37): + with django_assert_num_queries(36): response = staff_api_client.post_graphql(PROMOTION_CREATE_MUTATION, variables) # then diff --git a/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py b/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py index b16ff78476f..718b36f042f 100644 --- a/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py +++ b/saleor/graphql/discount/tests/benchmark/test_promotion_delete.py @@ -50,7 +50,7 @@ def test_gift_promotion_delete( } # when - with django_assert_num_queries(25): + with django_assert_num_queries(24): content = get_graphql_content( staff_api_client.post_graphql(PROMOTION_DELETE_MUTATION, variables) ) diff --git a/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py b/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py index a4e1217291e..7a54e10bd21 100644 --- a/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py +++ b/saleor/graphql/discount/tests/benchmark/test_promotion_rule_create.py @@ -133,7 +133,7 @@ def test_promotion_rule_create_gift( } # when - with django_assert_num_queries(17): + with django_assert_num_queries(16): content = get_graphql_content( staff_api_client.post_graphql(PROMOTION_RULE_CREATE_MUTATION, variables) ) diff --git a/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py b/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py index c24b6855e96..c46c3439f0e 100644 --- a/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py +++ b/saleor/graphql/discount/tests/benchmark/test_voucher_code_bulk_delete.py @@ -29,7 +29,7 @@ def test_voucher_code_bulk_delete_queries( variables = {"ids": ids[:1]} # when - with django_assert_num_queries(10): + with django_assert_num_queries(9): response = staff_api_client.post_graphql( VOUCHER_CODE_BULK_DELETE_MUTATION, variables ) @@ -38,7 +38,7 @@ def test_voucher_code_bulk_delete_queries( variables = {"ids": ids[1:]} - with django_assert_num_queries(10): + with django_assert_num_queries(9): response = staff_api_client.post_graphql( VOUCHER_CODE_BULK_DELETE_MUTATION, variables ) diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py index e31c3c699d1..47e81c59662 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_gateway_initialize_tokenization.py @@ -70,7 +70,7 @@ def _perform_mutation(cls, root, info, id, channel, data=None): manager = get_plugin_manager_promise(info.context).get() is_active = manager.is_event_active_for_any_plugin( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel.slug ) if not is_active: diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py index 535b13728a3..6a52443cfee 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_intialize_tokenization.py @@ -82,6 +82,7 @@ def _perform_mutation( payment_flow_to_support=payment_flow_to_support, ), PaymentMethodInitializeTokenizationErrorCode, + channel, ) return cls( diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py index e2164989a08..d5dbf3543e6 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_process_tokenization.py @@ -66,6 +66,7 @@ def _perform_mutation(cls, root, info, id, channel, data=None): id=id, user=user, channel=channel, data=data ), PaymentMethodProcessTokenizationErrorCode, + channel, ) return cls( result=response.result, data=response.data, errors=errors, id=response.id diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py index 5a2743c733a..f2eb1ac366d 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/payment_method_request_delete.py @@ -72,7 +72,7 @@ def _perform_mutation(cls, root, info, id, channel): manager = get_plugin_manager_promise(info.context).get() is_active = manager.is_event_active_for_any_plugin( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel.slug ) if not is_active: raise ValidationError( diff --git a/saleor/graphql/payment/mutations/stored_payment_methods/utils.py b/saleor/graphql/payment/mutations/stored_payment_methods/utils.py index a4670de05b7..edf19270ee6 100644 --- a/saleor/graphql/payment/mutations/stored_payment_methods/utils.py +++ b/saleor/graphql/payment/mutations/stored_payment_methods/utils.py @@ -1,5 +1,6 @@ from django.core.exceptions import ValidationError +from .....channel.models import Channel from .....payment.interface import ( PaymentMethodTokenizationBaseRequestData, PaymentMethodTokenizationResponseData, @@ -14,9 +15,12 @@ def handle_payment_method_action( manager_func_name: str, request_data: PaymentMethodTokenizationBaseRequestData, error_type_class, + channel: "Channel", ) -> tuple[PaymentMethodTokenizationResponseData, list[dict]]: manager = get_plugin_manager_promise(info.context).get() - is_active = manager.is_event_active_for_any_plugin(manager_func_name) + is_active = manager.is_event_active_for_any_plugin( + manager_func_name, channel_slug=channel.slug + ) if not is_active: raise ValidationError( diff --git a/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py b/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py index 2b1c703856f..742172ea1b0 100644 --- a/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py +++ b/saleor/graphql/payment/tests/mutations/test_payment_gateway_initialize_tokenization.py @@ -77,7 +77,7 @@ def test_payment_gateway_initialize_tokenization( assert response_data["data"] == expected_output_data mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_gateway_initialize_tokenization.assert_called_once_with( request_data=PaymentGatewayInitializeTokenizationRequestData( @@ -168,7 +168,7 @@ def test_payment_gateway_initialize_tokenization_not_app_or_plugin_subscribed_to ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel_USD.slug ) assert not mocked_payment_gateway_initialize_tokenization.called @@ -248,7 +248,7 @@ def test_payment_gateway_initialize_tokenization_failure_from_app( assert error["message"] == error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_gateway_initialize_tokenization" + "payment_gateway_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_gateway_initialize_tokenization.assert_called_once_with( request_data=PaymentGatewayInitializeTokenizationRequestData( diff --git a/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py b/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py index f514e352842..a2c647c680d 100644 --- a/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py +++ b/saleor/graphql/payment/tests/mutations/test_payment_method_initialize_tokenization.py @@ -87,7 +87,7 @@ def test_payment_method_initialize_tokenization( assert response_data["data"] == expected_output_data assert response_data["id"] == expected_payment_method_id mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_initialize_tokenization" + "payment_method_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_initialize_tokenization.assert_called_once_with( request_data=PaymentMethodInitializeTokenizationRequestData( @@ -191,7 +191,7 @@ def test_payment_method_initialize_tokenization_not_app_or_plugin_subscribed_to_ ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_initialize_tokenization" + "payment_method_initialize_tokenization", channel_slug=channel_USD.slug ) assert not mocked_payment_method_initialize_tokenization.called @@ -275,7 +275,7 @@ def test_payment_method_initialize_tokenization_failure_from_app( assert error["message"] == error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_initialize_tokenization" + "payment_method_initialize_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_initialize_tokenization.assert_called_once_with( request_data=PaymentMethodInitializeTokenizationRequestData( diff --git a/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py b/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py index a54566fd2f8..02c429728be 100644 --- a/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py +++ b/saleor/graphql/payment/tests/mutations/test_payment_method_process_tokenization.py @@ -83,7 +83,7 @@ def test_payment_method_process_tokenization( assert response_data["data"] == expected_output_data assert response_data["id"] == expected_payment_method_id mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_process_tokenization" + "payment_method_process_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_process_tokenization.assert_called_once_with( request_data=PaymentMethodProcessTokenizationRequestData( @@ -177,7 +177,7 @@ def test_payment_method_process_tokenization_not_app_or_plugin_subscribed_to_eve ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_process_tokenization" + "payment_method_process_tokenization", channel_slug=channel_USD.slug ) assert not mocked_payment_method_process_tokenization.called @@ -256,7 +256,7 @@ def test_payment_method_process_tokenization_failure_from_app( assert error["message"] == error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "payment_method_process_tokenization" + "payment_method_process_tokenization", channel_slug=channel_USD.slug ) mocked_payment_method_process_tokenization.assert_called_once_with( request_data=PaymentMethodProcessTokenizationRequestData( diff --git a/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py b/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py index a1c3aaa91b3..718391f3a13 100644 --- a/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py +++ b/saleor/graphql/payment/tests/mutations/test_stored_payment_method_request_delete.py @@ -56,14 +56,14 @@ def test_stored_payment_method_request_delete( ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel_USD.slug ) mocked_stored_payment_method_request_delete.assert_called_once_with( request_delete_data=StoredPaymentMethodRequestDeleteData( payment_method_id=expected_id, user=user_api_client.user, channel=channel_USD, - ) + ), ) @@ -104,7 +104,7 @@ def test_stored_payment_method_request_delete_app_returned_failure_event( assert error["message"] == expected_error_message mocked_is_event_active_for_any_plugin.assert_called_once_with( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel_USD.slug ) mocked_stored_payment_method_request_delete.assert_called_once_with( request_delete_data=StoredPaymentMethodRequestDeleteData( @@ -212,7 +212,7 @@ def test_stored_payment_method_request_delete_not_app_or_plugin_subscribed_to_ev ) mocked_is_event_active_for_any_plugin.assert_called_once_with( - "stored_payment_method_request_delete" + "stored_payment_method_request_delete", channel_slug=channel_USD.slug ) assert not mocked_stored_payment_method_request_delete.called diff --git a/saleor/graphql/plugins/resolvers.py b/saleor/graphql/plugins/resolvers.py index d96ded2c721..de72daf219c 100644 --- a/saleor/graphql/plugins/resolvers.py +++ b/saleor/graphql/plugins/resolvers.py @@ -43,7 +43,7 @@ def aggregate_plugins_configuration( plugins_per_channel: dict[str, list[BasePlugin]] = defaultdict(list) global_plugins: dict[str, BasePlugin] = {} - for plugin in manager.all_plugins: + for plugin in manager.get_all_plugins(): hide_private_configuration_fields(plugin.configuration, plugin.CONFIG_STRUCTURE) if plugin.HIDDEN is True: continue diff --git a/saleor/graphql/product/mutations/utils.py b/saleor/graphql/product/mutations/utils.py index ba9ac603080..9f1f15a443f 100644 --- a/saleor/graphql/product/mutations/utils.py +++ b/saleor/graphql/product/mutations/utils.py @@ -9,8 +9,7 @@ def clean_tax_code(cleaned_input: dict, manager: PluginsManager): This function provides backwards compatibility for the `taxCode` input field. If the `taxClass` is not provided but the `taxCode` is, try to find a tax class with given - tax code and assign it to the product type. If no matching tax class is found, - create one with the given tax code. + tax code and assign it to the product type. """ tax_code = cleaned_input.get("tax_code") if tax_code and "tax_class" not in cleaned_input: @@ -19,8 +18,4 @@ def clean_tax_code(cleaned_input: dict, manager: PluginsManager): | Q(metadata__contains={"avatax.code": tax_code}) | Q(metadata__contains={"vatlayer.code": tax_code}) ).first() - if not tax_class: - tax_class = TaxClass.objects.create(name=tax_code) - manager.assign_tax_code_to_object_meta(tax_class, tax_code) - tax_class.save(update_fields=["metadata"]) cleaned_input["tax_class"] = tax_class diff --git a/saleor/graphql/product/tests/benchmark/test_collection.py b/saleor/graphql/product/tests/benchmark/test_collection.py index de81b3ba626..752b1bf0a91 100644 --- a/saleor/graphql/product/tests/benchmark/test_collection.py +++ b/saleor/graphql/product/tests/benchmark/test_collection.py @@ -418,7 +418,7 @@ def test_collections_for_federation_query_count( ], } - with django_assert_num_queries(3): + with django_assert_num_queries(2): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -436,7 +436,7 @@ def test_collections_for_federation_query_count( ], } - with django_assert_num_queries(3): + with django_assert_num_queries(2): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 3 diff --git a/saleor/graphql/product/tests/benchmark/test_product.py b/saleor/graphql/product/tests/benchmark/test_product.py index 6b61e6c3de4..1e5c13b09de 100644 --- a/saleor/graphql/product/tests/benchmark/test_product.py +++ b/saleor/graphql/product/tests/benchmark/test_product.py @@ -743,7 +743,7 @@ def test_products_for_federation_query_count( ], } - with django_assert_num_queries(6): + with django_assert_num_queries(5): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -765,7 +765,7 @@ def test_products_for_federation_query_count( ], } - with django_assert_num_queries(6): + with django_assert_num_queries(5): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 2 diff --git a/saleor/graphql/product/tests/benchmark/test_variant.py b/saleor/graphql/product/tests/benchmark/test_variant.py index c904290478e..fa5f8d2775e 100644 --- a/saleor/graphql/product/tests/benchmark/test_variant.py +++ b/saleor/graphql/product/tests/benchmark/test_variant.py @@ -384,7 +384,7 @@ def test_products_variants_for_federation_query_count( ], } - with django_assert_num_queries(5): + with django_assert_num_queries(4): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 1 @@ -400,7 +400,7 @@ def test_products_variants_for_federation_query_count( ], } - with django_assert_num_queries(5): + with django_assert_num_queries(4): response = api_client.post_graphql(query, variables) content = get_graphql_content(response) assert len(content["data"]["_entities"]) == 4 diff --git a/saleor/graphql/product/tests/deprecated/test_utils.py b/saleor/graphql/product/tests/deprecated/test_utils.py index a17f1d7b5ca..158f70681a2 100644 --- a/saleor/graphql/product/tests/deprecated/test_utils.py +++ b/saleor/graphql/product/tests/deprecated/test_utils.py @@ -82,6 +82,4 @@ def test_clean_tax_code_when_tax_class_does_not_exists(): clean_tax_code(data, manager) # then - assert TaxClass.objects.count() == 1 - tax_class = TaxClass.objects.first() - assert data["tax_class"] == tax_class + assert data["tax_class"] is None diff --git a/saleor/graphql/product/tests/mutations/test_product_create.py b/saleor/graphql/product/tests/mutations/test_product_create.py index 747862a9edf..c1884af71c8 100644 --- a/saleor/graphql/product/tests/mutations/test_product_create.py +++ b/saleor/graphql/product/tests/mutations/test_product_create.py @@ -119,7 +119,7 @@ def test_create_product( monkeypatch.setattr( PluginsManager, "get_tax_code_from_object_meta", - lambda self, x: TaxType(description="", code=product_tax_rate), + lambda self, x, channel_slug: TaxType(description="", code=product_tax_rate), ) # Default attribute defined in product_type fixture diff --git a/saleor/graphql/product/tests/mutations/test_product_update.py b/saleor/graphql/product/tests/mutations/test_product_update.py index dd0b477de68..7a348dd1b32 100644 --- a/saleor/graphql/product/tests/mutations/test_product_update.py +++ b/saleor/graphql/product/tests/mutations/test_product_update.py @@ -132,7 +132,7 @@ def test_update_product( monkeypatch.setattr( PluginsManager, "get_tax_code_from_object_meta", - lambda self, x: TaxType(description="", code=product_tax_rate), + lambda self, x, channel_slug: TaxType(description="", code=product_tax_rate), ) attribute_id = graphene.Node.to_global_id("Attribute", color_attribute.pk) diff --git a/saleor/graphql/product/types/products.py b/saleor/graphql/product/types/products.py index 6322671a2ba..67319708e39 100644 --- a/saleor/graphql/product/types/products.py +++ b/saleor/graphql/product/types/products.py @@ -1052,7 +1052,9 @@ def resolve_description_json(root: ChannelContext[models.Product], _info): def resolve_tax_type(root: ChannelContext[models.Product], info): def with_tax_class(data): tax_class, manager = data - tax_data = manager.get_tax_code_from_object_meta(tax_class) + tax_data = manager.get_tax_code_from_object_meta( + tax_class, channel_slug=root.channel_slug + ) return TaxType(tax_code=tax_data.code, description=tax_data.description) if root.node.tax_class_id: diff --git a/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py b/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py index a4a3209698b..6aa21c31e2e 100644 --- a/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py +++ b/saleor/graphql/warehouse/tests/benchmark/test_stock_bulk_update.py @@ -71,7 +71,7 @@ def test_stocks_bulk_update_queries_count( ] # test number of queries when single object is updated - with django_assert_num_queries(12): + with django_assert_num_queries(11): staff_api_client.user.user_permissions.add(permission_manage_products) response = staff_api_client.post_graphql( STOCKS_BULK_UPDATE_MUTATION, {"stocks": stocks_input} @@ -107,7 +107,7 @@ def test_stocks_bulk_update_queries_count( ] # Test number of queries when multiple objects are updated - with django_assert_num_queries(12): + with django_assert_num_queries(11): staff_api_client.user.user_permissions.add(permission_manage_products) response = staff_api_client.post_graphql( STOCKS_BULK_UPDATE_MUTATION, {"stocks": stocks_input} diff --git a/saleor/payment/gateway.py b/saleor/payment/gateway.py index 9a834d9e9f4..32ec48b58f6 100644 --- a/saleor/payment/gateway.py +++ b/saleor/payment/gateway.py @@ -487,7 +487,7 @@ def list_payment_sources( gateway: str, customer_id: str, manager: "PluginsManager", - channel_slug: str, + channel_slug: Optional[str], ) -> list["CustomerSource"]: return manager.list_payment_sources(gateway, customer_id, channel_slug=channel_slug) diff --git a/saleor/payment/gateways/adyen/tests/conftest.py b/saleor/payment/gateways/adyen/tests/conftest.py index c94aa53ab8d..9e1cd43dbe5 100644 --- a/saleor/payment/gateways/adyen/tests/conftest.py +++ b/saleor/payment/gateways/adyen/tests/conftest.py @@ -64,6 +64,7 @@ def fun( ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/payment/gateways/authorize_net/tests/conftest.py b/saleor/payment/gateways/authorize_net/tests/conftest.py index c93faddc8d3..831481cb27c 100644 --- a/saleor/payment/gateways/authorize_net/tests/conftest.py +++ b/saleor/payment/gateways/authorize_net/tests/conftest.py @@ -71,4 +71,5 @@ def authorize_net_plugin(_, settings, channel_USD, authorize_net_gateway_config) ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] diff --git a/saleor/payment/gateways/np_atobarai/tests/conftest.py b/saleor/payment/gateways/np_atobarai/tests/conftest.py index c342200c4db..6e31fe500ed 100644 --- a/saleor/payment/gateways/np_atobarai/tests/conftest.py +++ b/saleor/payment/gateways/np_atobarai/tests/conftest.py @@ -57,6 +57,7 @@ def fun( ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/payment/gateways/stripe/tests/conftest.py b/saleor/payment/gateways/stripe/tests/conftest.py index 12097220b22..8937ff0cc21 100644 --- a/saleor/payment/gateways/stripe/tests/conftest.py +++ b/saleor/payment/gateways/stripe/tests/conftest.py @@ -110,6 +110,7 @@ def fun( ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/plugins/admin_email/tests/conftest.py b/saleor/plugins/admin_email/tests/conftest.py index dc619b27f39..b373c6d8cdf 100644 --- a/saleor/plugins/admin_email/tests/conftest.py +++ b/saleor/plugins/admin_email/tests/conftest.py @@ -123,6 +123,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.global_plugins[0] return fun diff --git a/saleor/plugins/admin_email/tests/test_plugin.py b/saleor/plugins/admin_email/tests/test_plugin.py index f88910fac09..8f971960589 100644 --- a/saleor/plugins/admin_email/tests/test_plugin.py +++ b/saleor/plugins/admin_email/tests/test_plugin.py @@ -281,6 +281,7 @@ def test_plugin_manager_doesnt_load_email_templates_from_db( ): settings.PLUGINS = ["saleor.plugins.admin_email.plugin.AdminEmailPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() plugin = manager.all_plugins[0] email_config_item = None diff --git a/saleor/plugins/avatax/plugin.py b/saleor/plugins/avatax/plugin.py index 31c067f12bd..c6456613f9b 100644 --- a/saleor/plugins/avatax/plugin.py +++ b/saleor/plugins/avatax/plugin.py @@ -822,31 +822,10 @@ def _get_shipping_tax_rate( ) return base_rate - def assign_tax_code_to_object_meta( + def get_tax_code_from_object_meta( self, - obj: "TaxClass", - tax_code: Optional[str], + obj: Union["Product", "ProductType", "TaxClass"], previous_value: Any, - ): - if not self.active: - return previous_value - - if tax_code is None and obj.pk: - obj.delete_value_from_metadata(META_CODE_KEY) - obj.delete_value_from_metadata(META_DESCRIPTION_KEY) - return previous_value - - codes = get_cached_tax_codes_or_fetch(self.config) - if tax_code not in codes: - return previous_value - - tax_description = codes.get(tax_code) - tax_item = {META_CODE_KEY: tax_code, META_DESCRIPTION_KEY: tax_description} - obj.store_value_in_metadata(items=tax_item) - return previous_value - - def get_tax_code_from_object_meta( - self, obj: Union["Product", "ProductType", "TaxClass"], previous_value: Any ) -> TaxType: if not self.active: return previous_value @@ -867,11 +846,6 @@ def get_tax_code_from_object_meta( description=tax_description, ) - def show_taxes_on_storefront(self, previous_value: bool) -> bool: - if not self.active: - return previous_value - return False - @classmethod def validate_authentication(cls, plugin_configuration: "PluginConfiguration"): conf = { diff --git a/saleor/plugins/avatax/tests/deprecated/__init__.py b/saleor/plugins/avatax/tests/deprecated/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/saleor/plugins/avatax/tests/deprecated/test_tax_code_mutations.py b/saleor/plugins/avatax/tests/deprecated/test_tax_code_mutations.py deleted file mode 100644 index 4922a0e8acb..00000000000 --- a/saleor/plugins/avatax/tests/deprecated/test_tax_code_mutations.py +++ /dev/null @@ -1,259 +0,0 @@ -import graphene -from django.test import override_settings - -from .....graphql.tests.utils import get_graphql_content -from .....tax.models import TaxClass - -PRODUCT_TYPE_CREATE_MUTATION_TAX_CODE = """ - mutation createProductType($name: String, $taxCode: String) { - productTypeCreate(input: {name: $name, taxCode: $taxCode}) { - productType { - name - slug - taxClass { - id - name - metadata { - key - value - } - } - } - errors { - field - message - code - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_type_create_tax_code_creates_new_tax_class( - staff_api_client, - permission_manage_product_types_and_attributes, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = {"name": "New product type", "taxCode": tax_code} - TaxClass.objects.all().delete() - - # when - response = staff_api_client.post_graphql( - PRODUCT_TYPE_CREATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_product_types_and_attributes], - ) - - # then - content = get_graphql_content(response) - tax_class = TaxClass.objects.first() - assert tax_class - assert tax_class.name == tax_code - assert tax_class.metadata - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productTypeCreate"]["productType"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) - - -PRODUCT_TYPE_UPDATE_MUTATION_TAX_CODE = """ - mutation updateProductType($id: ID!, $taxCode: String) { - productTypeUpdate(id: $id, input: {taxCode: $taxCode}) { - productType { - name - slug - taxClass { - id - name - metadata { - key - value - } - } - } - errors { - field - message - code - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_type_update_tax_code_creates_new_tax_class( - staff_api_client, - product_type, - permission_manage_product_types_and_attributes, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = { - "id": graphene.Node.to_global_id("ProductType", product_type.pk), - "taxCode": tax_code, - } - - # when - response = staff_api_client.post_graphql( - PRODUCT_TYPE_UPDATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_product_types_and_attributes], - ) - - # then - content = get_graphql_content(response) - product_type.refresh_from_db() - tax_class = product_type.tax_class - assert tax_class - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productTypeUpdate"]["productType"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) - - -PRODUCT_CREATE_MUTATION_TAX_CODE = """ - mutation ProductCreate($name: String!, $productTypeId: ID!, $taxCode: String) { - productCreate( - input: {name: $name, productType: $productTypeId, taxCode: $taxCode} - ) { - errors { - field - message - } - product { - id - name - taxClass { - id - name - metadata { - key - value - } - } - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_create_tax_code_creates_new_tax_class( - staff_api_client, - product_type, - permission_manage_products, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = { - "name": "New product", - "productTypeId": graphene.Node.to_global_id("ProductType", product_type.pk), - "taxCode": tax_code, - } - TaxClass.objects.all().delete() - - # when - response = staff_api_client.post_graphql( - PRODUCT_CREATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_products], - ) - - # then - content = get_graphql_content(response) - tax_class = TaxClass.objects.first() - assert tax_class - assert tax_class.name == tax_code - assert tax_class.metadata - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productCreate"]["product"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) - - -PRODUCT_UPDATE_MUTATION_TAX_CODE = """ - mutation ProductUpdate($id: ID!, $taxCode: String) { - productUpdate(id: $id, input: {taxCode: $taxCode}) { - errors { - field - message - } - product { - id - name - taxClass { - id - name - metadata { - key - value - } - } - } - } - } -""" - - -@override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) -def test_product_update_tax_code_creates_new_tax_class( - staff_api_client, - permission_manage_products, - product, - plugin_configuration, - monkeypatch, -): - # given - plugin_configuration() - tax_code = "P0000000" - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: "desc"}, - ) - variables = { - "id": graphene.Node.to_global_id("Product", product.pk), - "taxCode": tax_code, - } - - # when - response = staff_api_client.post_graphql( - PRODUCT_UPDATE_MUTATION_TAX_CODE, - variables, - permissions=[permission_manage_products], - ) - - # then - content = get_graphql_content(response) - product.refresh_from_db() - tax_class = product.tax_class - assert tax_class - assert tax_class.name == tax_code - assert tax_class.metadata - assert tax_class.metadata["avatax.code"] == tax_code - assert content["data"]["productUpdate"]["product"]["taxClass"][ - "id" - ] == graphene.Node.to_global_id("TaxClass", tax_class.pk) diff --git a/saleor/plugins/avatax/tests/test_avatax.py b/saleor/plugins/avatax/tests/test_avatax.py index 32275c3878d..eb76f92509a 100644 --- a/saleor/plugins/avatax/tests/test_avatax.py +++ b/saleor/plugins/avatax/tests/test_avatax.py @@ -74,6 +74,11 @@ def order_set_shipping_method(order, shipping_method): ) +def assign_tax_code_to_object_meta(obj: "TaxClass", tax_code: str): + tax_item = {META_CODE_KEY: tax_code} + obj.store_value_in_metadata(items=tax_item) + + @pytest.mark.vcr @pytest.mark.parametrize( ("expected_net", "expected_gross", "prices_entered_with_tax"), @@ -1218,7 +1223,7 @@ def test_calculate_checkout_total_uses_default_calculation( line = checkout_with_item.lines.first() product = line.variant.product product.metadata = {} - manager.assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") + assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") product.save() product.product_type.save() product_with_single_variant.tax_class = tax_class_zero_rates @@ -1286,7 +1291,7 @@ def test_calculate_checkout_total_uses_default_calculation_with_promotion( line = checkout.lines.first() product = line.variant.product product.metadata = {} - manager.assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") + assign_tax_code_to_object_meta(product.product_type.tax_class, "PC040156") product.save() product.product_type.save() product_with_single_variant.tax_class = tax_class_zero_rates @@ -1358,7 +1363,7 @@ def test_calculate_checkout_total( line = checkout_with_item.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1418,7 +1423,7 @@ def test_calculate_checkout_total_with_order_promotion( line = checkout.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1480,7 +1485,7 @@ def test_calculate_checkout_total_with_gift_promotion( line = checkout.lines.get(is_gift=False) product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1559,7 +1564,7 @@ def test_calculate_checkout_total_with_promotion( line = checkout.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1641,7 +1646,7 @@ def test_calculate_checkout_total_for_JPY( line = checkout.lines.first() product = line.variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -1715,7 +1720,7 @@ def test_calculate_checkout_total_for_JPY_with_promotion( variant = line.variant product = variant.product - manager.assign_tax_code_to_object_meta(product.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.tax_class, "PS081282") product.save() product.tax_class.save() @@ -2058,7 +2063,7 @@ def test_calculate_checkout_total_not_charged_product_and_shipping_with_0_price( variant = line.variant product = variant.product product.metadata = {} - manager.assign_tax_code_to_object_meta(product.product_type.tax_class, "PS081282") + assign_tax_code_to_object_meta(product.product_type.tax_class, "PS081282") product.save() product.product_type.save() @@ -4408,7 +4413,7 @@ def test_get_order_shipping_tax_rate_skip_plugin( def test_get_plugin_configuration(settings, channel_USD): settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] manager = get_plugins_manager(allow_replica=False) - plugin = manager.get_plugin(AvataxPlugin.PLUGIN_ID) + plugin = manager.get_plugin(AvataxPlugin.PLUGIN_ID, channel_slug=channel_USD.slug) configuration_fields = [ configuration_item["name"] for configuration_item in plugin.configuration @@ -4492,12 +4497,6 @@ def test_save_plugin_configuration_cannot_be_enabled_without_config( ) -def test_show_taxes_on_storefront(plugin_configuration): - plugin_configuration() - manager = get_plugins_manager(allow_replica=False) - assert manager.show_taxes_on_storefront() is False - - @patch("saleor.plugins.avatax.plugin.api_post_request_task.delay") @override_settings(PLUGINS=["saleor.plugins.avatax.plugin.AvataxPlugin"]) def test_order_confirmed( @@ -4796,11 +4795,13 @@ def test_plugin_uses_configuration_from_db( manager.preprocess_order_creation(checkout_info, lines) -def test_skip_disabled_plugin(settings, plugin_configuration): +def test_skip_disabled_plugin(settings, plugin_configuration, channel_USD): plugin_configuration(username=None, password=None) settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] manager = get_plugins_manager(allow_replica=False) - plugin: AvataxPlugin = manager.get_plugin(AvataxPlugin.PLUGIN_ID) + plugin: AvataxPlugin = manager.get_plugin( + AvataxPlugin.PLUGIN_ID, channel_slug=channel_USD.slug + ) assert ( plugin._skip_plugin( @@ -4810,14 +4811,18 @@ def test_skip_disabled_plugin(settings, plugin_configuration): ) -def test_get_tax_code_from_object_meta(product, settings, plugin_configuration): +def test_get_tax_code_from_object_meta( + product, settings, plugin_configuration, channel_USD +): product.tax_class.store_value_in_metadata( {META_CODE_KEY: "KEY", META_DESCRIPTION_KEY: "DESC"} ) plugin_configuration(username=None, password=None) settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] manager = get_plugins_manager(allow_replica=False) - tax_type = manager.get_tax_code_from_object_meta(product.tax_class) + tax_type = manager.get_tax_code_from_object_meta( + product.tax_class, channel_USD.slug + ) assert isinstance(tax_type, TaxType) assert tax_type.code == "KEY" @@ -5977,108 +5982,6 @@ def test_get_order_lines_data_adds_lines_with_taxes_disabled_for_line( assert len(lines_data) == len(order_with_lines.lines.all()) -def test_assign_tax_code_to_object_meta( - settings, channel_USD, plugin_configuration, product, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = "standard" - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: description}, - ) - - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(product.tax_class, tax_code) - - # then - assert product.tax_class.metadata == { - META_CODE_KEY: tax_code, - META_DESCRIPTION_KEY: description, - } - - -def test_assign_tax_code_to_object_meta_none_as_tax_code( - settings, channel_USD, plugin_configuration, product, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = None - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {"standard": description}, - ) - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(product.tax_class, tax_code) - - # then - assert product.metadata == {} - - -def test_assign_tax_code_to_object_meta_no_obj_id_and_none_as_tax_code( - settings, channel_USD, plugin_configuration, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = None - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {"standard": description}, - ) - - tax_class = TaxClass(name="A new tax class.") - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(tax_class, tax_code) - - # then - assert tax_class.metadata == {} - - -def test_assign_tax_code_to_object_meta_no_obj_id( - settings, channel_USD, plugin_configuration, monkeypatch -): - # given - settings.PLUGINS = ["saleor.plugins.avatax.plugin.AvataxPlugin"] - plugin_configuration(channel=channel_USD) - - tax_code = "standard" - description = "desc" - - monkeypatch.setattr( - "saleor.plugins.avatax.plugin.get_cached_tax_codes_or_fetch", - lambda _: {tax_code: description}, - ) - tax_class = TaxClass(name="A new product.") - manager = get_plugins_manager(allow_replica=False) - - # when - manager.assign_tax_code_to_object_meta(tax_class, tax_code) - - # then - assert tax_class.metadata == { - META_CODE_KEY: tax_code, - META_DESCRIPTION_KEY: description, - } - - @patch("saleor.plugins.avatax.plugin.get_checkout_tax_data") def test_calculate_checkout_shipping_validates_checkout( mocked_func, settings, channel_USD, plugin_configuration, checkout_with_item diff --git a/saleor/plugins/base_plugin.py b/saleor/plugins/base_plugin.py index 9fa29eb73cd..0d65d718d32 100644 --- a/saleor/plugins/base_plugin.py +++ b/saleor/plugins/base_plugin.py @@ -244,9 +244,6 @@ def __str__(self): # status is changed. app_status_changed: Callable[["App", None], None] - # Assign tax code dedicated to plugin. - assign_tax_code_to_object_meta: Callable[["TaxClass", Union[str, None], Any], Any] - # Trigger when attribute is created. # # Overwrite this method if you need to trigger specific logic after an attribute is @@ -625,7 +622,11 @@ def __str__(self): # Return tax code from object meta. get_tax_code_from_object_meta: Callable[ - [Union["Product", "ProductType", "TaxClass"], "TaxType"], "TaxType" + [ + Union["Product", "ProductType", "TaxClass"], + "TaxType", + ], + "TaxType", ] # Return list of all tax categories. @@ -1144,12 +1145,6 @@ def __str__(self): # metadata is updated. shipping_zone_metadata_updated: Callable[["ShippingZone", None], None] - # Define if storefront should add info about taxes to the price. - # - # It is used only by the old storefront. The returned value determines if - # storefront should append info to the price about "including/excluding X% VAT". - show_taxes_on_storefront: Callable[[bool], bool] - # Trigger when staff user is created. # # Overwrite this method if you need to trigger specific logic after a staff user is diff --git a/saleor/plugins/manager.py b/saleor/plugins/manager.py index 169eb881690..6d550932235 100644 --- a/saleor/plugins/manager.py +++ b/saleor/plugins/manager.py @@ -134,69 +134,82 @@ def _load_plugin( def __init__(self, plugins: list[str], requestor_getter=None, allow_replica=True): with opentracing.global_tracer().start_active_span("PluginsManager.__init__"): + self.plugins = plugins self._allow_replica = allow_replica self.all_plugins = [] self.global_plugins = [] self.plugins_per_channel = defaultdict(list) + self.loaded_all_channels = False + self.loaded_channels: set[str] = set() + self.loaded_global = False + self.requestor_getter = requestor_getter - channel_map = self._get_channel_map() - global_db_configs, channel_db_configs = self._get_db_plugin_configs( - channel_map - ) + def _ensure_channel_plugins_loaded( + self, channel_slug: Optional[str], channel: Optional[Channel] = None + ): + if channel_slug is None and not self.loaded_global: + global_db_config = self._get_db_plugin_configs(None) - for plugin_path in plugins: + for plugin_path in self.plugins: with opentracing.global_tracer().start_active_span(f"{plugin_path}"): PluginClass = import_string(plugin_path) if not getattr(PluginClass, "CONFIGURATION_PER_CHANNEL", False): plugin = self._load_plugin( PluginClass, - global_db_configs, - requestor_getter=requestor_getter, - allow_replica=allow_replica, + global_db_config, + requestor_getter=self.requestor_getter, + allow_replica=self._allow_replica, ) self.global_plugins.append(plugin) self.all_plugins.append(plugin) - else: - for channel in channel_map.values(): - channel_configs = channel_db_configs.get(channel, {}) - plugin = self._load_plugin( - PluginClass, - channel_configs, - channel, - requestor_getter, - allow_replica, - ) - self.plugins_per_channel[channel.slug].append(plugin) - self.all_plugins.append(plugin) - - for channel in channel_map.values(): - self.plugins_per_channel[channel.slug].extend(self.global_plugins) - - def _get_db_plugin_configs(self, channel_map): + self.loaded_global = True + + if channel_slug is not None and channel_slug not in self.loaded_channels: + if channel is None: + channel = ( + Channel.objects.using(self.database) + .filter(slug=channel_slug) + .first() + ) + if not channel: + return + + channel_db_config = self._get_db_plugin_configs(channel) + + for plugin_path in self.plugins: + with opentracing.global_tracer().start_active_span(f"{plugin_path}"): + PluginClass = import_string(plugin_path) + if getattr(PluginClass, "CONFIGURATION_PER_CHANNEL", False): + plugin = self._load_plugin( + PluginClass, + channel_db_config, + channel=channel, + requestor_getter=self.requestor_getter, + allow_replica=self._allow_replica, + ) + self.plugins_per_channel[channel_slug].append(plugin) + self.all_plugins.append(plugin) + + self._ensure_channel_plugins_loaded(None) + self.plugins_per_channel[channel_slug].extend(self.global_plugins) + self.loaded_channels.add(channel_slug) + + def _get_db_plugin_configs(self, channel: Optional[Channel]): with opentracing.global_tracer().start_active_span("_get_db_plugin_configs"): plugin_manager_configs = PluginConfiguration.objects.using( self.database - ).all() - channel_configs: defaultdict[Channel, dict] = defaultdict(dict) - global_configs = {} + ).filter(channel=channel) + configs = {} for db_plugin_config in plugin_manager_configs.iterator(): - channel = channel_map.get(db_plugin_config.channel_id) - if channel is None: - global_configs[db_plugin_config.identifier] = db_plugin_config - else: - db_plugin_config.channel = channel - channel_configs[channel][db_plugin_config.identifier] = ( - db_plugin_config - ) - - return global_configs, channel_configs + configs[db_plugin_config.identifier] = db_plugin_config + return configs def __run_method_on_plugins( self, method_name: str, default_value: Any, *args, - channel_slug: Optional[str] = None, + channel_slug: Optional[str], plugin_ids: Optional[list[str]] = None, **kwargs, ): @@ -249,7 +262,13 @@ def change_user_address( ) -> "Address": default_value = address return self.__run_method_on_plugins( - "change_user_address", default_value, address, address_type, user, save + "change_user_address", + default_value, + address, + address_type, + user, + save, + channel_slug=None, ) def calculate_checkout_total( @@ -615,11 +634,15 @@ def get_order_line_tax_rate( def get_tax_rate_type_choices(self) -> list[TaxType]: default_value: list = [] - return self.__run_method_on_plugins("get_tax_rate_type_choices", default_value) - - def show_taxes_on_storefront(self) -> bool: - default_value = False - return self.__run_method_on_plugins("show_taxes_on_storefront", default_value) + plugins = self.get_all_plugins() + return ( + self.__run_plugin_method_until_first_success( + "get_tax_rate_type_choices", + channel_slug=None, + plugins=plugins, + ) + or default_value + ) def get_taxes_for_checkout( self, checkout_info, lines, app_identifier @@ -656,96 +679,131 @@ def preprocess_order_creation( def customer_created(self, customer: "User"): default_value = None - return self.__run_method_on_plugins("customer_created", default_value, customer) + return self.__run_method_on_plugins( + "customer_created", default_value, customer, channel_slug=None + ) def customer_deleted(self, customer: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "customer_deleted", default_value, customer, webhooks=webhooks + "customer_deleted", + default_value, + customer, + webhooks=webhooks, + channel_slug=None, ) def customer_updated(self, customer: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "customer_updated", default_value, customer, webhooks=webhooks + "customer_updated", + default_value, + customer, + webhooks=webhooks, + channel_slug=None, ) def customer_metadata_updated(self, customer: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "customer_metadata_updated", default_value, customer, webhooks=webhooks + "customer_metadata_updated", + default_value, + customer, + webhooks=webhooks, + channel_slug=None, ) def collection_created(self, collection: "Collection"): default_value = None return self.__run_method_on_plugins( - "collection_created", default_value, collection + "collection_created", default_value, collection, channel_slug=None ) def collection_updated(self, collection: "Collection"): default_value = None return self.__run_method_on_plugins( - "collection_updated", default_value, collection + "collection_updated", default_value, collection, channel_slug=None ) def collection_deleted(self, collection: "Collection", webhooks=None): default_value = None return self.__run_method_on_plugins( - "collection_deleted", default_value, collection, webhooks=webhooks + "collection_deleted", + default_value, + collection, + webhooks=webhooks, + channel_slug=None, ) def collection_metadata_updated(self, collection: "Collection"): default_value = None return self.__run_method_on_plugins( - "collection_metadata_updated", default_value, collection + "collection_metadata_updated", default_value, collection, channel_slug=None ) def product_created(self, product: "Product", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_created", default_value, product, webhooks=webhooks + "product_created", + default_value, + product, + webhooks=webhooks, + channel_slug=None, ) def product_updated(self, product: "Product", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_updated", default_value, product, webhooks=webhooks + "product_updated", + default_value, + product, + webhooks=webhooks, + channel_slug=None, ) def product_deleted(self, product: "Product", variants: list[int], webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_deleted", default_value, product, variants, webhooks=webhooks + "product_deleted", + default_value, + product, + variants, + webhooks=webhooks, + channel_slug=None, ) def product_media_created(self, media: "ProductMedia"): default_value = None return self.__run_method_on_plugins( - "product_media_created", default_value, media + "product_media_created", default_value, media, channel_slug=None ) def product_media_updated(self, media: "ProductMedia"): default_value = None return self.__run_method_on_plugins( - "product_media_updated", default_value, media + "product_media_updated", default_value, media, channel_slug=None ) def product_media_deleted(self, media: "ProductMedia"): default_value = None return self.__run_method_on_plugins( - "product_media_deleted", default_value, media + "product_media_deleted", default_value, media, channel_slug=None ) def product_metadata_updated(self, product: "Product"): default_value = None return self.__run_method_on_plugins( - "product_metadata_updated", default_value, product + "product_metadata_updated", default_value, product, channel_slug=None ) def product_variant_created(self, product_variant: "ProductVariant", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_variant_created", default_value, product_variant, webhooks=webhooks + "product_variant_created", + default_value, + product_variant, + webhooks=webhooks, + channel_slug=None, ) def product_variant_updated( @@ -758,42 +816,62 @@ def product_variant_updated( product_variant, webhooks=webhooks, **kwargs, + channel_slug=None, ) def product_variant_deleted(self, product_variant: "ProductVariant", webhooks=None): default_value = None return self.__run_method_on_plugins( - "product_variant_deleted", default_value, product_variant, webhooks=webhooks + "product_variant_deleted", + default_value, + product_variant, + webhooks=webhooks, + channel_slug=None, ) def product_variant_out_of_stock(self, stock: "Stock", webhooks=None): default_value = None self.__run_method_on_plugins( - "product_variant_out_of_stock", default_value, stock, webhooks=webhooks + "product_variant_out_of_stock", + default_value, + stock, + webhooks=webhooks, + channel_slug=None, ) def product_variant_back_in_stock(self, stock: "Stock", webhooks=None): default_value = None self.__run_method_on_plugins( - "product_variant_back_in_stock", default_value, stock, webhooks=webhooks + "product_variant_back_in_stock", + default_value, + stock, + webhooks=webhooks, + channel_slug=None, ) def product_variant_stock_updated(self, stock: "Stock", webhooks=None): default_value = None self.__run_method_on_plugins( - "product_variant_stock_updated", default_value, stock, webhooks=webhooks + "product_variant_stock_updated", + default_value, + stock, + webhooks=webhooks, + channel_slug=None, ) def product_variant_metadata_updated(self, product_variant: "ProductVariant"): default_value = None self.__run_method_on_plugins( - "product_variant_metadata_updated", default_value, product_variant + "product_variant_metadata_updated", + default_value, + product_variant, + channel_slug=None, ) def product_export_completed(self, export: "ExportFile"): default_value = None return self.__run_method_on_plugins( - "product_export_completed", default_value, export + "product_export_completed", default_value, export, channel_slug=None ) def order_created(self, order: "Order"): @@ -805,7 +883,7 @@ def order_created(self, order: "Order"): def event_delivery_retry(self, event_delivery: "EventDelivery"): default_value = None return self.__run_method_on_plugins( - "event_delivery_retry", default_value, event_delivery + "event_delivery_retry", default_value, event_delivery, channel_slug=None ) def order_confirmed(self, order: "Order"): @@ -835,73 +913,100 @@ def draft_order_deleted(self, order: "Order"): def sale_created(self, sale: "Promotion", current_catalogue): default_value = None return self.__run_method_on_plugins( - "sale_created", default_value, sale, current_catalogue + "sale_created", default_value, sale, current_catalogue, channel_slug=None ) def sale_deleted(self, sale: "Promotion", previous_catalogue, webhooks=None): default_value = None return self.__run_method_on_plugins( - "sale_deleted", default_value, sale, previous_catalogue, webhooks=webhooks + "sale_deleted", + default_value, + sale, + previous_catalogue, + webhooks=webhooks, + channel_slug=None, ) def sale_updated(self, sale: "Promotion", previous_catalogue, current_catalogue): default_value = None return self.__run_method_on_plugins( - "sale_updated", default_value, sale, previous_catalogue, current_catalogue + "sale_updated", + default_value, + sale, + previous_catalogue, + current_catalogue, + channel_slug=None, ) def sale_toggle(self, sale: "Promotion", catalogue, webhooks=None): default_value = None return self.__run_method_on_plugins( - "sale_toggle", default_value, sale, catalogue, webhooks=webhooks + "sale_toggle", + default_value, + sale, + catalogue, + webhooks=webhooks, + channel_slug=None, ) def promotion_created(self, promotion: "Promotion"): default_value = None return self.__run_method_on_plugins( - "promotion_created", default_value, promotion + "promotion_created", default_value, promotion, channel_slug=None ) def promotion_updated(self, promotion: "Promotion"): default_value = None return self.__run_method_on_plugins( - "promotion_updated", default_value, promotion + "promotion_updated", default_value, promotion, channel_slug=None ) def promotion_deleted(self, promotion: "Promotion", webhooks=None): default_value = None return self.__run_method_on_plugins( - "promotion_deleted", default_value, promotion, webhooks=webhooks + "promotion_deleted", + default_value, + promotion, + webhooks=webhooks, + channel_slug=None, ) def promotion_started(self, promotion: "Promotion", webhooks=None): default_value = None return self.__run_method_on_plugins( - "promotion_started", default_value, promotion, webhooks=webhooks + "promotion_started", + default_value, + promotion, + webhooks=webhooks, + channel_slug=None, ) def promotion_ended(self, promotion: "Promotion", webhooks=None): default_value = None return self.__run_method_on_plugins( - "promotion_ended", default_value, promotion, webhooks=webhooks + "promotion_ended", + default_value, + promotion, + webhooks=webhooks, + channel_slug=None, ) def promotion_rule_created(self, promotion_rule: "PromotionRule"): default_value = None return self.__run_method_on_plugins( - "promotion_rule_created", default_value, promotion_rule + "promotion_rule_created", default_value, promotion_rule, channel_slug=None ) def promotion_rule_updated(self, promotion_rule: "PromotionRule"): default_value = None return self.__run_method_on_plugins( - "promotion_rule_updated", default_value, promotion_rule + "promotion_rule_updated", default_value, promotion_rule, channel_slug=None ) def promotion_rule_deleted(self, promotion_rule: "PromotionRule"): default_value = None return self.__run_method_on_plugins( - "promotion_rule_deleted", default_value, promotion_rule + "promotion_rule_deleted", default_value, promotion_rule, channel_slug=None ) def invoice_request( @@ -1000,12 +1105,17 @@ def order_fulfilled(self, order: "Order"): def order_metadata_updated(self, order: "Order"): default_value = None return self.__run_method_on_plugins( - "order_metadata_updated", default_value, order + "order_metadata_updated", + default_value, + order, + channel_slug=order.channel.slug, ) def order_bulk_created(self, orders: list["Order"]): default_value = None - return self.__run_method_on_plugins("order_bulk_created", default_value, orders) + return self.__run_method_on_plugins( + "order_bulk_created", default_value, orders, channel_slug=None + ) def fulfillment_created( self, fulfillment: "Fulfillment", notify_customer: Optional[bool] = True @@ -1043,7 +1153,10 @@ def fulfillment_approved( def fulfillment_metadata_updated(self, fulfillment: "Fulfillment"): default_value = None return self.__run_method_on_plugins( - "fulfillment_metadata_updated", default_value, fulfillment + "fulfillment_metadata_updated", + default_value, + fulfillment, + channel_slug=fulfillment.order.channel.slug, ) def tracking_number_updated(self, fulfillment: "Fulfillment"): @@ -1085,55 +1198,64 @@ def checkout_fully_paid(self, checkout: "Checkout"): def checkout_metadata_updated(self, checkout: "Checkout"): default_value = None return self.__run_method_on_plugins( - "checkout_metadata_updated", default_value, checkout + "checkout_metadata_updated", + default_value, + checkout, + channel_slug=checkout.channel.slug, ) def page_created(self, page: "Page"): default_value = None - return self.__run_method_on_plugins("page_created", default_value, page) + return self.__run_method_on_plugins( + "page_created", default_value, page, channel_slug=None + ) def page_updated(self, page: "Page"): default_value = None - return self.__run_method_on_plugins("page_updated", default_value, page) + return self.__run_method_on_plugins( + "page_updated", default_value, page, channel_slug=None + ) def page_deleted(self, page: "Page"): default_value = None - return self.__run_method_on_plugins("page_deleted", default_value, page) + return self.__run_method_on_plugins( + "page_deleted", default_value, page, channel_slug=None + ) def page_type_created(self, page_type: "PageType"): default_value = None return self.__run_method_on_plugins( - "page_type_created", default_value, page_type + "page_type_created", default_value, page_type, channel_slug=None ) def page_type_updated(self, page_type: "PageType"): default_value = None return self.__run_method_on_plugins( - "page_type_updated", default_value, page_type + "page_type_updated", default_value, page_type, channel_slug=None ) def page_type_deleted(self, page_type: "PageType"): default_value = None return self.__run_method_on_plugins( - "page_type_deleted", default_value, page_type + "page_type_deleted", default_value, page_type, channel_slug=None ) def permission_group_created(self, group: "Group"): default_value = None return self.__run_method_on_plugins( - "permission_group_created", default_value, group + "permission_group_created", default_value, group, channel_slug=None ) def permission_group_updated(self, group: "Group"): default_value = None return self.__run_method_on_plugins( - "permission_group_updated", default_value, group + "permission_group_updated", default_value, group, channel_slug=None ) def permission_group_deleted(self, group: "Group"): default_value = None return self.__run_method_on_plugins( - "permission_group_deleted", default_value, group + "permission_group_deleted", default_value, group, channel_slug=None ) def transaction_charge_requested( @@ -1212,12 +1334,17 @@ def transaction_process_session( def transaction_item_metadata_updated(self, transaction_item: "TransactionItem"): default_value = None return self.__run_method_on_plugins( - "transaction_item_metadata_updated", default_value, transaction_item + "transaction_item_metadata_updated", + default_value, + transaction_item, + channel_slug=None, ) def account_confirmed(self, user: "User"): default_value = None - return self.__run_method_on_plugins("account_confirmed", default_value, user) + return self.__run_method_on_plugins( + "account_confirmed", default_value, user, channel_slug=None + ) def account_confirmation_requested( self, user: "User", channel_slug: str, token: str, redirect_url: Optional[str] @@ -1230,6 +1357,7 @@ def account_confirmation_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def account_change_email_requested( @@ -1249,6 +1377,7 @@ def account_change_email_requested( token=token, redirect_url=redirect_url, new_email=new_email, + channel_slug=channel_slug, ) def account_email_changed( @@ -1260,6 +1389,7 @@ def account_email_changed( "account_email_changed", default_value, user, + channel_slug=None, ) def account_set_password_requested( @@ -1277,6 +1407,7 @@ def account_set_password_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def account_delete_requested( @@ -1290,132 +1421,195 @@ def account_delete_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def account_deleted(self, user: "User"): default_value = None - return self.__run_method_on_plugins("account_deleted", default_value, user) + return self.__run_method_on_plugins( + "account_deleted", default_value, user, channel_slug=None + ) def address_created(self, address: "Address"): default_value = None - return self.__run_method_on_plugins("address_created", default_value, address) + return self.__run_method_on_plugins( + "address_created", default_value, address, channel_slug=None + ) def address_updated(self, address: "Address"): default_value = None - return self.__run_method_on_plugins("address_updated", default_value, address) + return self.__run_method_on_plugins( + "address_updated", default_value, address, channel_slug=None + ) def address_deleted(self, address: "Address"): default_value = None - return self.__run_method_on_plugins("address_deleted", default_value, address) + return self.__run_method_on_plugins( + "address_deleted", default_value, address, channel_slug=None + ) def app_installed(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_installed", default_value, app) + return self.__run_method_on_plugins( + "app_installed", default_value, app, channel_slug=None + ) def app_updated(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_updated", default_value, app) + return self.__run_method_on_plugins( + "app_updated", default_value, app, channel_slug=None + ) def app_deleted(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_deleted", default_value, app) + return self.__run_method_on_plugins( + "app_deleted", default_value, app, channel_slug=None + ) def app_status_changed(self, app: "App"): default_value = None - return self.__run_method_on_plugins("app_status_changed", default_value, app) + return self.__run_method_on_plugins( + "app_status_changed", default_value, app, channel_slug=None + ) def attribute_created(self, attribute: "Attribute"): default_value = None return self.__run_method_on_plugins( - "attribute_created", default_value, attribute + "attribute_created", default_value, attribute, channel_slug=None ) def attribute_updated(self, attribute: "Attribute", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_updated", default_value, attribute, webhooks=webhooks + "attribute_updated", + default_value, + attribute, + webhooks=webhooks, + channel_slug=None, ) def attribute_deleted(self, attribute: "Attribute", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_deleted", default_value, attribute, webhooks=webhooks + "attribute_deleted", + default_value, + attribute, + webhooks=webhooks, + channel_slug=None, ) def attribute_value_created(self, attribute_value: "AttributeValue", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_value_created", default_value, attribute_value, webhooks=webhooks + "attribute_value_created", + default_value, + attribute_value, + webhooks=webhooks, + channel_slug=None, ) def attribute_value_updated(self, attribute_value: "AttributeValue"): default_value = None return self.__run_method_on_plugins( - "attribute_value_updated", default_value, attribute_value + "attribute_value_updated", default_value, attribute_value, channel_slug=None ) def attribute_value_deleted(self, attribute_value: "AttributeValue", webhooks=None): default_value = None return self.__run_method_on_plugins( - "attribute_value_deleted", default_value, attribute_value, webhooks=webhooks + "attribute_value_deleted", + default_value, + attribute_value, + webhooks=webhooks, + channel_slug=None, ) def category_created(self, category: "Category"): default_value = None - return self.__run_method_on_plugins("category_created", default_value, category) + return self.__run_method_on_plugins( + "category_created", default_value, category, channel_slug=None + ) def category_updated(self, category: "Category"): default_value = None - return self.__run_method_on_plugins("category_updated", default_value, category) + return self.__run_method_on_plugins( + "category_updated", default_value, category, channel_slug=None + ) def category_deleted(self, category: "Category", webhooks=None): default_value = None return self.__run_method_on_plugins( - "category_deleted", default_value, category, webhooks=webhooks + "category_deleted", + default_value, + category, + webhooks=webhooks, + channel_slug=None, ) def channel_created(self, channel: "Channel"): default_value = None - return self.__run_method_on_plugins("channel_created", default_value, channel) + return self.__run_method_on_plugins( + "channel_created", default_value, channel, channel_slug=channel.slug + ) def channel_updated(self, channel: "Channel", webhooks=None): default_value = None return self.__run_method_on_plugins( - "channel_updated", default_value, channel, webhooks=webhooks + "channel_updated", + default_value, + channel, + webhooks=webhooks, + channel_slug=channel.slug, ) def channel_deleted(self, channel: "Channel"): default_value = None - return self.__run_method_on_plugins("channel_deleted", default_value, channel) + return self.__run_method_on_plugins( + "channel_deleted", default_value, channel, channel_slug=None + ) def channel_status_changed(self, channel: "Channel"): default_value = None return self.__run_method_on_plugins( - "channel_status_changed", default_value, channel + "channel_status_changed", default_value, channel, channel_slug=channel.slug ) def channel_metadata_updated(self, channel: "Channel"): default_value = None return self.__run_method_on_plugins( - "channel_metadata_updated", default_value, channel + "channel_metadata_updated", + default_value, + channel, + channel_slug=channel.slug, ) def gift_card_created(self, gift_card: "GiftCard", webhooks=None): default_value = None return self.__run_method_on_plugins( - "gift_card_created", default_value, gift_card, webhooks=webhooks + "gift_card_created", + default_value, + gift_card, + webhooks=webhooks, + channel_slug=None, ) def gift_card_updated(self, gift_card: "GiftCard"): default_value = None return self.__run_method_on_plugins( - "gift_card_updated", default_value, gift_card + "gift_card_updated", + default_value, + gift_card, + channel_slug=None, ) def gift_card_deleted(self, gift_card: "GiftCard", webhooks=None): default_value = None return self.__run_method_on_plugins( - "gift_card_deleted", default_value, gift_card, webhooks=webhooks + "gift_card_deleted", + default_value, + gift_card, + webhooks=webhooks, + channel_slug=None, ) def gift_card_sent(self, gift_card: "GiftCard", channel_slug: str, email: str): @@ -1426,112 +1620,184 @@ def gift_card_sent(self, gift_card: "GiftCard", channel_slug: str, email: str): gift_card, channel_slug, email, + channel_slug=channel_slug, ) def gift_card_status_changed(self, gift_card: "GiftCard", webhooks=None): default_value = None return self.__run_method_on_plugins( - "gift_card_status_changed", default_value, gift_card, webhooks=webhooks + "gift_card_status_changed", + default_value, + gift_card, + webhooks=webhooks, + channel_slug=None, ) def gift_card_metadata_updated(self, gift_card: "GiftCard"): default_value = None return self.__run_method_on_plugins( - "gift_card_metadata_updated", default_value, gift_card + "gift_card_metadata_updated", + default_value, + gift_card, + channel_slug=None, ) def gift_card_export_completed(self, export: "ExportFile"): default_value = None return self.__run_method_on_plugins( - "gift_card_export_completed", default_value, export + "gift_card_export_completed", + default_value, + export, + channel_slug=None, ) def menu_created(self, menu: "Menu"): default_value = None - return self.__run_method_on_plugins("menu_created", default_value, menu) + return self.__run_method_on_plugins( + "menu_created", + default_value, + menu, + channel_slug=None, + ) def menu_updated(self, menu: "Menu"): default_value = None - return self.__run_method_on_plugins("menu_updated", default_value, menu) + return self.__run_method_on_plugins( + "menu_updated", + default_value, + menu, + channel_slug=None, + ) def menu_deleted(self, menu: "Menu", webhooks=None): default_value = None return self.__run_method_on_plugins( - "menu_deleted", default_value, menu, webhooks=webhooks + "menu_deleted", + default_value, + menu, + webhooks=webhooks, + channel_slug=None, ) def menu_item_created(self, menu_item: "MenuItem"): default_value = None return self.__run_method_on_plugins( - "menu_item_created", default_value, menu_item + "menu_item_created", + default_value, + menu_item, + channel_slug=None, ) def menu_item_updated(self, menu_item: "MenuItem"): default_value = None return self.__run_method_on_plugins( - "menu_item_updated", default_value, menu_item + "menu_item_updated", + default_value, + menu_item, + channel_slug=None, ) def menu_item_deleted(self, menu_item: "MenuItem", webhooks=None): default_value = None return self.__run_method_on_plugins( - "menu_item_deleted", default_value, menu_item, webhooks=webhooks + "menu_item_deleted", + default_value, + menu_item, + webhooks=webhooks, + channel_slug=None, ) def shipping_price_created(self, shipping_method: "ShippingMethod"): default_value = None return self.__run_method_on_plugins( - "shipping_price_created", default_value, shipping_method + "shipping_price_created", + default_value, + shipping_method, + channel_slug=None, ) def shipping_price_updated(self, shipping_method: "ShippingMethod"): default_value = None return self.__run_method_on_plugins( - "shipping_price_updated", default_value, shipping_method + "shipping_price_updated", + default_value, + shipping_method, + channel_slug=None, ) def shipping_price_deleted(self, shipping_method: "ShippingMethod", webhooks=None): default_value = None return self.__run_method_on_plugins( - "shipping_price_deleted", default_value, shipping_method, webhooks=webhooks + "shipping_price_deleted", + default_value, + shipping_method, + webhooks=webhooks, + channel_slug=None, ) def shipping_zone_created(self, shipping_zone: "ShippingZone"): default_value = None return self.__run_method_on_plugins( - "shipping_zone_created", default_value, shipping_zone + "shipping_zone_created", + default_value, + shipping_zone, + channel_slug=None, ) def shipping_zone_updated(self, shipping_zone: "ShippingZone"): default_value = None return self.__run_method_on_plugins( - "shipping_zone_updated", default_value, shipping_zone + "shipping_zone_updated", + default_value, + shipping_zone, + channel_slug=None, ) def shipping_zone_deleted(self, shipping_zone: "ShippingZone", webhooks=None): default_value = None return self.__run_method_on_plugins( - "shipping_zone_deleted", default_value, shipping_zone, webhooks=webhooks + "shipping_zone_deleted", + default_value, + shipping_zone, + webhooks=webhooks, + channel_slug=None, ) def shipping_zone_metadata_updated(self, shipping_zone: "ShippingZone"): default_value = None return self.__run_method_on_plugins( - "shipping_zone_metadata_updated", default_value, shipping_zone + "shipping_zone_metadata_updated", + default_value, + shipping_zone, + channel_slug=None, ) def staff_created(self, staff_user: "User"): default_value = None - return self.__run_method_on_plugins("staff_created", default_value, staff_user) + return self.__run_method_on_plugins( + "staff_created", + default_value, + staff_user, + channel_slug=None, + ) def staff_updated(self, staff_user: "User"): default_value = None - return self.__run_method_on_plugins("staff_updated", default_value, staff_user) + return self.__run_method_on_plugins( + "staff_updated", + default_value, + staff_user, + channel_slug=None, + ) def staff_deleted(self, staff_user: "User", webhooks=None): default_value = None return self.__run_method_on_plugins( - "staff_deleted", default_value, staff_user, webhooks=webhooks + "staff_deleted", + default_value, + staff_user, + webhooks=webhooks, + channel_slug=None, ) def staff_set_password_requested( @@ -1545,6 +1811,7 @@ def staff_set_password_requested( channel_slug, token=token, redirect_url=redirect_url, + channel_slug=channel_slug, ) def thumbnail_created( @@ -1553,79 +1820,124 @@ def thumbnail_created( ): default_value = None return self.__run_method_on_plugins( - "thumbnail_created", default_value, thumbnail + "thumbnail_created", + default_value, + thumbnail, + channel_slug=None, ) def warehouse_created(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_created", default_value, warehouse + "warehouse_created", + default_value, + warehouse, + channel_slug=None, ) def warehouse_updated(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_updated", default_value, warehouse + "warehouse_updated", + default_value, + warehouse, + channel_slug=None, ) def warehouse_deleted(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_deleted", default_value, warehouse + "warehouse_deleted", + default_value, + warehouse, + channel_slug=None, ) def warehouse_metadata_updated(self, warehouse: "Warehouse"): default_value = None return self.__run_method_on_plugins( - "warehouse_metadata_updated", default_value, warehouse + "warehouse_metadata_updated", + default_value, + warehouse, + channel_slug=None, ) def voucher_created(self, voucher: "Voucher", code: str): default_value = None return self.__run_method_on_plugins( - "voucher_created", default_value, voucher, code + "voucher_created", + default_value, + voucher, + code, + channel_slug=None, ) def voucher_updated(self, voucher: "Voucher", code: str): default_value = None return self.__run_method_on_plugins( - "voucher_updated", default_value, voucher, code + "voucher_updated", + default_value, + voucher, + code, + channel_slug=None, ) def voucher_deleted(self, voucher: "Voucher", code: str, webhooks=None): default_value = None return self.__run_method_on_plugins( - "voucher_deleted", default_value, voucher, code, webhooks=webhooks + "voucher_deleted", + default_value, + voucher, + code, + webhooks=webhooks, + channel_slug=None, ) def voucher_codes_created(self, voucher_codes: list["VoucherCode"], webhooks=None): default_value = None return self.__run_method_on_plugins( - "voucher_codes_created", default_value, voucher_codes, webhooks=webhooks + "voucher_codes_created", + default_value, + voucher_codes, + webhooks=webhooks, + channel_slug=None, ) def voucher_codes_deleted(self, voucher_codes: list["VoucherCode"], webhooks=None): default_value = None return self.__run_method_on_plugins( - "voucher_codes_deleted", default_value, voucher_codes, webhooks=webhooks + "voucher_codes_deleted", + default_value, + voucher_codes, + webhooks=webhooks, + channel_slug=None, ) def voucher_metadata_updated(self, voucher: "Voucher"): default_value = None return self.__run_method_on_plugins( - "voucher_metadata_updated", default_value, voucher + "voucher_metadata_updated", + default_value, + voucher, + channel_slug=None, ) def voucher_code_export_completed(self, export: "ExportFile"): default_value = None return self.__run_method_on_plugins( - "voucher_code_export_completed", default_value, export + "voucher_code_export_completed", + default_value, + export, + channel_slug=None, ) def shop_metadata_updated(self, shop: "SiteSettings"): default_value = None return self.__run_method_on_plugins( - "shop_metadata_updated", default_value, shop + "shop_metadata_updated", + default_value, + shop, + channel_slug=None, ) def initialize_payment( @@ -1642,6 +1954,7 @@ def initialize_payment( method_name, previous_value=default_value, payment_data=payment_data, + channel_slug=channel_slug, ) def authorize_payment( @@ -1717,7 +2030,7 @@ def list_payment_sources( self, gateway: str, customer_id: str, - channel_slug: str, + channel_slug: Optional[str], ) -> list["CustomerSource"]: default_value: list = [] gtw = self.get_plugin(gateway, channel_slug=channel_slug) @@ -1728,13 +2041,15 @@ def list_payment_sources( raise Exception(f"Payment plugin {gateway} is inaccessible!") def list_stored_payment_methods( - self, list_stored_payment_methods_data: "ListStoredPaymentMethodsRequestData" + self, + list_stored_payment_methods_data: "ListStoredPaymentMethodsRequestData", ) -> list["PaymentMethodData"]: default_value: list = [] return self.__run_method_on_plugins( "list_stored_payment_methods", default_value, list_stored_payment_methods_data, + channel_slug=list_stored_payment_methods_data.channel.slug, ) def stored_payment_method_request_delete( @@ -1749,6 +2064,7 @@ def stored_payment_method_request_delete( "stored_payment_method_request_delete", default_response, request_delete_data, + channel_slug=request_delete_data.channel.slug, ) return response @@ -1766,6 +2082,7 @@ def payment_gateway_initialize_tokenization( "payment_gateway_initialize_tokenization", default_response, request_data, + channel_slug=request_data.channel.slug, ) return response @@ -1783,6 +2100,7 @@ def payment_method_initialize_tokenization( "payment_method_initialize_tokenization", default_response, request_data, + channel_slug=request_data.channel.slug, ) return response @@ -1800,21 +2118,30 @@ def payment_method_process_tokenization( "payment_method_process_tokenization", default_response, request_data, + channel_slug=request_data.channel.slug, ) return response def translation_created(self, translation: "Translation"): default_value = None return self.__run_method_on_plugins( - "translation_created", default_value, translation + "translation_created", default_value, translation, channel_slug=None ) def translation_updated(self, translation: "Translation"): default_value = None return self.__run_method_on_plugins( - "translation_updated", default_value, translation + "translation_updated", default_value, translation, channel_slug=None ) + def get_all_plugins(self, active_only=False): + if not self.loaded_all_channels: + channels = Channel.objects.using(self.database).all() + for channel in channels.iterator(): + self._ensure_channel_plugins_loaded(channel.slug, channel=channel) + self.loaded_all_channels = True + return self.get_plugins(active_only=active_only) + def get_plugins( self, channel_slug: Optional[str] = None, @@ -1822,9 +2149,11 @@ def get_plugins( plugin_ids: Optional[list[str]] = None, ) -> list["BasePlugin"]: """Return list of plugins for a given channel.""" - if channel_slug: + if channel_slug is not None: + self._ensure_channel_plugins_loaded(channel_slug) plugins = self.plugins_per_channel[channel_slug] else: + self._ensure_channel_plugins_loaded(None) plugins = self.all_plugins if active_only: @@ -1844,7 +2173,17 @@ def list_payment_gateways( active_only: bool = True, ) -> list["PaymentGateway"]: channel_slug = checkout_info.channel.slug if checkout_info else channel_slug - plugins = self.get_plugins(channel_slug=channel_slug, active_only=active_only) + + if channel_slug is not None: + plugins = self.get_plugins( + channel_slug=channel_slug, active_only=active_only + ) + else: + # Backwards compatibility for: https://github.com/saleor/saleor/pull/15769/ + # Load all channel plugins and global plugins if channel_slug is None, as + # it was done before the mentioned PR. + plugins = self.get_all_plugins(active_only=active_only) + payment_plugins = [ plugin for plugin in plugins if "process_payment" in type(plugin).__dict__ ] @@ -1939,9 +2278,11 @@ def __run_plugin_method_until_first_success( self, method_name: str, *args, - channel_slug: Optional[str] = None, + channel_slug: Optional[str], + plugins: Optional[list["BasePlugin"]] = None, ): - plugins = self.get_plugins(channel_slug=channel_slug) + if plugins is None: + plugins = self.get_plugins(channel_slug=channel_slug) for plugin in plugins: result = self.__run_method_on_single_plugin( plugin, method_name, None, *args @@ -1972,18 +2313,17 @@ def _get_all_plugin_configs(self): # FIXME these methods should be more generic - def assign_tax_code_to_object_meta(self, obj: "TaxClass", tax_code: Optional[str]): - default_value = None - return self.__run_method_on_plugins( - "assign_tax_code_to_object_meta", default_value, obj, tax_code - ) - def get_tax_code_from_object_meta( - self, obj: Union["Product", "ProductType", "TaxClass"] + self, + obj: Union["Product", "ProductType", "TaxClass"], + channel_slug: Optional[str], ) -> TaxType: default_value = TaxType(code="", description="") return self.__run_method_on_plugins( - "get_tax_code_from_object_meta", default_value, obj + "get_tax_code_from_object_meta", + default_value, + obj, + channel_slug=channel_slug, ) def save_plugin_configuration( @@ -1998,7 +2338,7 @@ def save_plugin_configuration( return None else: channel = None - plugins = self.global_plugins + plugins = self.get_plugins() for plugin in plugins: if plugin.PLUGIN_ID == plugin_id: @@ -2039,6 +2379,10 @@ def webhook_endpoint_without_channel( default_value = HttpResponseNotFound() plugin = self.get_plugin(plugin_id) + if not plugin: + self.get_all_plugins() + plugin = self.get_plugin(plugin_id) + if not plugin: return default_value return self.__run_method_on_single_plugin( @@ -2046,7 +2390,7 @@ def webhook_endpoint_without_channel( ) def webhook( - self, request: SaleorContext, plugin_id: str, channel_slug: Optional[str] = None + self, request: SaleorContext, plugin_id: str, channel_slug: Optional[str] ) -> HttpResponse: split_path = request.path.split(plugin_id, maxsplit=1) path = None @@ -2125,7 +2469,9 @@ def external_refresh( def authenticate_user(self, request: SaleorContext) -> Optional["User"]: """Authenticate user which should be assigned to the request.""" default_value = None - return self.__run_method_on_plugins("authenticate_user", default_value, request) + return self.__run_method_on_plugins( + "authenticate_user", default_value, request, channel_slug=None + ) def external_logout( self, plugin_id: str, data: dict, request: SaleorContext @@ -2196,11 +2542,13 @@ def perform_mutation( root=root, info=info, data=data, + channel_slug=None, ) def is_event_active_for_any_plugin( self, event: str, channel_slug: Optional[str] = None ) -> bool: + self._ensure_channel_plugins_loaded(channel_slug) """Check if any plugin supports defined event.""" plugins = ( self.plugins_per_channel[channel_slug] if channel_slug else self.all_plugins @@ -2208,12 +2556,6 @@ def is_event_active_for_any_plugin( only_active_plugins = [plugin for plugin in plugins if plugin.active] return any([plugin.is_event_active(event) for plugin in only_active_plugins]) - def _get_channel_map(self): - return { - channel.pk: channel - for channel in Channel.objects.using(self.database).all().iterator() - } - def get_plugins_manager( allow_replica: bool, diff --git a/saleor/plugins/openid_connect/tests/conftest.py b/saleor/plugins/openid_connect/tests/conftest.py index bf665936d36..a2ca010de2d 100644 --- a/saleor/plugins/openid_connect/tests/conftest.py +++ b/saleor/plugins/openid_connect/tests/conftest.py @@ -86,6 +86,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.all_plugins[0] return fun diff --git a/saleor/plugins/sendgrid/tests/conftest.py b/saleor/plugins/sendgrid/tests/conftest.py index d45a7810cf0..fc8b1b69679 100644 --- a/saleor/plugins/sendgrid/tests/conftest.py +++ b/saleor/plugins/sendgrid/tests/conftest.py @@ -102,6 +102,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/plugins/tests/sample_plugins.py b/saleor/plugins/tests/sample_plugins.py index ba777550bed..52d678edeb9 100644 --- a/saleor/plugins/tests/sample_plugins.py +++ b/saleor/plugins/tests/sample_plugins.py @@ -178,9 +178,6 @@ def calculate_order_line_unit( def get_tax_rate_type_choices(self, previous_value): return [TaxType(code="123", description="abc")] - def show_taxes_on_storefront(self, previous_value: bool) -> bool: - return True - def external_authentication_url( self, data: dict, request: WSGIRequest, previous_value ) -> dict: @@ -480,6 +477,7 @@ class SampleAuthorizationPlugin(BasePlugin): PLUGIN_ID = "saleor.sample.authorization" PLUGIN_NAME = "SampleAuthorization" DEFAULT_ACTIVE = True + CONFIGURATION_PER_CHANNEL = False def authenticate_user(self, request, previous_value) -> Optional[User]: # This function will be mocked in test diff --git a/saleor/plugins/tests/test_manager.py b/saleor/plugins/tests/test_manager.py index 3ab391901a0..2ee6b88e011 100644 --- a/saleor/plugins/tests/test_manager.py +++ b/saleor/plugins/tests/test_manager.py @@ -54,6 +54,7 @@ def test_get_plugins_manager(settings): plugin_path = "saleor.plugins.tests.sample_plugins.PluginSample" settings.PLUGINS = [plugin_path] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() assert isinstance(manager, PluginsManager) assert len(manager.all_plugins) == 1 @@ -66,6 +67,8 @@ def test_manager_with_default_configuration_for_channel_plugins( "saleor.plugins.tests.sample_plugins.PluginSample", ] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() + assert len(manager.global_plugins) == 1 assert isinstance(manager.global_plugins[0], PluginSample) assert {channel_PLN.slug, channel_USD.slug} == set( @@ -92,6 +95,7 @@ def test_manager_with_channel_plugins( "saleor.plugins.tests.sample_plugins.ChannelPluginSample", ] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() assert {channel_PLN.slug, channel_USD.slug} == set( manager.plugins_per_channel.keys() @@ -539,14 +543,6 @@ def sample_none_data(obj): return None -@pytest.mark.parametrize( - ("plugins", "show_taxes"), - [(["saleor.plugins.tests.sample_plugins.PluginSample"], True), ([], False)], -) -def test_manager_show_taxes_on_storefront(plugins, show_taxes): - assert show_taxes == PluginsManager(plugins=plugins).show_taxes_on_storefront() - - @pytest.mark.parametrize( ("plugins", "expected_tax_data"), [ @@ -818,7 +814,7 @@ def test_manager_webhook(rf): plugin_path = "/webhook/paid" request = rf.post(path=f"/plugins/{PluginSample.PLUGIN_ID}{plugin_path}") - response = manager.webhook(request, PluginSample.PLUGIN_ID) + response = manager.webhook(request, PluginSample.PLUGIN_ID, channel_slug=None) assert isinstance(response, JsonResponse) assert response.status_code == 200 assert response.content.decode() == json.dumps({"received": True, "paid": True}) @@ -832,7 +828,7 @@ def test_manager_webhook_plugin_doesnt_have_webhook_support(rf): manager = PluginsManager(plugins=plugins) plugin_path = "/webhook/paid" request = rf.post(path=f"/plugins/{PluginInactive.PLUGIN_ID}{plugin_path}") - response = manager.webhook(request, PluginSample.PLUGIN_ID) + response = manager.webhook(request, PluginSample.PLUGIN_ID, channel_slug=None) assert isinstance(response, HttpResponseNotFound) assert response.status_code == 404 @@ -845,7 +841,7 @@ def test_manager_inncorrect_plugin(rf): manager = PluginsManager(plugins=plugins) plugin_path = "/webhook/paid" request = rf.post(path=f"/plugins/incorrect.plugin.id{plugin_path}") - response = manager.webhook(request, "incorrect.plugin.id") + response = manager.webhook(request, "incorrect.plugin.id", channel_slug=None) assert isinstance(response, HttpResponseNotFound) assert response.status_code == 404 @@ -966,8 +962,7 @@ def test_list_external_authentications_active_only(channel_USD): def test_run_method_on_plugins_default_value(plugins_manager): default_value = "default" value = plugins_manager._PluginsManager__run_method_on_plugins( - method_name="test_method", - default_value=default_value, + method_name="test_method", default_value=default_value, channel_slug=None ) assert value == default_value @@ -980,6 +975,7 @@ def test_run_method_on_plugins_default_value_when_not_existing_method_is_called( value = all_plugins_manager._PluginsManager__run_method_on_plugins( method_name="test_method", default_value=default_value, + channel_slug=channel_USD.slug, ) assert value == default_value @@ -992,6 +988,7 @@ def test_run_method_on_plugins_value_overridden_by_plugin_method( value = all_plugins_manager._PluginsManager__run_method_on_plugins( method_name="get_supported_currencies", default_value="default_value", + channel_slug=channel_USD.slug, ) assert value == expected @@ -1006,6 +1003,7 @@ def test_run_method_on_plugins_only_on_active_ones( all_plugins_manager._PluginsManager__run_method_on_plugins( method_name="test_method_name", default_value="default_value", + channel_slug=channel_USD.slug, ) active_plugins_count = len(ACTIVE_PLUGINS) @@ -1017,7 +1015,11 @@ def test_run_method_on_plugins_only_on_active_ones( assert mocked_method.call_count == active_plugins_count called_plugins_id = [arg.args[0].PLUGIN_ID for arg in mocked_method.call_args_list] - expected_active_plugins_id = [p.PLUGIN_ID for p in ACTIVE_PLUGINS] + expected_active_plugins_id = [ + p.PLUGIN_ID + for p in all_plugins_manager.plugins_per_channel[channel_USD.slug] + if p.active + ] assert called_plugins_id == expected_active_plugins_id @@ -1039,7 +1041,7 @@ def test_run_method_on_plugins_only_for_given_channel( ) pln_plugin = ChannelPluginSample(active=True, channel=channel_PLN, configuration=[]) - plugins_manager.plugins = [usd_plugin_1, usd_plugin_2, pln_plugin] + plugins_manager.all_plugins = [usd_plugin_1, usd_plugin_2, pln_plugin] plugins_manager.plugins_per_channel[channel_USD.slug] = [usd_plugin_1, usd_plugin_2] plugins_manager.plugins_per_channel[channel_PLN.slug] = [pln_plugin] @@ -1125,6 +1127,7 @@ def fake_request_getter(mock): manager = PluginsManager( plugins=plugins, requestor_getter=partial(fake_request_getter, user_mock) ) + manager.get_all_plugins() user_mock.assert_not_called() plugin = manager.all_plugins.pop() @@ -1548,19 +1551,34 @@ def test_plugin_manager_database(allow_replica, expected_connection_name): assert manager.database == expected_connection_name -def test_plugin_manager__get_channel_map( - channel_USD, channel_PLN, channel_JPY, other_channel_USD -): +def test_loaded_all_channels(channel_USD, channel_PLN, django_assert_num_queries): + # given + plugins = [ + "saleor.plugins.tests.sample_plugins.PluginSample", + ] + manager = PluginsManager(plugins=plugins) + + # then + with django_assert_num_queries(4): + plugins = manager.get_all_plugins() + assert plugins + + with django_assert_num_queries(0): + plugins = manager.get_all_plugins() + assert plugins + + +def test_get_plugin_invalid_channel(): # given - manager = PluginsManager(["saleor.plugins.tests.sample_plugins.PluginSample"]) + plugins = [ + "saleor.plugins.tests.sample_plugins.PluginSample", + ] + manager = PluginsManager(plugins=plugins) # when - channel_map = manager._get_channel_map() + plugin = manager.get_plugin( + "saleor.plugins.tests.sample_plugins.PluginSample", channel_slug="invalid" + ) # then - assert channel_map == { - channel_USD.pk: channel_USD, - channel_PLN.pk: channel_PLN, - channel_JPY.pk: channel_JPY, - other_channel_USD.pk: other_channel_USD, - } + assert plugin is None diff --git a/saleor/plugins/user_email/tests/conftest.py b/saleor/plugins/user_email/tests/conftest.py index 1e288faa247..415f9cdded0 100644 --- a/saleor/plugins/user_email/tests/conftest.py +++ b/saleor/plugins/user_email/tests/conftest.py @@ -240,6 +240,7 @@ def fun( }, ) manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.plugins_per_channel[channel_USD.slug][0] return fun diff --git a/saleor/plugins/user_email/tests/test_plugin.py b/saleor/plugins/user_email/tests/test_plugin.py index c115046a11d..0bb68742d69 100644 --- a/saleor/plugins/user_email/tests/test_plugin.py +++ b/saleor/plugins/user_email/tests/test_plugin.py @@ -302,6 +302,7 @@ def test_plugin_manager_doesnt_load_email_templates_from_db( ): settings.PLUGINS = ["saleor.plugins.user_email.plugin.UserEmailPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() plugin = manager.all_plugins[0] email_config_item = None diff --git a/saleor/plugins/webhook/conftest.py b/saleor/plugins/webhook/conftest.py index 13423ed389b..3ad3b31085d 100644 --- a/saleor/plugins/webhook/conftest.py +++ b/saleor/plugins/webhook/conftest.py @@ -71,6 +71,7 @@ def webhook_plugin(settings): def factory(): settings.PLUGINS = ["saleor.plugins.webhook.plugin.WebhookPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.global_plugins[0] return factory diff --git a/saleor/plugins/webhook/plugin.py b/saleor/plugins/webhook/plugin.py index d3fcff9b704..04b6d9c93ab 100644 --- a/saleor/plugins/webhook/plugin.py +++ b/saleor/plugins/webhook/plugin.py @@ -3119,7 +3119,9 @@ def get_shipping_methods_for_checkout( return methods def get_tax_code_from_object_meta( - self, obj: Union["Product", "ProductType", "TaxClass"], previous_value: Any + self, + obj: Union["Product", "ProductType", "TaxClass"], + previous_value: Any, ): """Get tax code and description for a product or product type. diff --git a/saleor/plugins/webhook/tests/conftest.py b/saleor/plugins/webhook/tests/conftest.py index 91258c140dd..1edbd8341c9 100644 --- a/saleor/plugins/webhook/tests/conftest.py +++ b/saleor/plugins/webhook/tests/conftest.py @@ -18,6 +18,7 @@ def webhook_plugin(settings): def factory() -> WebhookPlugin: settings.PLUGINS = ["saleor.plugins.webhook.plugin.WebhookPlugin"] manager = get_plugins_manager(allow_replica=False) + manager.get_all_plugins() return manager.global_plugins[0] return factory From 1f0d9ee3bdf0d6c9b463b7db40f7c6f619d3b8c9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 09:54:19 +0200 Subject: [PATCH 32/35] Bump gunicorn from 21.2.0 to 22.0.0 (#15808) Bumps [gunicorn](https://github.com/benoitc/gunicorn) from 21.2.0 to 22.0.0. - [Release notes](https://github.com/benoitc/gunicorn/releases) - [Commits](https://github.com/benoitc/gunicorn/compare/21.2.0...22.0.0) --- updated-dependencies: - dependency-name: gunicorn dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- poetry.lock | 14 ++++++++------ pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 2e667b4d91b..a604b4b2c4a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2090,22 +2090,23 @@ protobuf = ">=4.21.6" [[package]] name = "gunicorn" -version = "21.2.0" +version = "22.0.0" description = "WSGI HTTP Server for UNIX" optional = false -python-versions = ">=3.5" +python-versions = ">=3.7" files = [ - {file = "gunicorn-21.2.0-py3-none-any.whl", hash = "sha256:3213aa5e8c24949e792bcacfc176fef362e7aac80b76c56f6b5122bf350722f0"}, - {file = "gunicorn-21.2.0.tar.gz", hash = "sha256:88ec8bff1d634f98e61b9f65bc4bf3cd918a90806c6f5c48bc5603849ec81033"}, + {file = "gunicorn-22.0.0-py3-none-any.whl", hash = "sha256:350679f91b24062c86e386e198a15438d53a7a8207235a78ba1b53df4c4378d9"}, + {file = "gunicorn-22.0.0.tar.gz", hash = "sha256:4a0b436239ff76fb33f11c07a16482c521a7e09c1ce3cc293c2330afe01bec63"}, ] [package.dependencies] packaging = "*" [package.extras] -eventlet = ["eventlet (>=0.24.1)"] +eventlet = ["eventlet (>=0.24.1,!=0.36.0)"] gevent = ["gevent (>=1.4.0)"] setproctitle = ["setproctitle"] +testing = ["coverage", "eventlet", "gevent", "pytest", "pytest-cov"] tornado = ["tornado (>=0.2)"] [[package]] @@ -3846,6 +3847,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5400,4 +5402,4 @@ test = ["pytest"] [metadata] lock-version = "2.0" python-versions = "~3.9" -content-hash = "59f3b92f1122bc8f6a11be9e6d6c94fc4cd60f8a161ed392a3f69554040b92c5" +content-hash = "1363addca1d09bfe3f1716c221ebfc88014c9584ad943b94a9f866ac9a8f3ff6" diff --git a/pyproject.toml b/pyproject.toml index e7089fdb93a..c5cc05d0985 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ documentation = "https://docs.saleor.io/" graphene = "<3.0" graphql-core = "^2.3.2" graphql-relay = "^2.0.1" - gunicorn = "^21.2.0" + gunicorn = "^22.0.0" html2text = "^2020.1.16" html-to-draftjs = "^1.0.1" jaeger-client = "^4.5.0" From 09acf7bd0d210850257edd87dcf82f45e544386f Mon Sep 17 00:00:00 2001 From: Krzysztof Waliczek Date: Fri, 19 Apr 2024 12:46:10 +0200 Subject: [PATCH 33/35] remove unique constraint (#15779) --- .../migrations/0040_clear_assignedattributes.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/saleor/attribute/migrations/0040_clear_assignedattributes.py b/saleor/attribute/migrations/0040_clear_assignedattributes.py index 5684e04fd14..7e336f5bf05 100644 --- a/saleor/attribute/migrations/0040_clear_assignedattributes.py +++ b/saleor/attribute/migrations/0040_clear_assignedattributes.py @@ -19,10 +19,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedproductattributevalue ALTER COLUMN product_id SET NOT NULL; - ALTER TABLE attribute_assignedproductattributevalue - ADD CONSTRAINT attribute_assignedproduc_value_id_product_id_6f6deb31_uniq - UNIQUE (value_id, product_id); - DROP TABLE attribute_assignedproductattribute; """, reverse_sql=""" @@ -32,9 +28,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedproductattributevalue ALTER COLUMN product_id DROP NOT NULL; - ALTER TABLE attribute_assignedproductattributevalue - DROP CONSTRAINT IF EXISTS attribute_assignedproduc_value_id_product_id_6f6deb31_uniq; - CREATE TABLE attribute_assignedproductattribute ( id serial NOT NULL PRIMARY KEY, assignment_id integer, @@ -50,10 +43,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedpageattributevalue ALTER COLUMN page_id SET NOT NULL; - ALTER TABLE attribute_assignedpageattributevalue - ADD CONSTRAINT attribute_assignedpageat_value_id_page_id_851cd501_uniq - UNIQUE (value_id, page_id); - DROP TABLE attribute_assignedpageattribute; """, reverse_sql=""" @@ -63,9 +52,6 @@ class Migration(migrations.Migration): ALTER TABLE attribute_assignedpageattributevalue ALTER COLUMN page_id DROP NOT NULL; - ALTER TABLE attribute_assignedpageattributevalue - DROP CONSTRAINT IF EXISTS attribute_assignedpageat_value_id_page_id_851cd501_uniq; - CREATE TABLE attribute_assignedpageattribute ( id serial NOT NULL PRIMARY KEY, assignment_id integer, From c5bc3ef90074345b8f5f8088aa974a34d65e6edd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Kwa=C5=9Bniak?= Date: Fri, 19 Apr 2024 13:21:37 +0200 Subject: [PATCH 34/35] Add deepcoopy when updating CONFIG_STRUCTURE for AdminEmailPlugin (#15813) --- saleor/plugins/admin_email/plugin.py | 3 +- .../plugins/admin_email/tests/test_plugin.py | 35 +++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/saleor/plugins/admin_email/plugin.py b/saleor/plugins/admin_email/plugin.py index a13fa41194d..9b3aef56081 100644 --- a/saleor/plugins/admin_email/plugin.py +++ b/saleor/plugins/admin_email/plugin.py @@ -1,4 +1,5 @@ import logging +from copy import deepcopy from dataclasses import asdict from typing import Union @@ -143,7 +144,7 @@ class AdminEmailPlugin(BasePlugin): "label": "CSV export failed template", }, } - CONFIG_STRUCTURE.update(DEFAULT_EMAIL_CONFIG_STRUCTURE) + CONFIG_STRUCTURE.update(deepcopy(DEFAULT_EMAIL_CONFIG_STRUCTURE)) CONFIG_STRUCTURE["host"]["help_text"] += ( " Leave it blank if you want to use system environment - EMAIL_HOST." ) diff --git a/saleor/plugins/admin_email/tests/test_plugin.py b/saleor/plugins/admin_email/tests/test_plugin.py index 8f971960589..8d35df6035b 100644 --- a/saleor/plugins/admin_email/tests/test_plugin.py +++ b/saleor/plugins/admin_email/tests/test_plugin.py @@ -8,7 +8,11 @@ from ....core.notify_events import NotifyEventType from ....graphql.tests.utils import get_graphql_content -from ...email_common import DEFAULT_EMAIL_VALUE, get_email_template +from ...email_common import ( + DEFAULT_EMAIL_CONFIG_STRUCTURE, + DEFAULT_EMAIL_VALUE, + get_email_template, +) from ...manager import get_plugins_manager from ...models import PluginConfiguration from ..constants import ( @@ -24,7 +28,7 @@ send_staff_order_confirmation, send_staff_reset_password, ) -from ..plugin import get_admin_event_map +from ..plugin import AdminEmailPlugin, get_admin_event_map def test_event_map(): @@ -293,3 +297,30 @@ def test_plugin_manager_doesnt_load_email_templates_from_db( # email template from DB but returns default email value. assert email_config_item assert email_config_item["value"] == DEFAULT_EMAIL_VALUE + + +def test_plugin_dont_change_default_help_text_config_value(): + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["host"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["host"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["port"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["port"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["username"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["username"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["password"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["password"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["use_tls"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["use_tls"]["help_text"] + ) + assert ( + AdminEmailPlugin.CONFIG_STRUCTURE["use_ssl"]["help_text"] + != DEFAULT_EMAIL_CONFIG_STRUCTURE["use_ssl"]["help_text"] + ) From d3fa53c48cd3b60fde9a8c0c48764c3f315208ea Mon Sep 17 00:00:00 2001 From: Teddy Ondieki Date: Fri, 19 Apr 2024 14:22:48 +0300 Subject: [PATCH 35/35] Prevent name overwritting of Product Variants when Updating Product Types (#15670) * prevent overwrite of variant names if name exists already --- CHANGELOG.md | 1 + saleor/product/tasks.py | 5 +++-- saleor/product/tests/test_tasks.py | 23 +++++++++++++++-------- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5031460a04e..0817889b10c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ All notable, unreleased changes to this project will be documented in this file. - Don't raise InsufficientStock for track_inventory=False variants #15475 by @carlosa54 - DB performance improvements in attribute dataloaders - #15474 by @AjmalPonneth - Calculate order promotions in draft orders - #15459 by @zedzior +- Prevent name overwriting of Product Variants when Updating Product Types - #15670 by @teddyondieki # 3.19.0 diff --git a/saleor/product/tasks.py b/saleor/product/tasks.py index 1d0e3daf3d6..5de5ee0a636 100644 --- a/saleor/product/tasks.py +++ b/saleor/product/tasks.py @@ -58,8 +58,8 @@ def _variants_in_batches(variants_qs): def _update_variants_names(instance: ProductType, saved_attributes: Iterable): """Product variant names are created from names of assigned attributes. - After change in attribute value name, for all product variants using this - attributes we need to update the names. + After change in attribute value name, we update the names for all product variants + that lack names and use these attributes. """ initial_attributes = set(instance.variant_attributes.all()) attributes_changed = initial_attributes.intersection(saved_attributes) @@ -67,6 +67,7 @@ def _update_variants_names(instance: ProductType, saved_attributes: Iterable): return variants = ProductVariant.objects.filter( + name="", product__in=instance.products.all(), product__product_type__variant_attributes__in=attributes_changed, ) diff --git a/saleor/product/tests/test_tasks.py b/saleor/product/tests/test_tasks.py index 6238e8a0784..9c1346d9b64 100644 --- a/saleor/product/tests/test_tasks.py +++ b/saleor/product/tests/test_tasks.py @@ -5,6 +5,7 @@ import pytest from django.utils import timezone +from faker import Faker from ...discount import PromotionType, RewardValueType from ...discount.models import Promotion, PromotionRule @@ -230,17 +231,23 @@ def test_recalculate_discounted_price_for_products_task_re_trigger_task( assert recalculate_discounted_price_for_products_task_mock.called -@patch("saleor.product.tasks._update_variants_names") -def test_update_variants_names( - update_variants_names_mock, product_type, size_attribute -): +def test_update_variants_names(product_variant_list, size_attribute): + # given + variant_without_name = product_variant_list[0] + variant_with_name = product_variant_list[1] + random_name = Faker().word() + variant_with_name.name = random_name + variant_with_name.save() + product = variant_without_name.product + # when - update_variants_names(product_type.id, [size_attribute.id]) + update_variants_names(product.product_type_id, [size_attribute.id]) # then - args, _ = update_variants_names_mock.call_args - assert args[0] == product_type - assert {arg.pk for arg in args[1]} == {size_attribute.pk} + variant_without_name.refresh_from_db() + variant_with_name.refresh_from_db() + assert variant_without_name.name == variant_without_name.sku + assert variant_with_name.name == random_name def test_update_variants_names_product_type_does_not_exist(caplog):