diff --git a/drivers/misc/pci_endpoint_test.c b/drivers/misc/pci_endpoint_test.c index cce4cdb5d1be..8eb6a78e6b0c 100644 --- a/drivers/misc/pci_endpoint_test.c +++ b/drivers/misc/pci_endpoint_test.c @@ -337,6 +337,17 @@ static bool pci_endpoint_test_msi_irq(struct pci_endpoint_test *test, return pci_irq_vector(pdev, msi_num - 1) == test->last_irq; } +static int pci_endpoint_test_validate_xfer_params(struct device *dev, + struct pci_endpoint_test_xfer_param *param, size_t alignment) +{ + if (param->size > SIZE_MAX - alignment) { + dev_dbg(dev, "Maximum transfer data size exceeded\n"); + return -EINVAL; + } + + return 0; +} + static bool pci_endpoint_test_copy(struct pci_endpoint_test *test, unsigned long arg) { @@ -368,9 +379,11 @@ static bool pci_endpoint_test_copy(struct pci_endpoint_test *test, return false; } + err = pci_endpoint_test_validate_xfer_params(dev, ¶m, alignment); + if (err) + return false; + size = param.size; - if (size > SIZE_MAX - alignment) - goto err; use_dma = !!(param.flags & PCITEST_FLAGS_USE_DMA); if (use_dma) @@ -502,9 +515,11 @@ static bool pci_endpoint_test_write(struct pci_endpoint_test *test, return false; } + err = pci_endpoint_test_validate_xfer_params(dev, ¶m, alignment); + if (err) + return false; + size = param.size; - if (size > SIZE_MAX - alignment) - goto err; use_dma = !!(param.flags & PCITEST_FLAGS_USE_DMA); if (use_dma) @@ -600,9 +615,11 @@ static bool pci_endpoint_test_read(struct pci_endpoint_test *test, return false; } + err = pci_endpoint_test_validate_xfer_params(dev, ¶m, alignment); + if (err) + return false; + size = param.size; - if (size > SIZE_MAX - alignment) - goto err; use_dma = !!(param.flags & PCITEST_FLAGS_USE_DMA); if (use_dma)