From ee800f1444713d167d9f6c54c58be876124dedf0 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Thu, 24 Aug 2023 11:00:17 +0200 Subject: [PATCH] Fix joinMany Instead of allocating an array of pointers, joinMany was allocating memory for just one pointer. This was making ArrayFire read out of bounds and fail with various errors. This commit fixes this issue by adding a helper withManyForeignPtr function that acts like withForeignPtr (not unsafeWithForeignPtr!), but for a list of ForeignPtrs. --- cabal.project | 1 + src/ArrayFire/Data.hs | 25 +++++++++++++------------ test/ArrayFire/DataSpec.hs | 5 +++++ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/cabal.project b/cabal.project index 7f529ad..e5c6ff0 100644 --- a/cabal.project +++ b/cabal.project @@ -1,3 +1,4 @@ +packages: . ignore-project: False write-ghc-environment-files: always tests: True diff --git a/src/ArrayFire/Data.hs b/src/ArrayFire/Data.hs index ab3b8e6..c4fab4b 100644 --- a/src/ArrayFire/Data.hs +++ b/src/ArrayFire/Data.hs @@ -30,7 +30,6 @@ module ArrayFire.Data where import Control.Exception -import Control.Monad import Data.Complex import Data.Int import Data.Proxy @@ -38,6 +37,7 @@ import Data.Word import Foreign.C.Types import Foreign.ForeignPtr import Foreign.Marshal hiding (void) +import Foreign.Ptr (Ptr) import Foreign.Storable import System.IO.Unsafe import Unsafe.Coerce @@ -357,20 +357,21 @@ joinMany :: Int -> [Array a] -> Array a -joinMany (fromIntegral -> n) arrays = unsafePerformIO . mask_ $ do - fptrs <- forM arrays $ \(Array fptr) -> pure fptr - newPtr <- - alloca $ \fPtrsPtr -> do - forM_ fptrs $ \fptr -> - withForeignPtr fptr (poke fPtrsPtr) - alloca $ \aPtr -> do - zeroOutArray aPtr - throwAFError =<< af_join_many aPtr n nArrays fPtrsPtr - peek aPtr +joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do + newPtr <- alloca $ \aPtr -> do + zeroOutArray aPtr + (throwAFError =<<) $ + withManyForeignPtr arrays $ \(fromIntegral -> nArrays) fPtrsPtr -> + af_join_many aPtr n nArrays fPtrsPtr + peek aPtr Array <$> newForeignPtr af_release_array_finalizer newPtr + +withManyForeignPtr :: [ForeignPtr a] -> (Int -> Ptr (Ptr a) -> IO b) -> IO b +withManyForeignPtr fptrs action = go [] fptrs where - nArrays = fromIntegral (length arrays) + go ptrs [] = withArrayLen (reverse ptrs) action + go ptrs (fptr:others) = withForeignPtr fptr $ \ptr -> go (ptr : ptrs) others -- | Tiles an Array according to specified dimensions -- diff --git a/test/ArrayFire/DataSpec.hs b/test/ArrayFire/DataSpec.hs index ab22e69..fcbd53f 100644 --- a/test/ArrayFire/DataSpec.hs +++ b/test/ArrayFire/DataSpec.hs @@ -32,3 +32,8 @@ spec = constant @(Complex Float) [1] (1.0 :+ 1.0) `shouldBe` constant @(Complex Float) [1] (1.0 :+ 1.0) + it "Should join Arrays along the specified dimension" $ do + join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] + join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2] + joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] + joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3]