Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cabal.project
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
packages: .
ignore-project: False
write-ghc-environment-files: always
tests: True
Expand Down
25 changes: 13 additions & 12 deletions src/ArrayFire/Data.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
module ArrayFire.Data where

import Control.Exception
import Control.Monad
import Data.Complex
import Data.Int
import Data.Proxy
import Data.Word
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Marshal hiding (void)
import Foreign.Ptr (Ptr)
import Foreign.Storable
import System.IO.Unsafe
import Unsafe.Coerce
Expand Down Expand Up @@ -357,20 +357,21 @@ joinMany
:: Int
-> [Array a]
-> Array a
joinMany (fromIntegral -> n) arrays = unsafePerformIO . mask_ $ do
fptrs <- forM arrays $ \(Array fptr) -> pure fptr
newPtr <-
alloca $ \fPtrsPtr -> do
forM_ fptrs $ \fptr ->
withForeignPtr fptr (poke fPtrsPtr)
alloca $ \aPtr -> do
zeroOutArray aPtr
throwAFError =<< af_join_many aPtr n nArrays fPtrsPtr
peek aPtr
joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do
newPtr <- alloca $ \aPtr -> do
zeroOutArray aPtr
(throwAFError =<<) $
withManyForeignPtr arrays $ \(fromIntegral -> nArrays) fPtrsPtr ->
af_join_many aPtr n nArrays fPtrsPtr
peek aPtr
Array <$>
newForeignPtr af_release_array_finalizer newPtr

withManyForeignPtr :: [ForeignPtr a] -> (Int -> Ptr (Ptr a) -> IO b) -> IO b
withManyForeignPtr fptrs action = go [] fptrs
where
nArrays = fromIntegral (length arrays)
go ptrs [] = withArrayLen (reverse ptrs) action
go ptrs (fptr:others) = withForeignPtr fptr $ \ptr -> go (ptr : ptrs) others

-- | Tiles an Array according to specified dimensions
--
Expand Down
5 changes: 5 additions & 0 deletions test/ArrayFire/DataSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ spec =
constant @(Complex Float) [1] (1.0 :+ 1.0)
`shouldBe`
constant @(Complex Float) [1] (1.0 :+ 1.0)
it "Should join Arrays along the specified dimension" $ do
join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2]
joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3]