Skip to content

Commit 049e21d

Browse files
committed
Keep all casts, and cast (Signal a ~ a) where appropriate
1 parent ea4d877 commit 049e21d

File tree

9 files changed

+209
-97
lines changed

9 files changed

+209
-97
lines changed

clash-ghc/src-ghc/Clash/GHC/Evaluator.hs

+3
Original file line numberDiff line numberDiff line change
@@ -3465,6 +3465,9 @@ naturalLiteral v =
34653465
DC dc [Left (Literal (ByteArrayLiteral (Vector.Vector _ _ (ByteArray.ByteArray ba))))]
34663466
| dcTag dc == 2
34673467
-> Just (Jp# (BN# ba))
3468+
CastValue v0 _ _
3469+
| Just n <- naturalLiteral v0
3470+
-> Just n
34683471
_ -> Nothing
34693472

34703473
integerLiterals' :: [Value] -> [Integer]

clash-ghc/src-ghc/Clash/GHC/GHC2Core.hs

+60-30
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ import TyCon (AlgTyConRhs (..), TyCon, tyConName,
9191
tyConArity,
9292
tyConDataCons, tyConKind,
9393
tyConName, tyConUnique, isClassTyCon)
94-
import Type (mkTvSubstPrs, substTy, coreView)
94+
import Type (mkTvSubstPrs, substTy, coreView, piResultTys)
9595
import TyCoRep (Coercion (..), TyLit (..), Type (..))
9696
import Unique (Uniquable (..), Unique, getKey, hasKey)
9797
import Var (Id, TyVar, Var, idDetails,
@@ -288,71 +288,101 @@ coreToTerm primMap unlocs = term
288288
, let (nm, _) = RWS.evalRWS (qualifiedNameString (varName x))
289289
noSrcSpan
290290
emptyGHC2CoreState
291-
= go nm args
291+
= go nm (varType x) args
292292
| otherwise
293293
= term' e
294294
where
295295
-- Remove most Signal transformers
296-
go "Clash.Signal.Internal.mapSignal#" args
297-
| length args == 5
298-
= term (App (args!!3) (args!!4))
299-
go "Clash.Signal.Internal.signal#" args
300-
| length args == 3
301-
= term (args!!2)
302-
go "Clash.Signal.Internal.appSignal#" args
303-
| length args == 5
304-
= term (App (args!!3) (args!!4))
305-
go "Clash.Signal.Internal.joinSignal#" args
296+
go "Clash.Signal.Internal.mapSignal#" pTy args
297+
| [Type aTy, Type bTy, Type domTy, fTm, aSigTm] <- args
298+
= do
299+
let aSigTy = piResultTys pTy [bTy,aTy,domTy,aTy,aTy]
300+
bSigTy = piResultTys pTy [aTy,bTy,domTy,bTy,bTy]
301+
aTyC <- coreToType aTy
302+
bTyC <- coreToType bTy
303+
aSigTyC <- coreToType aSigTy
304+
bSigTyC <- coreToType bSigTy
305+
C.Cast <$> (C.App <$> term fTm
306+
<*> (C.Cast <$> term aSigTm
307+
<*> pure aSigTyC
308+
<*> pure aTyC))
309+
<*> pure bTyC
310+
<*> pure bSigTyC
311+
go "Clash.Signal.Internal.signal#" pty args
312+
| [Type aTy, Type domTy, aTm] <- args
313+
= let aSigTy = piResultTys pty [aTy,domTy,aTy]
314+
in C.Cast <$> term aTm <*> coreToType aTy <*> coreToType aSigTy
315+
go "Clash.Signal.Internal.appSignal#" pTy args
316+
| [Type domTy, Type aTy, Type bTy, fSigTm, aSigTm] <- args
317+
= do
318+
let aSigTy = piResultTys pTy [domTy,bTy,aTy,aTy,aTy]
319+
bSigTy = piResultTys pTy [domTy,aTy,bTy,bTy,bTy]
320+
fSigTy = piResultTys pTy [domTy,aTy,FunTy aTy bTy,aTy,aTy]
321+
aTyC <- coreToType aTy
322+
bTyC <- coreToType bTy
323+
aSigTyC <- coreToType aSigTy
324+
bSigTyC <- coreToType bSigTy
325+
fSigTyC <- coreToType fSigTy
326+
let fTyC = C.mkFunTy aTyC bTyC
327+
C.Cast <$> (C.App <$> (C.Cast <$> term fSigTm
328+
<*> pure fSigTyC
329+
<*> pure fTyC)
330+
<*> (C.Cast <$> term aSigTm
331+
<*> pure aSigTyC
332+
<*> pure aTyC))
333+
<*> pure bTyC
334+
<*> pure bSigTyC
335+
go "Clash.Signal.Internal.joinSignal#" _ args
306336
| length args == 3
307337
= term (args!!2)
308-
go "Clash.Signal.Bundle.vecBundle#" args
338+
go "Clash.Signal.Bundle.vecBundle#" _ args
309339
| length args == 4
310340
= term (args!!3)
311341
--- Remove `$`
312-
go "GHC.Base.$" args
342+
go "GHC.Base.$" _ args
313343
| length args == 5
314344
= term (App (args!!3) (args!!4))
315-
go "GHC.Magic.noinline" args -- noinline :: forall a. a -> a
345+
go "GHC.Magic.noinline" _ args -- noinline :: forall a. a -> a
316346
| [_ty, x] <- args
317347
= term x
318348
-- Remove most CallStack logic
319-
go "GHC.Stack.Types.PushCallStack" args = term (last args)
320-
go "GHC.Stack.Types.FreezeCallStack" args = term (last args)
321-
go "GHC.Stack.withFrozenCallStack" args
349+
go "GHC.Stack.Types.PushCallStack" _ args = term (last args)
350+
go "GHC.Stack.Types.FreezeCallStack" _ args = term (last args)
351+
go "GHC.Stack.withFrozenCallStack" _ args
322352
| length args == 3
323353
= term (App (args!!2) (args!!1))
324-
go "Clash.Class.BitPack.packXWith" args
354+
go "Clash.Class.BitPack.packXWith" _ args
325355
| [_nTy,_aTy,_kn,f] <- args
326356
= term f
327-
go "Clash.Sized.BitVector.Internal.checkUnpackUndef" args
357+
go "Clash.Sized.BitVector.Internal.checkUnpackUndef" _ args
328358
| [_nTy,_aTy,_kn,_typ,f] <- args
329359
= term f
330-
go "Clash.Magic.prefixName" args
360+
go "Clash.Magic.prefixName" _ args
331361
| [Type nmTy,_aTy,f] <- args
332362
= C.Tick <$> (C.NameMod C.PrefixName <$> coreToType nmTy) <*> term f
333-
go "Clash.Magic.suffixName" args
363+
go "Clash.Magic.suffixName" _ args
334364
| [Type nmTy,_aTy,f] <- args
335365
= C.Tick <$> (C.NameMod C.SuffixName <$> coreToType nmTy) <*> term f
336-
go "Clash.Magic.suffixNameFromNat" args
366+
go "Clash.Magic.suffixNameFromNat" _ args
337367
| [Type nmTy,_aTy,f] <- args
338368
= C.Tick <$> (C.NameMod C.SuffixName <$> coreToType nmTy) <*> term f
339-
go "Clash.Magic.suffixNameP" args
369+
go "Clash.Magic.suffixNameP" _ args
340370
| [Type nmTy,_aTy,f] <- args
341371
= C.Tick <$> (C.NameMod C.SuffixNameP <$> coreToType nmTy) <*> term f
342-
go "Clash.Magic.suffixNameFromNatP" args
372+
go "Clash.Magic.suffixNameFromNatP" _ args
343373
| [Type nmTy,_aTy,f] <- args
344374
= C.Tick <$> (C.NameMod C.SuffixNameP <$> coreToType nmTy) <*> term f
345-
go "Clash.Magic.setName" args
375+
go "Clash.Magic.setName" _ args
346376
| [Type nmTy,_aTy,f] <- args
347377
= C.Tick <$> (C.NameMod C.SetName <$> coreToType nmTy) <*> term f
348-
go "Clash.Magic.deDup" args
378+
go "Clash.Magic.deDup" _ args
349379
| [_aTy,f] <- args
350380
= C.Tick C.DeDup <$> term f
351-
go "Clash.Magic.noDeDup" args
381+
go "Clash.Magic.noDeDup" _ args
352382
| [_aTy,f] <- args
353383
= C.Tick C.NoDeDup <$> term f
354384

355-
go _ _ = term' e
385+
go _ _ _ = term' e
356386
term' (Var x) = var x
357387
term' (Lit l) = return $ C.Literal (coreToLiteral l)
358388
term' (App eFun (Type tyArg)) = C.TyApp <$> term eFun <*> coreToType tyArg
@@ -405,7 +435,7 @@ coreToTerm primMap unlocs = term
405435
case hasPrimCoM of
406436
Just _ | ty1_I || ty2_I
407437
-> C.Cast <$> term e <*> coreToType ty1 <*> coreToType ty2
408-
_ -> term e
438+
_ -> C.Cast <$> term e <*> coreToType ty1 <*> coreToType ty2
409439
term' (Tick (SourceNote rsp _) e) =
410440
C.Tick (C.SrcSpan (RealSrcSpan rsp)) <$> addUsefull (RealSrcSpan rsp) (term e)
411441
term' (Tick _ e) = term e

clash-lib/src/Clash/Core/Evaluator.hs

+12-6
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ unwindStack m
140140
let term = Tick sp (getTerm m')
141141
in unwindStack (setTerm term m')
142142

143+
Castish ty1 ty2 ->
144+
let term = Cast (getTerm m') ty1 ty2
145+
in unwindStack (setTerm term m')
146+
143147
-- | A single step in the partial evaluator. The result is the new heap and
144148
-- stack, and the next expression to be reduced.
145149
--
@@ -232,6 +236,8 @@ stepApp x y m tcm =
232236
GT -> let (m0, n) = newLetBinding tcm m y
233237
in Just . setTerm x $ stackPush (Apply n) m0
234238

239+
Cast {} -> error "stepApp QQ"
240+
235241
_ -> let (m0, n) = newLetBinding tcm m y
236242
in Just . setTerm x $ stackPush (Apply n) m0
237243
where
@@ -264,6 +270,8 @@ stepTyApp x ty m tcm =
264270
LT -> newBinder tys' (TyApp x ty) m tcm
265271
GT -> Just . setTerm x $ stackPush (Instantiate ty) m
266272

273+
Cast {} -> error "stepTyApp QQ"
274+
267275
_ -> Just . setTerm x $ stackPush (Instantiate ty) m
268276
where
269277
(term, args, _) = collectArgsTicks (TyApp x ty)
@@ -273,17 +281,14 @@ stepLetRec :: [LetBinding] -> Term -> Step
273281
stepLetRec bs x m _ = Just (allocate bs x m)
274282

275283
stepCase :: Term -> Type -> [Alt] -> Step
284+
stepCase (Cast {}) _ty _alts _m _ = error "stepCase QQ"
276285
stepCase scrut ty alts m _ =
277286
Just . setTerm scrut $ stackPush (Scrutinise ty alts) m
278287

279288
-- TODO Support stepwise evaluation of casts.
280289
--
281290
stepCast :: Term -> Type -> Type -> Step
282-
stepCast _ _ _ _ _ =
283-
flip trace Nothing $ unlines
284-
[ "WARNING: " <> $(curLoc) <> "Clash can't symbolically evaluate casts"
285-
, "Please file an issue at https://github.com/clash-lang/clash-compiler/issues"
286-
]
291+
stepCast x ty1 ty2 m _ = Just . setTerm x $ stackPush (Castish ty1 ty2) m
287292

288293
stepTick :: TickInfo -> Term -> Step
289294
stepTick tick x m _ =
@@ -356,7 +361,8 @@ unwind tcm m v = do
356361
go (Instantiate ty) = return . instantiate v ty
357362
go (PrimApply p tys vs tms) = mPrimUnwind m tcm p tys vs v tms
358363
go (Scrutinise _ as) = return . scrutinise v as
359-
go (Tickish _) = return . setTerm (valToTerm v)
364+
go (Tickish t) = flip (unwind tcm) (TickValue t v)
365+
go (Castish ty1 ty2) = flip (unwind tcm) (CastValue v ty1 ty2)
360366

361367
-- | Update the Heap with the evaluated term
362368
update :: IdScope -> Id -> Value -> Machine -> Machine

clash-lib/src/Clash/Core/Evaluator/Types.hs

+3
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ data StackFrame
117117
| PrimApply PrimInfo [Type] [Value] [Term]
118118
| Scrutinise Type [Alt]
119119
| Tickish TickInfo
120+
| Castish Type Type
120121
deriving Show
121122

122123
instance ClashPretty StackFrame where
@@ -134,6 +135,8 @@ instance ClashPretty StackFrame where
134135
fromPpr (Case (Literal (CharLiteral '_')) a b)]
135136
clashPretty (Tickish sp) =
136137
hsep ["Tick", fromPpr sp]
138+
clashPretty (Castish ty1 ty2) =
139+
hsep ["Cast", fromPpr ty1, fromPpr ty2]
137140

138141
-- Values
139142
data Value

clash-lib/src/Clash/Core/Type.hs

+3-3
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ coreView1 tcMap ty = case tyView ty of
196196
| nameOcc tcNm == "Clash.Signal.BiSignal.BiSignalOut"
197197
, [_,_,_,elTy] <- args
198198
-> Just elTy
199-
| nameOcc tcNm == "Clash.Signal.Internal.Signal"
200-
, [_,elTy] <- args
201-
-> Just elTy
199+
-- | nameOcc tcNm == "Clash.Signal.Internal.Signal"
200+
-- , [_,elTy] <- args
201+
-- -> Just elTy
202202
| otherwise
203203
-> case tcMap `lookupUniqMap'` tcNm of
204204
AlgTyCon {algTcRhs = (NewTyCon _ nt)}

clash-lib/src/Clash/Core/Util.hs

+11-7
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ import Clash.Core.Type
5050
coreView, coreView1, isFunTy, isPolyFunCoreTy, mkFunTy, splitFunTy, tyView,
5151
undefinedTy, isTypeFamilyApplication)
5252
import Clash.Core.TyCon
53-
(TyConMap, tyConDataCons)
53+
(TyConMap, TyConName, tyConDataCons)
5454
import Clash.Core.TysPrim (typeNatKind)
5555
import Clash.Core.Var
5656
(Id, TyVar, Var (..), isLocalId, mkLocalId, mkTyVar)
@@ -995,15 +995,17 @@ shouldSplit
995995
shouldSplit tcm (tyView -> TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") [tyArg]) =
996996
-- We also look through `SimIO` to find things like Files
997997
shouldSplit tcm tyArg
998-
shouldSplit tcm ty = shouldSplit0 tcm (tyView (coreView tcm ty))
998+
shouldSplit tcm ty = shouldSplit0 emptyUniqSet tcm (tyView (coreView tcm ty))
999999

10001000
-- | Worker of 'shouldSplit', works on 'TypeView' instead of 'Type'
10011001
shouldSplit0
1002-
:: TyConMap
1002+
:: UniqSet TyConName
1003+
-> TyConMap
10031004
-> TypeView
10041005
-> Maybe (Term,[Type])
1005-
shouldSplit0 tcm (TyConApp tcNm tyArgs)
1006-
| Just tc <- lookupUniqMap tcNm tcm
1006+
shouldSplit0 seen tcm (TyConApp tcNm tyArgs)
1007+
| tcNm `notElemUniqSet` seen
1008+
, Just tc <- lookupUniqMap tcNm tcm
10071009
, [dc] <- tyConDataCons tc
10081010
, let dcArgs = substArgTys dc tyArgs
10091011
, let dcArgVs = map (tyView . coreView tcm) dcArgs
@@ -1012,8 +1014,10 @@ shouldSplit0 tcm (TyConApp tcNm tyArgs)
10121014
else
10131015
Nothing
10141016
where
1017+
seen1 = extendUniqSet seen tcNm
1018+
10151019
shouldSplitTy :: TypeView -> Bool
1016-
shouldSplitTy ty = isJust (shouldSplit0 tcm ty) || splitTy ty
1020+
shouldSplitTy ty = isJust (shouldSplit0 seen1 tcm ty) || splitTy ty
10171021

10181022
-- Hidden constructs (HiddenClock, HiddenReset, ..) don't need to be split
10191023
-- because KnownDomain will be filtered anyway during netlist generation due
@@ -1046,7 +1050,7 @@ shouldSplit0 tcm (TyConApp tcNm tyArgs)
10461050
]
10471051
splitTy _ = False
10481052

1049-
shouldSplit0 _ _ = Nothing
1053+
shouldSplit0 _ _ _ = Nothing
10501054

10511055
-- | Potentially split apart a list of function argument types. e.g. given:
10521056
--

0 commit comments

Comments
 (0)