{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE BangPatterns #-}
{-
author: kstahl@mailbox.tu-berlin.de

compile with: ghc -O2 -fexcess-precision -funfolding-use-threshold=16 -o pclustering pclustering.hs
-}
module Main where

import Control.Monad
import System.Console.GetOpt
import System.IO
import System.FilePath
import System.Exit 
import System.Directory (createDirectoryIfMissing,doesFileExist)
import System.Environment (getArgs)
import Data.List
import Data.Maybe (fromMaybe, isNothing)
import Data.Char (isDigit)
import qualified Data.Text as T
import qualified Data.Text.IO as TIO
import qualified Data.Text.Read as TR
import qualified Data.Map as M
import Text.Parsec.Text
import qualified Text.Parsec.Token as P
import Text.Parsec.Language (haskellDef)
import Text.ParserCombinators.Parsec
import Control.DeepSeq

data PatternSymbol = Box | Star deriving (Eq, Ord)

instance Show PatternSymbol where
  show Box = "_"
  show Star = "*"

data Alphabet = Text {-# UNPACK #-} !Int | Number {-# UNPACK #-} !Int | FPNumber {-# UNPACK #-} !Double deriving (Read, Eq, Ord)
instance Show Alphabet where
  show (Text a) = show a
  show (Number a) = show a
  show (FPNumber a) = show a

type Record = [Alphabet]
data MatchedPosition = StarField | DropField | TextField {-# UNPACK #-} !Int | NumberField {-# UNPACK #-} !Int | FPNumberField {-# UNPACK #-} !Double deriving (Eq,Ord)

instance Show MatchedPosition where
  show StarField = show Star
  show (TextField a) = show a
  show (NumberField a) = show a
  show (FPNumberField a) = show a
  show DropField = show Star

instance NFData Alphabet where
  rnf (Text a) = rnf a
  rnf (Number a) = rnf a
  rnf (FPNumber a) = rnf a

type Pattern = [PatternSymbol]
type MatchedRecord = [MatchedPosition]

data Evaluation = Evaluation {
  k :: !Int,
  usefulness :: !Double,
  weightedUsefulness :: !Double,
  noOfAequivalenceClasses :: !Double,
  avgAequivalenceClassSize :: !Double,
  noOfSuppressions :: !Int,
  weightedNoOfSuppressions :: !Int,
  sumOfSquaredACSize :: !Int,
  maximumACSize :: !Double,
  noOfColumns :: !Int,
  distcount :: !Double,
  size :: !Int
}

instance  Show Evaluation where
  show (Evaluation k use wuse nac aac ns wns ss mac nc dc size) = "k: " ++ show k ++ "\nno of suppressions: " ++ show ns
    ++ "\nweighted no of suppressions: " ++ show wns
    ++ "\nno of aequivalence classes (ac): " ++ show nac ++ "\nuseness: " ++ show use ++ "\nweighted useness: " ++ show wuse
    ++ "\navg. size of ac: " ++ show aac  ++ "\nsum of squared ac sizes: " ++ show ss ++ "\nmaximum ac size: " ++ show mac
    ++ "\nno of columns: " ++ show nc ++ "\nmax no of distinct elements: " ++ show dc ++ "\nsize of dataset: " ++ show size
    ++ "\n\n#" ++ show k ++ " " ++ show ns ++ " " ++ show wns ++ " " ++ show nac ++ " " ++ show use ++ " " ++ show wuse ++ " " ++ show aac
    ++ " " ++ show ss  ++ " " ++ show mac ++ " " ++ show nc ++ " " ++ show dc ++ " " ++ show size ++ "\n"

parseFile :: Char -> String -> IO ((M.Map T.Text Int, M.Map Int T.Text), [Record])
parseFile delimiter = liftM (mapAccumL (\(!mapping) (!inp)  -> parseEntry mapping $ T.split (== delimiter) inp
                            ) (M.empty,M.empty) . T.lines) . TIO.readFile

cmpRecordsByPattern :: Pattern -> Record -> Record -> Bool
cmpRecordsByPattern [] [] [] = True
cmpRecordsByPattern (p:ps) (r1:r1s) (r2:r2s) = 
        case p of 
          Star -> cmpRecordsByPattern ps r1s r2s
          Box -> if r1 /= r2
                 then False
                 else cmpRecordsByPattern ps r1s r2s
cmpRecordsByPattern _ _ _ = False

orderRecordsByPattern :: Pattern -> Record -> Record -> Ordering
orderRecordsByPattern (p:ps) (r1:r1s) (r2:r2s)
      | p == Star = orderRecordsByPattern ps r1s r2s
      | r1 < r2 = LT
      | r1 > r2 = GT
      | otherwise = orderRecordsByPattern ps r1s r2s
orderRecordsByPattern [] r1 r2 = r1 `compare` r2
orderRecordsByPattern _ [] (_:_) = LT
orderRecordsByPattern _ (_:_) [] = GT
orderRecordsByPattern _ [] [] = EQ

applyPattern :: Pattern -> Record -> MatchedRecord
applyPattern = zipWith (\x y -> if x == Box
                                then
                                    case y of
                                      Text a -> TextField a
                                      Number a -> NumberField a
                                      FPNumber a -> FPNumberField a
                                else StarField)

dropFields :: Record -> MatchedRecord
dropFields rec = replicate (length rec) DropField

extractComponents :: Pattern -> [(MatchedRecord,[Record])] ->  ([[Record]], [[Record]]) -> ([(MatchedRecord,[Record])], [Record])
extractComponents p imres (xs, ys) = (foldl' (\(!racc) x -> (applyPattern p (head x),x) : racc) imres xs, concat ys)                                  

{-
simple greedy heuristic
args:
k - minimum no of entries to aggregate
(p:ps) - list of pattern to apply
list - list of records, not yet aggregated, clean of patterns
res - current result list
return: tuple: the aggregated records with pattern applied and the rest (didn't make the cut)
-}

greedyHeuristic :: Int -> [Pattern] -> [Record] -> [(MatchedRecord,[Record])]
                    -> ([(MatchedRecord,[Record])], [Record])
greedyHeuristic _ [] !list !res = (res, list)
greedyHeuristic _ _ [] !res = (res, [])
greedyHeuristic !k (p:ps) list res
    | (length . take k) list < k = ((dropFields (head list),list):res,[])       -- small speedup, don't try all the remaining patternvectors, if rest < k
    | otherwise = 
        let (res2, rest) = extractComponents p res $ partition (\x -> length x >= k) . groupBy (cmpRecordsByPattern p) $ sortBy (orderRecordsByPattern p) list
        in greedyHeuristic k ps rest res2

{-
creates all possible pattern permutations of a given length (n), the result is presorted from low cost (few * as possible) to high
-}

createPatternPermutations :: Int -> [Pattern]
createPatternPermutations !n 
    | n <= 0 = []
    | otherwise = replicateM n [Box,Star]


createPatternPermutations2 :: Int -> [Pattern]
createPatternPermutations2 !c = concatMap helper $ zip (reverse [0..c]) [0..c]

helper :: (Int, Int) -> [Pattern]
helper (c,            0)          = [replicate c Box]
helper (0,            c)          = [replicate c Star]
helper (cUnderscores, cAsterisks) = map (Box :) (helper (cUnderscores - 1, cAsterisks))
                                 ++ map (Star :) (helper (cUnderscores, cAsterisks - 1))

{-
calculates the cost of a given Pattern Vector: no of stars (*) * weights (default: 1)
-}

costOfPattern :: [Int] -> Pattern -> Int
costOfPattern [] pattern = foldl'(\(!acc) p -> case p of
                                                Star -> acc + 1
                                                _ -> acc) 0 pattern
--costOfPattern [] pattern = length $ filter (== Star) pattern
costOfPattern weights pattern = sum $ zipWith (\w p -> case p of
                                                        Star -> w
                                                        _ -> 0
                                              ) weights pattern


cmpPattern :: [Int] -> Pattern -> Pattern -> Ordering
cmpPattern weights p1 p2
    | cp1 < cp2 = LT
    | cp1 > cp2 = GT
    | otherwise = EQ
    where cp1 = costOfPattern weights p1
          cp2 = costOfPattern weights p2

costOfMatchedRecord :: [Int] -> MatchedRecord -> Int
costOfMatchedRecord [] pattern = foldl'(\(!acc) p -> if p == StarField || p == DropField
                                                     then acc + 1
                                                     else acc
                                       )0 pattern
--costOfMatchedRecord [] pattern = length $! filter (\x -> x == StarField || x == DropField) pattern
costOfMatchedRecord weights pattern = sum $! zipWith (\w p -> case p of
                                                              StarField -> w
                                                              DropField -> w
                                                              _ -> 0
                                                     ) weights pattern

{- 
args: k (size of Dataset) useness result no of columns
-}
evaluateResult :: Int -> [Int] -> Int -> Double -> Double ->  ([(MatchedRecord,[Record])], [Record]) -> Int -> Double -> Evaluation
evaluateResult !k weights !size !useness !wuseness (res,rest) !nc !dcount = let ((!squared, (!suppr,!wsuppr)), (!noOfAC, !maxACS)) = costs weights res
                                 in Evaluation k useness wuseness noOfAC (fromIntegral size/noOfAC) suppr wsuppr squared maxACS nc dcount size

costs :: [Int] -> [(MatchedRecord,[Record])] -> ((Int,(Int,Int)),(Double,Double))
costs weights = foldl' (\((!acc1,(!acc2,!acc5)),(!acc3, !acc4)) (rs,rst) ->
                                  let !lrst = length rst
                                      !dlrst = fromIntegral lrst 
                                  in ((acc1 + (lrst ^ 2),(acc2 + costOfMatchedRecord [] rs * lrst, acc5 + costOfMatchedRecord weights rs * lrst)), (acc3 + 1, max acc4 dlrst))
                ) ((0,(0,0)),(0,-1/0))

parseField :: (M.Map T.Text Int, M.Map Int T.Text) -> T.Text -> ((M.Map T.Text Int, M.Map Int T.Text), Alphabet)
parseField (m,bm) s = if isDigit (T.head s) then
                 case TR.decimal s of
                    Right (n,"")  -> ((m,bm), Number n)
                    _   -> case TR.double s of
                              Right (!n,"")  -> ((m,bm), FPNumber n)
                              _   -> ret
                 else ret
                   where ret = case lsm of
                                 Nothing -> ((M.insert s nkey m, M.insert nkey s bm), Text nkey)
                                 Just k -> ((m,bm), Text k)
                         !lsm = M.lookup s m
                         !nkey = M.size m

parseEntry :: (M.Map T.Text Int,M.Map Int T.Text) -> [T.Text] -> ((M.Map T.Text Int, M.Map Int T.Text), Record)
parseEntry = mapAccumL (\(!macc) (!inp) -> parseField macc inp)

isNumber :: Alphabet -> Bool
isNumber n = case n of
              Text _ -> False
              _ -> True

extractMinMax :: [Bool] -> [Record] -> [Either ((Double,Double), Double) Double]
extractMinMax _ [] = []
extractMinMax [] _ = []
extractMinMax b res = extractMM b (transpose res)

{- initial file analysis, obtain column information, see extractMM -}
initialExtractMinMax :: Bool -> [Record] -> ([Bool], [Either ((Double,Double), Double) Double])
initialExtractMinMax _ [] = ([],[])
initialExtractMinMax b res = (blist, extractMM blist tres)
                           where tres = transpose res
                                 blist = if b
                                         then replicate (length $ head res) False
                                         else map (all isNumber) tres

{- Left contains (min,max) and no of distinct values of the interval values, Right: no of distinct discrete values -}
extractMM :: [Bool] -> [Record] -> [Either ((Double,Double),Double) Double]
extractMM [] _ = []
extractMM _ [] = []
extractMM (b:bs) (r:rs) = if b
                          then Left (minmax r (1/0,-1/0),distcount) : extractMM bs rs
                          else Right distcount : extractMM bs rs
                          where !distcount = (fromIntegral . length . nub) r
  

minmax :: Record -> (Double,Double) -> (Double,Double)
minmax [] !acc = acc
minmax (r:rs) (!minacc,!maxacc) = case r of
                        (Number a) -> minmax rs (min minacc (fromIntegral a), max maxacc (fromIntegral a))
                        (FPNumber a) -> minmax rs (min minacc a, max maxacc a)
                        _ -> minmax rs (minacc, maxacc)

findMaxDistcount :: [Either ((Double,Double), Double) Double] -> Double -> Double
findMaxDistcount [] !mmax = mmax
findMaxDistcount (d:ds) !mmax = case d of
                                Right !d -> findMaxDistcount ds (max mmax d)
                                Left ((_,_),!d) -> findMaxDistcount ds (max mmax d)

{-
calculates the useness (see paper: Capturing Data Usefulness and Privacy Protection in K-Anonymisation; and function da)
as a side effect: writes transformed output to file (with interval columns generalized and * instead of the discrete values)
-}
useness :: FilePath -> [Int] -> ([Bool],[Either ((Double,Double),Double) Double]) -> ([(MatchedRecord,[Record])], [Record]) -> M.Map Int T.Text -> IO (Double,Double)
useness f weights (b,ges) (m,_) mm = do handle <- openFile f WriteMode
                                        (x,xx) <- go 0 0 m handle
                                        return (x / fromIntegral (length m),xx / fromIntegral (length m))
                            where go :: Double -> Double -> [(MatchedRecord,[Record])] -> Handle -> IO (Double,Double)
                                  go !acc !wacc [] handle = do  hClose handle
                                                                return (acc,wacc)
                                  go !acc !wacc ((mp,result):rs) handle = do  TIO.hPutStrLn handle (T.pack (show (length result)) `T.append` outp result mp)
                                                                              go unweighted weighted rs handle
                                                                              where 
                                                                                !unweighted = acc + sum (zipWith da ges (r result))
                                                                                !weighted = if null weights
                                                                                            then unweighted
                                                                                            else wacc + sum (zipWith3 (da' avgweight) weights ges (r result))
                                                                                avgweight :: Double
                                                                                !avgweight = fromIntegral (sum weights) / fromIntegral (length weights)
                                  !r = extractMinMax b
                                  outp result mp = generateOutput (r result) mp mm ""
     


da :: Either ((Double,Double),Double) Double -> Either ((Double,Double),Double) Double -> Double
da !dages !dagroup = case dages of
        Left ((!minDa,!maxDa),_) -> case dagroup of
                  Left ((!minVa,!maxVa),_) -> if maxDa - minDa /= 0
                                            then (maxVa - minVa) / (maxDa - minDa)
                                            else 0
                  Right  _ -> 0 
        Right !d -> case dagroup of
                      Left (_,!i2) -> i2 / d
                      Right !i2 -> i2 / d

da' :: Double -> Int -> Either ((Double,Double),Double) Double -> Either ((Double,Double),Double) Double -> Double
da' !avgweight !weight !dages !dagroup = case dages of
        Left ((!minDa,!maxDa),_) -> case dagroup of
                  Left ((!minVa,!maxVa),_) -> if maxDa - minDa /= 0
                                            then (maxVa - minVa) * (fromIntegral weight / avgweight) / (maxDa - minDa)
                                            else 0
                  Right  _ -> 0 
        Right !d -> case dagroup of
                      Left (_,!i2) -> i2 * (fromIntegral weight / avgweight) / d
                      Right !i2 -> i2 * (fromIntegral weight / avgweight) / d


sanityCheck :: [Record] -> Either String Int
sanityCheck r = if count == 1
                then Right (head columns)
                else Left "Invalid input: no of columns don't match"
                where !count = length columns
                      columns = (nub . map length) r

patternFile = endBy line newline
line = sepBy symbol (char ',')

symbol = try parseBox <|> try parseStar

parseBox = do 
              char '_' <|> char '1'
              return Box

parseStar = do
              char '*' <|> char '0'
              return Star



data Options = Options
     { optHelp     :: Bool
     , optPattern :: Maybe String
     , optK      :: Maybe [Int]
     , optWeight :: Maybe [Int]
     , optDiscrete :: Bool
     } deriving (Eq,Show)

defaultOptions :: Options
defaultOptions = Options
     { optHelp    = False
     , optPattern = Nothing
     , optK     = Just [2,3,5,10]
     , optWeight     = Just []
     , optDiscrete = False
     }

options :: [OptDescr (Options -> Options)]
options =
  [ 
    Option "h" ["help"]
      (NoArg (\ opts -> opts { optHelp = True })) "print this help"
  ,    Option "d" ["discrete"]
      (NoArg (\ opts -> opts { optDiscrete = True })) "every value will be treated as a discrete value (useness)"
  , Option [] ["k"]
      (OptArg ((\ f opts -> opts { optK = f }) . parseIntList . T.pack . fromMaybe "") 
                "k")
         "list of k's (e.g.: 1,2,3), default: 2,3,5,10"
  , Option [] ["w"]
      (OptArg ((\ f opts -> opts { optWeight = f }) . parseIntList . T.pack . fromMaybe "") 
                "w")
         "list of weights (e.g.: 1,2,3), default: 1, provide as\nmany weights as columns exist; \
         \cost of a pattern is the sum of its weights\n(patterns are sorted from lowest to highest cost) \
         \ \nhas no effect when used with --p"
  , Option [] ["p"]
      (OptArg (\ f opts -> opts { optPattern = f }) 
                "p")
         "path to pattern file (file format: * or 0 to drop entry,\n_ or 1 to keep, (e.g.: 0,_,*,1))"
  ]

parseOpts :: [String] -> IO (Options, [String])
parseOpts argv =
  case getOpt Permute options argv of
    (args,[fs],[]) -> do
        exists <- doesFileExist fs
        unless exists $ hPutStrLn stderr (failure fs) >> exitFailure
        if optHelp (nargs args) then
          ioError (userError (usageInfo header options)) >> exitSuccess
        else return (nargs args, [fs])
    (_,_,errs) -> ioError (userError (concat errs ++ usageInfo header options))
  where header = "Usage: pclustering [OPTION...] file"
        failure file = "File \"" ++ file ++ "\" doesn't exist"
        nargs = foldl (flip id) defaultOptions

parseIntList :: T.Text -> Maybe [Int]
parseIntList input = go (T.split (== ',') input) []
                where go [] acc = Just acc
                      go (i:is) acc = case TR.decimal i of
                                        Right (n,"")  -> go is (acc++[n])
                                        _   -> Nothing

{-
constructs the output of a given aequivalence class with interval columns generalized
-}
generateOutput :: [Either ((Double,Double),Double) Double] -> MatchedRecord -> M.Map Int T.Text -> T.Text -> T.Text
generateOutput [] [] _ !acc = acc
generateOutput [] (_:_) _ !acc = acc
generateOutput (_:_) [] _ !acc = acc
generateOutput (e:es) (r:rs) mmap !acc = case r of
      StarField -> case e of 
                      Left ((!mmin,!mmax),_) -> generateOutput es rs mmap (acc `T.append` ",[" `T.append` T.pack (show mmin) `T.append` ":" `T.append` T.pack (show mmax) `T.append` "]") 
                      Right _ -> generateOutput es rs mmap (acc `T.append` ",*")
      DropField -> generateOutput es rs mmap (acc `T.append` ",*")
      TextField e -> generateOutput es rs mmap (acc `T.append` "," `T.append` str) 
              where !mstr = M.lookup e mmap
                    !str = fromMaybe "" mstr
      _ -> generateOutput es rs mmap (acc `T.append` "," `T.append` T.pack (show r))


{-
args: pattern list to process, list of k's, parsed input, no of columns
-}
runEval :: [Int] -> [Pattern] -> [Int] -> [Record] -> Int -> M.Map Int T.Text -> FilePath -> Bool -> IO ()
runEval weights plist ks input !nc m f discrete = go ks
                where go [] = return ()
                      go [k] = do (use,wuse) <- useness ff weights ex (gh k) m
                                  print (evaluateResult k weights lges use wuse (gh k) nc dcount)
                                  where ff = "output" </> (base ++ "_k=" ++ show k ++ ".csv")
                      go (k:ks) = do go [k]
                                     go ks
                      ex = initialExtractMinMax discrete input
                      dcount = findMaxDistcount (snd ex) 0
                      lges = length input
                      gh k = greedyHeuristic k plist input []
                      base = takeBaseName f

main :: IO ()
main = do
    (args, [file]) <- getArgs >>= parseOpts
    ((_,bm),linput) <- parseFile ',' file
    let input = rnf linput `deepseq` linput
    createDirectoryIfMissing True "output"
    case sanityCheck input of
      Left err -> hPutStrLn stderr err >> exitFailure
      Right count -> case optK args of
                      (Just k) -> 
                        case optPattern args of 
                          (Just f) -> do
                                        pinput <- parseFromFile patternFile f
                                        case pinput of 
                                          Left err -> hPrint stderr err >> exitFailure
                                          Right p -> case optWeight args of
                                                Just w -> runEval w plist k input count bm file (optDiscrete args)
                                                        where plist = if null w
                                                                      then p
                                                                      else sortBy (cmpPattern w) p
                                                Nothing -> hPutStrLn stderr "invalid weights supplied (invalid format)" >> exitFailure
                          Nothing -> case optWeight args of 
                                (Just w) -> do
                                              when (not (null w) && count /= length w) $ hPutStrLn stderr "invalid weights supplied (length mismatch)" >> exitFailure
                                              runEval w plist k input count bm file (optDiscrete args)
                                              where plist = if null w
                                                            then createPatternPermutations2 count
                                                            else sortBy (cmpPattern w) $ createPatternPermutations count
                                Nothing -> hPutStrLn stderr "invalid weights supplied (invalid format)" >> exitFailure
                      Nothing -> hPutStrLn stderr "invalid k-list supplied" >> exitFailure