Skip to content

Core Models & Data

The core module contains the deep learning implementations and data transformation logic.

Models

syntho_hive.core.models.ctgan.CTGAN

Bases: ConditionalGenerativeModel

Conditional Tabular GAN with entity embeddings and parent context.

Source code in syntho_hive/core/models/ctgan.py
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
class CTGAN(ConditionalGenerativeModel):
    """Conditional Tabular GAN with entity embeddings and parent context."""

    def __init__(
        self,
        metadata: Any,
        embedding_dim: int = 128,
        generator_dim: Tuple[int, int] = (256, 256),
        discriminator_dim: Tuple[int, int] = (256, 256),
        batch_size: int = 500,
        epochs: int = 300,
        device: str = "cpu",
        embedding_threshold: int = 50,
        discriminator_steps: int = 5,
        legacy_context_conditioning: bool = False,
    ):
        """Create a CTGAN instance configured for tabular synthesis.

        Args:
            metadata: Table metadata describing columns and constraints.
            embedding_dim: Dimension of input noise vector.
            generator_dim: Hidden layer widths for the generator.
            discriminator_dim: Hidden layer widths for the discriminator.
            batch_size: Training batch size.
            epochs: Number of training epochs.
            device: Torch device string, e.g. ``"cpu"`` or ``"cuda"``.
            embedding_threshold: Cardinality threshold for switching to embeddings.
            discriminator_steps: Number of discriminator steps per generator step.
            legacy_context_conditioning: If True, reuses discriminator batch context
                in generator step (legacy behavior). Default False applies correct
                independent resample, which prevents FK cardinality drift.
        """
        self.metadata = metadata
        self.embedding_dim = embedding_dim
        self.generator_dim = generator_dim
        self.discriminator_dim = discriminator_dim
        self.batch_size = batch_size
        self.epochs = epochs
        self.device = torch.device(device)
        self.discriminator_steps = discriminator_steps
        # Prioritize init arg, fallback to metadata if available, else default (already 50)
        self.embedding_threshold = embedding_threshold
        self.legacy_context_conditioning = legacy_context_conditioning

        self.generator = None
        self.discriminator = None
        self.transformer = DataTransformer(
            metadata, embedding_threshold=self.embedding_threshold
        )
        self.context_transformer = DataTransformer(
            metadata, embedding_threshold=self.embedding_threshold
        )

        # Embedding Layers
        self.embedding_layers = nn.ModuleDict()
        self.data_column_info = []  # List of tuples: (dim, type, related_info)

    def _compile_layout(self, transformer):
        """Analyze transformer output to map column indices and types.

        Args:
            transformer: Fitted ``DataTransformer`` for the child table.
        """
        self.data_column_info = []
        self.embedding_layers = nn.ModuleDict()

        current_idx = 0
        for col, info in transformer._column_info.items():
            if info["type"] == "categorical_embedding":
                # Create Embedding Layer
                num_categories = info["num_categories"]
                # Heuristic for embedding dimension: min(50, num_categories/2)
                emb_dim = min(50, (num_categories + 1) // 2)

                self.embedding_layers[col] = EntityEmbeddingLayer(
                    num_categories, emb_dim
                ).to(self.device)

                self.data_column_info.append(
                    {
                        "name": col,
                        "type": "embedding",
                        "input_idx": current_idx,
                        "input_dim": 1,
                        "output_dim": emb_dim,
                        "num_categories": num_categories,
                    }
                )
                current_idx += 1
            else:
                self.data_column_info.append(
                    {
                        "name": col,
                        "type": "normal",
                        "input_idx": current_idx,
                        "input_dim": info["dim"],
                        "output_dim": info["dim"],
                    }
                )
                current_idx += info["dim"]

    def _apply_embeddings(self, data, is_fake=False):
        """Convert a mixed categorical/continuous tensor into embedding space.

        Args:
            data: Input tensor with mixed column representations.
            is_fake: Whether the tensor came from the generator (logits) or real data (indices).

        Returns:
            Tensor with embeddings applied to categorical columns.
        """
        parts = []
        for info in self.data_column_info:
            idx = info["input_idx"]
            dim = info["input_dim"]
            col_data = data[:, idx : idx + dim]

            if info["type"] == "embedding":
                layer = self.embedding_layers[info["name"]]
                if is_fake:
                    # col_data contains Softmax logits from Generator
                    # Needs hard Gumbel-Softmax or Softmax? Generator outputs unnormalized logits usually.
                    # Ideally Generator outputs (N, num_cats).
                    # Wait, 'data' passed here is strictly what Generator produced.
                    # Discriminator expects (N, EmbDim).

                    # Logic: Generator outputs Logits. We apply Softmax -> Dense.
                    # But wait, logic above says Generator outputs:
                    # Embedding: Logits (dim=num_cats)
                    # Normal: Values (dim=original_dim)

                    # So 'dim' in loop here must match GENERATOR output structure, not Transformer output.
                    # Compile Layout logic is slightly tricky because Generator output shape != Transformer output shape for Embeddings.

                    # RE-THINK:
                    # Transformer Output (Real): [Index] (1 dim)
                    # Generator Output (Fake): [Logits] (num_cats dim)

                    # This function strictly transforms Real Data (Index) -> Embedding.
                    # Or Fake Data (Logits) -> Soft Embedding.

                    # Problem: input 'data' has different shapes for Real vs Fake.
                    # We need to handle them separately or have this function assume inputs are already sliced?
                    # Let's pass sliced inputs or rely on info having both dims.
                    pass
                else:
                    # Real Data: Indices -> Embedding
                    # input is (N, 1) indices
                    embeddings = layer(col_data.long().squeeze(1))
                    parts.append(embeddings)
            else:
                parts.append(col_data)

        # Re-implementing clearer separated logic in Build Model / Forward
        return torch.cat(parts, dim=1)

    def _build_model(self, transformer_output_dim: int, context_dim: int = 0):
        """Instantiate generator and discriminator modules.

        Args:
            transformer_output_dim: Flattened dimension of transformed child data.
            context_dim: Flattened dimension of transformed context (if any).
        """
        # 1. Compile Layout first
        self._compile_layout(self.transformer)

        # 2. Calculate Generator Output Dim & Discriminator Input Dim
        gen_output_dim = 0
        disc_input_dim_base = 0

        for info in self.data_column_info:
            if info["type"] == "embedding":
                gen_output_dim += info["num_categories"]  # Generator outputs logits
                disc_input_dim_base += info["output_dim"]  # D sees embeddings
            else:
                gen_output_dim += info["output_dim"]
                disc_input_dim_base += info["output_dim"]

        # Generator: Noise + Context -> Data (Logits/Values)
        gen_input_dim = self.embedding_dim + context_dim

        self.generator = nn.Sequential(
            ResidualLayer(gen_input_dim, self.generator_dim[0]),
            ResidualLayer(self.generator_dim[0], self.generator_dim[1]),
            nn.Linear(self.generator_dim[1], gen_output_dim),
        ).to(self.device)

        # Discriminator: Data(Embeddings) + Context -> Score
        disc_input_dim = disc_input_dim_base + context_dim

        self.discriminator = Discriminator(
            disc_input_dim, self.discriminator_dim[0], self.discriminator_dim[1]
        ).to(self.device)

    def fit(
        self,
        data: pd.DataFrame,
        context: Optional[pd.DataFrame] = None,
        table_name: Optional[str] = None,
        checkpoint_dir: Optional[str] = None,
        log_metrics: bool = True,
        seed: Optional[int] = None,
        progress_bar: bool = True,
        checkpoint_interval: int = 10,
        **kwargs: Any,
    ) -> None:
        """Train the CTGAN model on tabular data.

        Args:
            data: Child table data (target) to model.
            context: Parent attributes to condition on (aligned row-wise).
            table_name: Table name for metadata lookup and constraint handling.
            checkpoint_dir: Directory to save checkpoints (best model, metrics). Defaults to None.
            log_metrics: Whether to save training metrics to a CSV file. Defaults to True.
            seed: Integer seed for deterministic training. When None, an integer is
                  auto-generated and logged so the run can be reproduced later.
            progress_bar: If True (default), display a tqdm progress bar to stderr during
                  training. Structured log events always emit regardless of this flag.
            checkpoint_interval: Save a validation checkpoint every N epochs. Default 10.
            **kwargs: Extra training options (unused placeholder for compatibility).
        """
        import random as _random

        # Seed handling — auto-generate when not provided so every run is reproducible.
        if seed is None:
            seed = _random.randint(0, 2**31 - 1)
            log.info(
                "training_seed",
                seed=seed,
                message="No seed provided — auto-generated. Log this value to reproduce this run.",
            )
        else:
            log.info("training_seed", seed=seed)

        _set_seed(seed)

        # 0. Setup Checkpointing
        if checkpoint_dir:
            os.makedirs(checkpoint_dir, exist_ok=True)

        history = []

        # Validation-metric checkpoint state (QUAL-03)
        _validator = None
        best_val_metric = float("inf")
        best_epoch = -1
        best_checkpoint_path = None

        if checkpoint_dir:
            from syntho_hive.validation.statistical import StatisticalValidator

            _validator = StatisticalValidator()
        # 1. Fit and Transform Data
        self.transformer.fit(data, table_name=table_name, seed=seed)
        train_data = self.transformer.transform(data)
        train_data = torch.from_numpy(train_data).float().to(self.device)

        # 2. Handle Context
        if context is not None:
            if len(data) != len(context):
                raise ValueError(
                    f"Data and context must have same number of rows, "
                    f"got {len(data)} and {len(context)}"
                )

            # Use dedicated transformer for context
            # NOTE: We abuse metdata here slightly. Ideally context comes from a known table (Parent).
            # But context might be a mix of parent columns.
            # For fit, we pass table_name=None to fit on just the columns present in context df.
            self.context_transformer.fit(context)
            context_transformed = self.context_transformer.transform(context)
            context_data = torch.from_numpy(context_transformed).float().to(self.device)
            context_dim = context_data.shape[1]
        else:
            context_data = None
            context_dim = 0

        data_dim = train_data.shape[1]

        # 3. Build Model
        if self.generator is None:
            self._build_model(data_dim, context_dim)

        all_gen_params = list(self.generator.parameters()) + list(
            self.embedding_layers.parameters()
        )
        optimizer_G = optim.Adam(all_gen_params, lr=2e-4, betas=(0.5, 0.9))
        optimizer_D = optim.Adam(
            self.discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9)
        )

        # 4. Training Loop (WGAN-GP)
        steps_per_epoch = max(len(train_data) // self.batch_size, 1)

        # Emit training_start event
        log.info(
            "training_start",
            total_epochs=self.epochs,
            batch_size=self.batch_size,
            embedding_dim=self.embedding_dim,
            checkpoint_interval=checkpoint_interval,
        )
        _start_time = time.time()

        # Replace bare for-loop with trange (disable=True suppresses bar; log events always fire)
        pbar = trange(
            self.epochs,
            desc="Training",
            file=sys.stderr,
            leave=True,
            disable=not progress_bar,
        )

        for epoch in pbar:
            for i in range(steps_per_epoch):
                # --- Train Discriminator ---
                for _ in range(self.discriminator_steps):
                    optimizer_D.zero_grad()

                    # Sample real data
                    idx = np.random.randint(0, len(train_data), self.batch_size)
                    real_data_batch = train_data[idx]
                    if context_data is not None:
                        real_context_batch = context_data[idx]
                        real_input = torch.cat(
                            [real_data_batch, real_context_batch], dim=1
                        )
                    else:
                        real_context_batch = None
                        real_input = real_data_batch

                    # Generate fake data
                    noise = torch.randn(
                        self.batch_size, self.embedding_dim, device=self.device
                    )
                    if real_context_batch is not None:
                        gen_input = torch.cat([noise, real_context_batch], dim=1)
                    else:
                        gen_input = noise

                    fake_raw = self.generator(gen_input)

                    # Apply Embeddings / Softmax to Fake Data
                    fake_parts = []
                    fake_ptr = 0
                    for info in self.data_column_info:
                        if info["type"] == "embedding":
                            dim = info["num_categories"]
                            logits = fake_raw[:, fake_ptr : fake_ptr + dim]
                            fake_ptr += dim

                            # Gumbel Softmax or Softmax? WGAN prefers generic softmax for differentiability
                            # Note: Gumbel Softmax allows hard sampling with gradients.
                            probs = F.softmax(logits, dim=1)
                            emb_vect = self.embedding_layers[info["name"]].forward_soft(
                                probs
                            )
                            fake_parts.append(emb_vect)
                        else:
                            dim = info["output_dim"]
                            val = fake_raw[:, fake_ptr : fake_ptr + dim]
                            fake_ptr += dim
                            fake_parts.append(val)

                    fake_data_batch = torch.cat(fake_parts, dim=1)

                    # Apply Embeddings to Real Data
                    real_parts = []
                    real_ptr = 0
                    # Need to iterate column info again to slice real data correctly
                    # Real data from transformer is concatenated (Indices, Values...)
                    for info in self.data_column_info:
                        dim = info["input_dim"]  # 1 for embedding (index)
                        col_data = real_data_batch[:, real_ptr : real_ptr + dim]
                        real_ptr += dim

                        if info["type"] == "embedding":
                            emb_vect = self.embedding_layers[info["name"]](
                                col_data.long().squeeze(1)
                            )
                            real_parts.append(emb_vect)
                        else:
                            real_parts.append(col_data)

                    real_data_processed = torch.cat(real_parts, dim=1)

                    if real_context_batch is not None:
                        fake_input = torch.cat(
                            [fake_data_batch, real_context_batch], dim=1
                        )
                        real_input_processed = torch.cat(
                            [real_data_processed, real_context_batch], dim=1
                        )
                    else:
                        fake_input = fake_data_batch
                        real_input_processed = real_data_processed

                    # Compute WGAN loss
                    d_real = self.discriminator(real_input_processed)
                    d_fake = self.discriminator(fake_input)

                    # Gradient Penalty
                    gp = compute_gradient_penalty(
                        self.discriminator,
                        real_input_processed,
                        fake_input,
                        self.device,
                    )

                    loss_D = -torch.mean(d_real) + torch.mean(d_fake) + 10.0 * gp

                    loss_D.backward()
                    torch.nn.utils.clip_grad_norm_(
                        self.discriminator.parameters(), max_norm=1.0
                    )
                    optimizer_D.step()

                # --- Train Generator ---
                # Train generator once after n_critic discriminator steps
                noise = torch.randn(
                    self.batch_size, self.embedding_dim, device=self.device
                )
                if context_data is not None:
                    if self.legacy_context_conditioning:
                        # Backwards-compatible: reuse last discriminator batch context
                        gen_context_batch = real_context_batch
                    else:
                        # Correct: independently sample fresh context for generator step
                        gen_ctx_idx = np.random.randint(
                            0, len(context_data), self.batch_size
                        )
                        gen_context_batch = context_data[gen_ctx_idx]
                    gen_input = torch.cat([noise, gen_context_batch], dim=1)
                else:
                    gen_input = noise

                fake_raw = self.generator(gen_input)

                # Apply Embeddings / Softmax (Same logic as above)
                fake_parts = []
                fake_ptr = 0
                for info in self.data_column_info:
                    if info["type"] == "embedding":
                        dim = info["num_categories"]
                        logits = fake_raw[:, fake_ptr : fake_ptr + dim]
                        fake_ptr += dim
                        probs = F.softmax(logits, dim=1)
                        emb_vect = self.embedding_layers[info["name"]].forward_soft(
                            probs
                        )
                        fake_parts.append(emb_vect)
                    else:
                        dim = info["output_dim"]
                        val = fake_raw[:, fake_ptr : fake_ptr + dim]
                        fake_ptr += dim
                        fake_parts.append(val)

                fake_data_batch = torch.cat(fake_parts, dim=1)

                if context_data is not None:
                    fake_input = torch.cat([fake_data_batch, gen_context_batch], dim=1)
                else:
                    fake_input = fake_data_batch

                d_fake = self.discriminator(fake_input)
                loss_G = -torch.mean(d_fake)

                optimizer_G.zero_grad()
                loss_G.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(self.generator.parameters())
                    + list(self.embedding_layers.parameters()),
                    max_norm=1.0,
                )
                optimizer_G.step()

            # --- Checkpointing & Logging ---
            current_loss_g = loss_G.item()
            current_loss_d = loss_D.item()

            # ETA calculation (linear extrapolation)
            _elapsed = time.time() - _start_time
            _epochs_done = epoch + 1
            _elapsed_per_epoch = _elapsed / _epochs_done
            _remaining_epochs = self.epochs - _epochs_done
            eta_seconds = (
                _elapsed_per_epoch * _remaining_epochs
            )  # 0.0 on final epoch — correct

            # Update tqdm postfix (visual only; fires regardless of disable state)
            pbar.set_postfix(
                {
                    "g_loss": f"{current_loss_g:.4f}",
                    "d_loss": f"{current_loss_d:.4f}",
                    "eta": f"{int(eta_seconds)}s",
                }
            )

            # Prepare epoch_end log fields (val_metric added below on checkpoint epochs)
            epoch_log_fields = dict(
                epoch=epoch,
                g_loss=current_loss_g,
                d_loss=current_loss_d,
                eta_seconds=eta_seconds,
            )

            # Checkpoint validation (every checkpoint_interval epochs, only when checkpoint_dir set)
            _is_checkpoint_epoch = (
                checkpoint_dir is not None and (epoch + 1) % checkpoint_interval == 0
            )

            if _is_checkpoint_epoch:
                if context_dim > 0:
                    log.debug(
                        "checkpoint_validation_skipped",
                        epoch=epoch,
                        note="Skipping checkpoint validation for context-conditioned model",
                    )
                    val_metric = float("inf")
                else:
                    # Generate a small validation sample from current generator state
                    self.generator.eval()
                    self.discriminator.eval()
                    with torch.no_grad():
                        val_synth = self.sample(min(len(data), 500))
                    self.generator.train()
                    self.discriminator.train()

                    # Align columns: drop columns not present in synthetic output (FK/PK)
                    real_for_val = data[
                        [c for c in data.columns if c in val_synth.columns]
                    ].copy()

                    results = _validator.compare_columns(real_for_val, val_synth)
                    stats = [
                        v["statistic"]
                        for v in results.values()
                        if isinstance(v, dict) and "statistic" in v
                    ]

                    if stats:
                        val_metric = sum(stats) / len(stats)
                        # Include val_metric in epoch_end only on checkpoint epochs
                        epoch_log_fields["val_metric"] = val_metric
                    else:
                        log.warning(
                            "checkpoint_validation_empty",
                            epoch=epoch,
                            note="compare_columns returned no valid stats — skipping checkpoint",
                        )
                        val_metric = float("inf")

                if val_metric < best_val_metric:
                    best_val_metric = val_metric
                    best_epoch = epoch
                    best_cp = os.path.join(checkpoint_dir, "best_checkpoint")
                    self.save(best_cp, overwrite=True)
                    best_checkpoint_path = best_cp
                    log.info(
                        "new_best_checkpoint",
                        epoch=epoch,
                        val_metric=val_metric,
                        path=best_cp,
                    )

            # Emit epoch_end event (always, independent of progress_bar flag)
            # val_metric included in epoch_log_fields only on checkpoint epochs
            log.info("epoch_end", **epoch_log_fields)

            if log_metrics:
                history.append(
                    {"epoch": epoch, "loss_g": current_loss_g, "loss_d": current_loss_d}
                )

        # End of training: Save final checkpoint and metrics
        if checkpoint_dir:
            final_cp = os.path.join(checkpoint_dir, "final_checkpoint")
            self.save(final_cp, overwrite=True)
            # If no validation checkpoint was ever saved (no checkpoint epoch ran),
            # fall back: treat final as best
            if best_checkpoint_path is None:
                best_checkpoint_path = final_cp
                best_epoch = self.epochs - 1
                best_val_metric = float("inf")

        # Emit training_complete with real validation-metric values (QUAL-03)
        log.info(
            "training_complete",
            best_epoch=best_epoch,
            best_val_metric=best_val_metric,
            total_epochs=self.epochs,
            checkpoint_path=str(best_checkpoint_path) if best_checkpoint_path else None,
        )

        if checkpoint_dir and log_metrics and history:
            metrics_path = os.path.join(checkpoint_dir, "training_metrics.csv")
            keys = history[0].keys()
            with open(metrics_path, "w", newline="") as f:
                dict_writer = csv.DictWriter(f, fieldnames=keys)
                dict_writer.writeheader()
                dict_writer.writerows(history)
            print(f"Training metrics saved to {metrics_path}")

    def sample(
        self,
        num_rows: int,
        context: Optional[pd.DataFrame] = None,
        seed: Optional[int] = None,
        enforce_constraints: bool = False,
        **kwargs: Any,
    ) -> pd.DataFrame:
        """Generate synthetic samples, optionally conditioned on parent context.

        Args:
            num_rows: Number of rows to generate.
            context: Optional parent attributes aligned to the requested rows.
            seed: Optional integer seed for deterministic sampling. Only applied
                  when provided; no auto-generation (fits and samples may use
                  independent seeds per CONTEXT.md decision).
            enforce_constraints: When True, inspects generated rows against column
                  constraints defined in the table's Metadata config.  Any rows
                  that violate a min/max constraint are dropped and a structlog
                  WARNING is emitted listing each violation.  When False (default),
                  constraint checking is skipped entirely — this matches the
                  pre-existing behavior where inverse_transform() already clips
                  values within each column's defined range.
            **kwargs: Additional sampling controls (unused placeholder).

        Returns:
            DataFrame of synthetic rows mapped back to original schema.
        """
        if seed is not None:
            _set_seed(seed)

        if num_rows <= 0:
            raise ValueError(f"num_rows must be positive, got {num_rows}")

        was_training = self.generator.training
        self.generator.eval()
        with torch.no_grad():
            noise = torch.randn(num_rows, self.embedding_dim, device=self.device)

            if context is not None:
                if len(context) != num_rows:
                    raise ValueError(
                        f"context must have exactly num_rows={num_rows} rows, "
                        f"got {len(context)}"
                    )

                # Transform context using the fitted context transformer
                context_transformed = self.context_transformer.transform(context)
                context_data = (
                    torch.from_numpy(context_transformed).float().to(self.device)
                )

                gen_input = torch.cat([noise, context_data], dim=1)
            else:
                gen_input = noise

            fake_raw = self.generator(gen_input)

            # Post-process logits to indices for output
            output_parts = []
            fake_ptr = 0
            for info in self.data_column_info:
                if info["type"] == "embedding":
                    dim = info["num_categories"]
                    logits = fake_raw[:, fake_ptr : fake_ptr + dim]
                    fake_ptr += dim

                    # Argmax to get index
                    indices = torch.argmax(logits, dim=1, keepdim=True)
                    output_parts.append(indices.cpu().numpy())
                else:
                    dim = info["output_dim"]
                    val = fake_raw[:, fake_ptr : fake_ptr + dim]
                    fake_ptr += dim
                    output_parts.append(val.cpu().numpy())

            fake_data_np = np.concatenate(output_parts, axis=1)

        result_df = self.transformer.inverse_transform(fake_data_np)

        # Constraint violation checking (opt-in via enforce_constraints=True).
        # Note: inverse_transform() already clips values within defined column ranges,
        # so enforce_constraints=True is primarily useful for post-hoc auditing or
        # catching any residual violations before returning rows to the caller.
        if enforce_constraints:
            table_config = None
            table_name = getattr(self.transformer, "table_name", None)
            if hasattr(self, "metadata") and table_name:
                try:
                    table_config = self.metadata.get_table(table_name)
                except Exception as exc:
                    log.warning(
                        "constraint_config_lookup_failed",
                        table_name=table_name,
                        error=str(exc),
                        note="Skipping constraint enforcement — table config could not be retrieved",
                    )
                    table_config = None

            # If the table has constraints defined, scan generated rows.
            # If no constraints are configured this block is a no-op.
            if table_config is not None and table_config.constraints:
                violations = []
                valid_mask = pd.Series([True] * len(result_df), index=result_df.index)

                for col_name, constraint in table_config.constraints.items():
                    if col_name not in result_df.columns:
                        continue
                    col_data = result_df[col_name]

                    # Check min constraint
                    min_val = constraint.min
                    if min_val is not None:
                        try:
                            col_numeric = pd.to_numeric(col_data, errors="coerce")
                            bad = col_numeric < min_val
                            if bad.any():
                                observed = col_numeric[bad].min()
                                violations.append(
                                    f"{col_name}: got {observed:.4g} (min={min_val})"
                                )
                                valid_mask &= ~bad
                        except Exception as exc:
                            log.warning(
                                "constraint_min_check_skipped",
                                column=col_name,
                                error=str(exc),
                                note="Column is non-numeric or comparison failed — skipping min check",
                            )

                    # Check max constraint
                    max_val = constraint.max
                    if max_val is not None:
                        try:
                            col_numeric = pd.to_numeric(col_data, errors="coerce")
                            bad = col_numeric > max_val
                            if bad.any():
                                observed = col_numeric[bad].max()
                                violations.append(
                                    f"{col_name}: got {observed:.4g} (max={max_val})"
                                )
                                valid_mask &= ~bad
                        except Exception as exc:
                            log.warning(
                                "constraint_max_check_skipped",
                                column=col_name,
                                error=str(exc),
                                note="Column is non-numeric or comparison failed — skipping max check",
                            )

                if violations:
                    summary = "; ".join(violations)
                    # ROADMAP success criterion 4 and REQUIREMENTS.md QUAL-04 require
                    # violations to "raise with the column name and observed value".
                    # Use sample(enforce_constraints=False) (the default) if you want
                    # the previous warn-and-return behavior.
                    raise ConstraintViolationError(
                        f"ConstraintViolationError: {len(violations)} violation(s) found — "
                        f"{summary}"
                    )

        if was_training:
            self.generator.train()

        return result_df

    def save(self, path: str, *, overwrite: bool = False) -> None:
        """Persist full model state to a directory checkpoint.

        Saves all components required for a cold load-and-sample without the
        original training data: network weights, DataTransformer state,
        context_transformer state, embedding layer weights, column layout, and
        human-readable metadata.

        The directory contains:
            - generator.pt — generator state_dict
            - discriminator.pt — discriminator state_dict
            - transformer.joblib — fitted DataTransformer for child table
            - context_transformer.joblib — fitted DataTransformer for context
            - embedding_layers.joblib — nn.ModuleDict with entity embedding weights
            - data_column_info.joblib — column layout list
            - metadata.json — hyperparameters and version info

        Args:
            path: Directory path to save into.
            overwrite: If False (default), raises SerializationError if path already exists.

        Raises:
            SerializationError: If path exists and overwrite=False, or if any
                component fails to serialize.
        """
        import joblib
        import json
        from pathlib import Path
        from datetime import datetime, timezone

        p = Path(path)
        if p.exists() and not overwrite:
            raise SerializationError(
                f"SerializationError: Save path '{path}' already exists. "
                f"Pass overwrite=True to replace it."
            )

        try:
            p.mkdir(parents=True, exist_ok=True)

            # Network weights — torch native format
            torch.save(self.generator.state_dict(), p / "generator.pt")
            torch.save(self.discriminator.state_dict(), p / "discriminator.pt")

            # sklearn and numpy-heavy objects — joblib for efficient NumPy serialization
            joblib.dump(self.transformer, p / "transformer.joblib")
            joblib.dump(self.context_transformer, p / "context_transformer.joblib")

            # Embedding layers (nn.ModuleDict) — joblib serializes via pickle
            joblib.dump(self.embedding_layers, p / "embedding_layers.joblib")

            # Column layout list (list of dicts describing each column)
            joblib.dump(self.data_column_info, p / "data_column_info.joblib")

            # Metadata — human-readable, enables version mismatch detection on load
            try:
                from syntho_hive import __version__

                current_version = __version__
            except Exception as exc:
                log.warning(
                    "version_lookup_failed",
                    error=str(exc),
                    note="Could not determine SynthoHive version — using 'unknown'",
                )
                current_version = "unknown"

            meta = {
                "synthohive_version": current_version,
                "embedding_dim": self.embedding_dim,
                "generator_dim": list(self.generator_dim),
                "discriminator_dim": list(self.discriminator_dim),
                "legacy_context_conditioning": self.legacy_context_conditioning,
                "saved_at": datetime.now(timezone.utc).isoformat(),
            }
            with open(p / "metadata.json", "w") as f:
                json.dump(meta, f, indent=2)

            log.info("model_saved", path=str(p))

        except SerializationError:
            raise
        except Exception as exc:
            raise SerializationError(
                f"SerializationError: Failed to save model to '{path}'. "
                f"Original error: {exc}"
            ) from exc

    def load(self, path: str) -> None:
        """Load full model state from a directory checkpoint.

        Reconstructs the complete model — DataTransformer, context_transformer,
        embedding_layers, column layout, and network weights — without requiring
        the original training data.

        Args:
            path: Directory path produced by save().

        Raises:
            SerializationError: If path does not exist, is missing required files,
                or if any component fails to deserialize.
        """
        import joblib
        import json
        from pathlib import Path

        p = Path(path)
        if not p.exists():
            raise SerializationError(
                f"SerializationError: Checkpoint path '{path}' does not exist."
            )

        required_files = [
            "generator.pt",
            "discriminator.pt",
            "transformer.joblib",
            "context_transformer.joblib",
            "embedding_layers.joblib",
            "data_column_info.joblib",
        ]
        missing = [f for f in required_files if not (p / f).exists()]
        if missing:
            raise SerializationError(
                f"SerializationError: Checkpoint at '{path}' is incomplete. "
                f"Missing files: {', '.join(missing)}. "
                f"The checkpoint may have been saved by an older version or is corrupt."
            )

        saved_version = "unknown"
        try:
            # Version check — warn but do not fail
            meta_path = p / "metadata.json"
            if meta_path.exists():
                with open(meta_path) as f:
                    meta = json.load(f)
                try:
                    from syntho_hive import __version__

                    current_version = __version__
                except Exception as exc:
                    log.warning(
                        "version_lookup_failed",
                        error=str(exc),
                        note="Could not determine SynthoHive version — using 'unknown'",
                    )
                    current_version = "unknown"
                saved_version = meta.get("synthohive_version", "unknown")
                if saved_version != current_version:
                    log.warning(
                        "checkpoint_version_mismatch",
                        saved_version=saved_version,
                        current_version=current_version,
                        path=str(p),
                        note="Attempting load — schema changes between versions may cause failures",
                    )
                # Restore hyperparams from metadata so _build_model() uses correct dims
                if "embedding_dim" in meta:
                    self.embedding_dim = meta["embedding_dim"]
                if "generator_dim" in meta:
                    self.generator_dim = tuple(meta["generator_dim"])
                if "discriminator_dim" in meta:
                    self.discriminator_dim = tuple(meta["discriminator_dim"])
                # Default False for forward compatibility with old checkpoints that lack this key
                self.legacy_context_conditioning = meta.get(
                    "legacy_context_conditioning", False
                )

            # Load sklearn objects first — transformer must be in place before _build_model()
            self.transformer = joblib.load(p / "transformer.joblib")
            self.context_transformer = joblib.load(p / "context_transformer.joblib")

            # Load saved column layout and embedding layers (will be restored after _build_model)
            saved_data_column_info = joblib.load(p / "data_column_info.joblib")
            saved_embedding_layers = joblib.load(p / "embedding_layers.joblib")

            # Validate transformer round-trip integrity
            if (
                not hasattr(self.transformer, "output_dim")
                or self.transformer.output_dim <= 0
            ):
                raise SerializationError(
                    f"SerializationError: Loaded transformer has invalid output_dim "
                    f"({getattr(self.transformer, 'output_dim', 'missing')}). "
                    f"The checkpoint may be corrupt."
                )

            # Derive dimensions needed to reconstruct the generator/discriminator architecture.
            # context_transformer.output_dim is 0 when no context was used during training.
            data_dim = self.transformer.output_dim
            context_dim = getattr(self.context_transformer, "output_dim", 0)

            # Reconstruct generator/discriminator architecture.
            # _build_model() internally calls _compile_layout(self.transformer) which overwrites
            # self.data_column_info and self.embedding_layers with freshly-initialised layers.
            # We restore the saved values immediately after so weights can be loaded correctly.
            self._build_model(data_dim, context_dim)

            # Restore saved column layout and trained embedding weights (overwrite fresh ones)
            self.data_column_info = saved_data_column_info
            self.embedding_layers = saved_embedding_layers

            # Load network weights — weights_only=False REQUIRED for PyTorch 2.6+
            # (PyTorch 2.6 changed default to weights_only=True; custom objects fail without False)
            # SECURITY WARNING: weights_only=False uses pickle deserialization under the hood,
            # which can execute arbitrary code. Only load checkpoints from trusted sources.
            # Restructuring to weights_only=True requires registering all custom types with
            # torch.serialization.add_safe_globals() and is a non-trivial migration.
            self.generator.load_state_dict(
                torch.load(p / "generator.pt", weights_only=False)
            )
            self.discriminator.load_state_dict(
                torch.load(p / "discriminator.pt", weights_only=False)
            )

            # Set model to eval mode for inference
            self.generator.eval()
            self.discriminator.eval()

            log.info("model_loaded", path=str(p), version=saved_version)

        except SerializationError:
            raise
        except Exception as exc:
            raise SerializationError(
                f"SerializationError: Failed to load model from '{path}'. "
                f"Original error: {exc}"
            ) from exc

fit

fit(data: DataFrame, context: Optional[DataFrame] = None, table_name: Optional[str] = None, checkpoint_dir: Optional[str] = None, log_metrics: bool = True, seed: Optional[int] = None, progress_bar: bool = True, checkpoint_interval: int = 10, **kwargs: Any) -> None

Train the CTGAN model on tabular data.

Parameters:

Name Type Description Default
data DataFrame

Child table data (target) to model.

required
context Optional[DataFrame]

Parent attributes to condition on (aligned row-wise).

None
table_name Optional[str]

Table name for metadata lookup and constraint handling.

None
checkpoint_dir Optional[str]

Directory to save checkpoints (best model, metrics). Defaults to None.

None
log_metrics bool

Whether to save training metrics to a CSV file. Defaults to True.

True
seed Optional[int]

Integer seed for deterministic training. When None, an integer is auto-generated and logged so the run can be reproduced later.

None
progress_bar bool

If True (default), display a tqdm progress bar to stderr during training. Structured log events always emit regardless of this flag.

True
checkpoint_interval int

Save a validation checkpoint every N epochs. Default 10.

10
**kwargs Any

Extra training options (unused placeholder for compatibility).

{}
Source code in syntho_hive/core/models/ctgan.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def fit(
    self,
    data: pd.DataFrame,
    context: Optional[pd.DataFrame] = None,
    table_name: Optional[str] = None,
    checkpoint_dir: Optional[str] = None,
    log_metrics: bool = True,
    seed: Optional[int] = None,
    progress_bar: bool = True,
    checkpoint_interval: int = 10,
    **kwargs: Any,
) -> None:
    """Train the CTGAN model on tabular data.

    Args:
        data: Child table data (target) to model.
        context: Parent attributes to condition on (aligned row-wise).
        table_name: Table name for metadata lookup and constraint handling.
        checkpoint_dir: Directory to save checkpoints (best model, metrics). Defaults to None.
        log_metrics: Whether to save training metrics to a CSV file. Defaults to True.
        seed: Integer seed for deterministic training. When None, an integer is
              auto-generated and logged so the run can be reproduced later.
        progress_bar: If True (default), display a tqdm progress bar to stderr during
              training. Structured log events always emit regardless of this flag.
        checkpoint_interval: Save a validation checkpoint every N epochs. Default 10.
        **kwargs: Extra training options (unused placeholder for compatibility).
    """
    import random as _random

    # Seed handling — auto-generate when not provided so every run is reproducible.
    if seed is None:
        seed = _random.randint(0, 2**31 - 1)
        log.info(
            "training_seed",
            seed=seed,
            message="No seed provided — auto-generated. Log this value to reproduce this run.",
        )
    else:
        log.info("training_seed", seed=seed)

    _set_seed(seed)

    # 0. Setup Checkpointing
    if checkpoint_dir:
        os.makedirs(checkpoint_dir, exist_ok=True)

    history = []

    # Validation-metric checkpoint state (QUAL-03)
    _validator = None
    best_val_metric = float("inf")
    best_epoch = -1
    best_checkpoint_path = None

    if checkpoint_dir:
        from syntho_hive.validation.statistical import StatisticalValidator

        _validator = StatisticalValidator()
    # 1. Fit and Transform Data
    self.transformer.fit(data, table_name=table_name, seed=seed)
    train_data = self.transformer.transform(data)
    train_data = torch.from_numpy(train_data).float().to(self.device)

    # 2. Handle Context
    if context is not None:
        if len(data) != len(context):
            raise ValueError(
                f"Data and context must have same number of rows, "
                f"got {len(data)} and {len(context)}"
            )

        # Use dedicated transformer for context
        # NOTE: We abuse metdata here slightly. Ideally context comes from a known table (Parent).
        # But context might be a mix of parent columns.
        # For fit, we pass table_name=None to fit on just the columns present in context df.
        self.context_transformer.fit(context)
        context_transformed = self.context_transformer.transform(context)
        context_data = torch.from_numpy(context_transformed).float().to(self.device)
        context_dim = context_data.shape[1]
    else:
        context_data = None
        context_dim = 0

    data_dim = train_data.shape[1]

    # 3. Build Model
    if self.generator is None:
        self._build_model(data_dim, context_dim)

    all_gen_params = list(self.generator.parameters()) + list(
        self.embedding_layers.parameters()
    )
    optimizer_G = optim.Adam(all_gen_params, lr=2e-4, betas=(0.5, 0.9))
    optimizer_D = optim.Adam(
        self.discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9)
    )

    # 4. Training Loop (WGAN-GP)
    steps_per_epoch = max(len(train_data) // self.batch_size, 1)

    # Emit training_start event
    log.info(
        "training_start",
        total_epochs=self.epochs,
        batch_size=self.batch_size,
        embedding_dim=self.embedding_dim,
        checkpoint_interval=checkpoint_interval,
    )
    _start_time = time.time()

    # Replace bare for-loop with trange (disable=True suppresses bar; log events always fire)
    pbar = trange(
        self.epochs,
        desc="Training",
        file=sys.stderr,
        leave=True,
        disable=not progress_bar,
    )

    for epoch in pbar:
        for i in range(steps_per_epoch):
            # --- Train Discriminator ---
            for _ in range(self.discriminator_steps):
                optimizer_D.zero_grad()

                # Sample real data
                idx = np.random.randint(0, len(train_data), self.batch_size)
                real_data_batch = train_data[idx]
                if context_data is not None:
                    real_context_batch = context_data[idx]
                    real_input = torch.cat(
                        [real_data_batch, real_context_batch], dim=1
                    )
                else:
                    real_context_batch = None
                    real_input = real_data_batch

                # Generate fake data
                noise = torch.randn(
                    self.batch_size, self.embedding_dim, device=self.device
                )
                if real_context_batch is not None:
                    gen_input = torch.cat([noise, real_context_batch], dim=1)
                else:
                    gen_input = noise

                fake_raw = self.generator(gen_input)

                # Apply Embeddings / Softmax to Fake Data
                fake_parts = []
                fake_ptr = 0
                for info in self.data_column_info:
                    if info["type"] == "embedding":
                        dim = info["num_categories"]
                        logits = fake_raw[:, fake_ptr : fake_ptr + dim]
                        fake_ptr += dim

                        # Gumbel Softmax or Softmax? WGAN prefers generic softmax for differentiability
                        # Note: Gumbel Softmax allows hard sampling with gradients.
                        probs = F.softmax(logits, dim=1)
                        emb_vect = self.embedding_layers[info["name"]].forward_soft(
                            probs
                        )
                        fake_parts.append(emb_vect)
                    else:
                        dim = info["output_dim"]
                        val = fake_raw[:, fake_ptr : fake_ptr + dim]
                        fake_ptr += dim
                        fake_parts.append(val)

                fake_data_batch = torch.cat(fake_parts, dim=1)

                # Apply Embeddings to Real Data
                real_parts = []
                real_ptr = 0
                # Need to iterate column info again to slice real data correctly
                # Real data from transformer is concatenated (Indices, Values...)
                for info in self.data_column_info:
                    dim = info["input_dim"]  # 1 for embedding (index)
                    col_data = real_data_batch[:, real_ptr : real_ptr + dim]
                    real_ptr += dim

                    if info["type"] == "embedding":
                        emb_vect = self.embedding_layers[info["name"]](
                            col_data.long().squeeze(1)
                        )
                        real_parts.append(emb_vect)
                    else:
                        real_parts.append(col_data)

                real_data_processed = torch.cat(real_parts, dim=1)

                if real_context_batch is not None:
                    fake_input = torch.cat(
                        [fake_data_batch, real_context_batch], dim=1
                    )
                    real_input_processed = torch.cat(
                        [real_data_processed, real_context_batch], dim=1
                    )
                else:
                    fake_input = fake_data_batch
                    real_input_processed = real_data_processed

                # Compute WGAN loss
                d_real = self.discriminator(real_input_processed)
                d_fake = self.discriminator(fake_input)

                # Gradient Penalty
                gp = compute_gradient_penalty(
                    self.discriminator,
                    real_input_processed,
                    fake_input,
                    self.device,
                )

                loss_D = -torch.mean(d_real) + torch.mean(d_fake) + 10.0 * gp

                loss_D.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.discriminator.parameters(), max_norm=1.0
                )
                optimizer_D.step()

            # --- Train Generator ---
            # Train generator once after n_critic discriminator steps
            noise = torch.randn(
                self.batch_size, self.embedding_dim, device=self.device
            )
            if context_data is not None:
                if self.legacy_context_conditioning:
                    # Backwards-compatible: reuse last discriminator batch context
                    gen_context_batch = real_context_batch
                else:
                    # Correct: independently sample fresh context for generator step
                    gen_ctx_idx = np.random.randint(
                        0, len(context_data), self.batch_size
                    )
                    gen_context_batch = context_data[gen_ctx_idx]
                gen_input = torch.cat([noise, gen_context_batch], dim=1)
            else:
                gen_input = noise

            fake_raw = self.generator(gen_input)

            # Apply Embeddings / Softmax (Same logic as above)
            fake_parts = []
            fake_ptr = 0
            for info in self.data_column_info:
                if info["type"] == "embedding":
                    dim = info["num_categories"]
                    logits = fake_raw[:, fake_ptr : fake_ptr + dim]
                    fake_ptr += dim
                    probs = F.softmax(logits, dim=1)
                    emb_vect = self.embedding_layers[info["name"]].forward_soft(
                        probs
                    )
                    fake_parts.append(emb_vect)
                else:
                    dim = info["output_dim"]
                    val = fake_raw[:, fake_ptr : fake_ptr + dim]
                    fake_ptr += dim
                    fake_parts.append(val)

            fake_data_batch = torch.cat(fake_parts, dim=1)

            if context_data is not None:
                fake_input = torch.cat([fake_data_batch, gen_context_batch], dim=1)
            else:
                fake_input = fake_data_batch

            d_fake = self.discriminator(fake_input)
            loss_G = -torch.mean(d_fake)

            optimizer_G.zero_grad()
            loss_G.backward()
            torch.nn.utils.clip_grad_norm_(
                list(self.generator.parameters())
                + list(self.embedding_layers.parameters()),
                max_norm=1.0,
            )
            optimizer_G.step()

        # --- Checkpointing & Logging ---
        current_loss_g = loss_G.item()
        current_loss_d = loss_D.item()

        # ETA calculation (linear extrapolation)
        _elapsed = time.time() - _start_time
        _epochs_done = epoch + 1
        _elapsed_per_epoch = _elapsed / _epochs_done
        _remaining_epochs = self.epochs - _epochs_done
        eta_seconds = (
            _elapsed_per_epoch * _remaining_epochs
        )  # 0.0 on final epoch — correct

        # Update tqdm postfix (visual only; fires regardless of disable state)
        pbar.set_postfix(
            {
                "g_loss": f"{current_loss_g:.4f}",
                "d_loss": f"{current_loss_d:.4f}",
                "eta": f"{int(eta_seconds)}s",
            }
        )

        # Prepare epoch_end log fields (val_metric added below on checkpoint epochs)
        epoch_log_fields = dict(
            epoch=epoch,
            g_loss=current_loss_g,
            d_loss=current_loss_d,
            eta_seconds=eta_seconds,
        )

        # Checkpoint validation (every checkpoint_interval epochs, only when checkpoint_dir set)
        _is_checkpoint_epoch = (
            checkpoint_dir is not None and (epoch + 1) % checkpoint_interval == 0
        )

        if _is_checkpoint_epoch:
            if context_dim > 0:
                log.debug(
                    "checkpoint_validation_skipped",
                    epoch=epoch,
                    note="Skipping checkpoint validation for context-conditioned model",
                )
                val_metric = float("inf")
            else:
                # Generate a small validation sample from current generator state
                self.generator.eval()
                self.discriminator.eval()
                with torch.no_grad():
                    val_synth = self.sample(min(len(data), 500))
                self.generator.train()
                self.discriminator.train()

                # Align columns: drop columns not present in synthetic output (FK/PK)
                real_for_val = data[
                    [c for c in data.columns if c in val_synth.columns]
                ].copy()

                results = _validator.compare_columns(real_for_val, val_synth)
                stats = [
                    v["statistic"]
                    for v in results.values()
                    if isinstance(v, dict) and "statistic" in v
                ]

                if stats:
                    val_metric = sum(stats) / len(stats)
                    # Include val_metric in epoch_end only on checkpoint epochs
                    epoch_log_fields["val_metric"] = val_metric
                else:
                    log.warning(
                        "checkpoint_validation_empty",
                        epoch=epoch,
                        note="compare_columns returned no valid stats — skipping checkpoint",
                    )
                    val_metric = float("inf")

            if val_metric < best_val_metric:
                best_val_metric = val_metric
                best_epoch = epoch
                best_cp = os.path.join(checkpoint_dir, "best_checkpoint")
                self.save(best_cp, overwrite=True)
                best_checkpoint_path = best_cp
                log.info(
                    "new_best_checkpoint",
                    epoch=epoch,
                    val_metric=val_metric,
                    path=best_cp,
                )

        # Emit epoch_end event (always, independent of progress_bar flag)
        # val_metric included in epoch_log_fields only on checkpoint epochs
        log.info("epoch_end", **epoch_log_fields)

        if log_metrics:
            history.append(
                {"epoch": epoch, "loss_g": current_loss_g, "loss_d": current_loss_d}
            )

    # End of training: Save final checkpoint and metrics
    if checkpoint_dir:
        final_cp = os.path.join(checkpoint_dir, "final_checkpoint")
        self.save(final_cp, overwrite=True)
        # If no validation checkpoint was ever saved (no checkpoint epoch ran),
        # fall back: treat final as best
        if best_checkpoint_path is None:
            best_checkpoint_path = final_cp
            best_epoch = self.epochs - 1
            best_val_metric = float("inf")

    # Emit training_complete with real validation-metric values (QUAL-03)
    log.info(
        "training_complete",
        best_epoch=best_epoch,
        best_val_metric=best_val_metric,
        total_epochs=self.epochs,
        checkpoint_path=str(best_checkpoint_path) if best_checkpoint_path else None,
    )

    if checkpoint_dir and log_metrics and history:
        metrics_path = os.path.join(checkpoint_dir, "training_metrics.csv")
        keys = history[0].keys()
        with open(metrics_path, "w", newline="") as f:
            dict_writer = csv.DictWriter(f, fieldnames=keys)
            dict_writer.writeheader()
            dict_writer.writerows(history)
        print(f"Training metrics saved to {metrics_path}")

sample

sample(num_rows: int, context: Optional[DataFrame] = None, seed: Optional[int] = None, enforce_constraints: bool = False, **kwargs: Any) -> pd.DataFrame

Generate synthetic samples, optionally conditioned on parent context.

Parameters:

Name Type Description Default
num_rows int

Number of rows to generate.

required
context Optional[DataFrame]

Optional parent attributes aligned to the requested rows.

None
seed Optional[int]

Optional integer seed for deterministic sampling. Only applied when provided; no auto-generation (fits and samples may use independent seeds per CONTEXT.md decision).

None
enforce_constraints bool

When True, inspects generated rows against column constraints defined in the table's Metadata config. Any rows that violate a min/max constraint are dropped and a structlog WARNING is emitted listing each violation. When False (default), constraint checking is skipped entirely — this matches the pre-existing behavior where inverse_transform() already clips values within each column's defined range.

False
**kwargs Any

Additional sampling controls (unused placeholder).

{}

Returns:

Type Description
DataFrame

DataFrame of synthetic rows mapped back to original schema.

Source code in syntho_hive/core/models/ctgan.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
def sample(
    self,
    num_rows: int,
    context: Optional[pd.DataFrame] = None,
    seed: Optional[int] = None,
    enforce_constraints: bool = False,
    **kwargs: Any,
) -> pd.DataFrame:
    """Generate synthetic samples, optionally conditioned on parent context.

    Args:
        num_rows: Number of rows to generate.
        context: Optional parent attributes aligned to the requested rows.
        seed: Optional integer seed for deterministic sampling. Only applied
              when provided; no auto-generation (fits and samples may use
              independent seeds per CONTEXT.md decision).
        enforce_constraints: When True, inspects generated rows against column
              constraints defined in the table's Metadata config.  Any rows
              that violate a min/max constraint are dropped and a structlog
              WARNING is emitted listing each violation.  When False (default),
              constraint checking is skipped entirely — this matches the
              pre-existing behavior where inverse_transform() already clips
              values within each column's defined range.
        **kwargs: Additional sampling controls (unused placeholder).

    Returns:
        DataFrame of synthetic rows mapped back to original schema.
    """
    if seed is not None:
        _set_seed(seed)

    if num_rows <= 0:
        raise ValueError(f"num_rows must be positive, got {num_rows}")

    was_training = self.generator.training
    self.generator.eval()
    with torch.no_grad():
        noise = torch.randn(num_rows, self.embedding_dim, device=self.device)

        if context is not None:
            if len(context) != num_rows:
                raise ValueError(
                    f"context must have exactly num_rows={num_rows} rows, "
                    f"got {len(context)}"
                )

            # Transform context using the fitted context transformer
            context_transformed = self.context_transformer.transform(context)
            context_data = (
                torch.from_numpy(context_transformed).float().to(self.device)
            )

            gen_input = torch.cat([noise, context_data], dim=1)
        else:
            gen_input = noise

        fake_raw = self.generator(gen_input)

        # Post-process logits to indices for output
        output_parts = []
        fake_ptr = 0
        for info in self.data_column_info:
            if info["type"] == "embedding":
                dim = info["num_categories"]
                logits = fake_raw[:, fake_ptr : fake_ptr + dim]
                fake_ptr += dim

                # Argmax to get index
                indices = torch.argmax(logits, dim=1, keepdim=True)
                output_parts.append(indices.cpu().numpy())
            else:
                dim = info["output_dim"]
                val = fake_raw[:, fake_ptr : fake_ptr + dim]
                fake_ptr += dim
                output_parts.append(val.cpu().numpy())

        fake_data_np = np.concatenate(output_parts, axis=1)

    result_df = self.transformer.inverse_transform(fake_data_np)

    # Constraint violation checking (opt-in via enforce_constraints=True).
    # Note: inverse_transform() already clips values within defined column ranges,
    # so enforce_constraints=True is primarily useful for post-hoc auditing or
    # catching any residual violations before returning rows to the caller.
    if enforce_constraints:
        table_config = None
        table_name = getattr(self.transformer, "table_name", None)
        if hasattr(self, "metadata") and table_name:
            try:
                table_config = self.metadata.get_table(table_name)
            except Exception as exc:
                log.warning(
                    "constraint_config_lookup_failed",
                    table_name=table_name,
                    error=str(exc),
                    note="Skipping constraint enforcement — table config could not be retrieved",
                )
                table_config = None

        # If the table has constraints defined, scan generated rows.
        # If no constraints are configured this block is a no-op.
        if table_config is not None and table_config.constraints:
            violations = []
            valid_mask = pd.Series([True] * len(result_df), index=result_df.index)

            for col_name, constraint in table_config.constraints.items():
                if col_name not in result_df.columns:
                    continue
                col_data = result_df[col_name]

                # Check min constraint
                min_val = constraint.min
                if min_val is not None:
                    try:
                        col_numeric = pd.to_numeric(col_data, errors="coerce")
                        bad = col_numeric < min_val
                        if bad.any():
                            observed = col_numeric[bad].min()
                            violations.append(
                                f"{col_name}: got {observed:.4g} (min={min_val})"
                            )
                            valid_mask &= ~bad
                    except Exception as exc:
                        log.warning(
                            "constraint_min_check_skipped",
                            column=col_name,
                            error=str(exc),
                            note="Column is non-numeric or comparison failed — skipping min check",
                        )

                # Check max constraint
                max_val = constraint.max
                if max_val is not None:
                    try:
                        col_numeric = pd.to_numeric(col_data, errors="coerce")
                        bad = col_numeric > max_val
                        if bad.any():
                            observed = col_numeric[bad].max()
                            violations.append(
                                f"{col_name}: got {observed:.4g} (max={max_val})"
                            )
                            valid_mask &= ~bad
                    except Exception as exc:
                        log.warning(
                            "constraint_max_check_skipped",
                            column=col_name,
                            error=str(exc),
                            note="Column is non-numeric or comparison failed — skipping max check",
                        )

            if violations:
                summary = "; ".join(violations)
                # ROADMAP success criterion 4 and REQUIREMENTS.md QUAL-04 require
                # violations to "raise with the column name and observed value".
                # Use sample(enforce_constraints=False) (the default) if you want
                # the previous warn-and-return behavior.
                raise ConstraintViolationError(
                    f"ConstraintViolationError: {len(violations)} violation(s) found — "
                    f"{summary}"
                )

    if was_training:
        self.generator.train()

    return result_df

save

save(path: str, *, overwrite: bool = False) -> None

Persist full model state to a directory checkpoint.

Saves all components required for a cold load-and-sample without the original training data: network weights, DataTransformer state, context_transformer state, embedding layer weights, column layout, and human-readable metadata.

The directory contains
  • generator.pt — generator state_dict
  • discriminator.pt — discriminator state_dict
  • transformer.joblib — fitted DataTransformer for child table
  • context_transformer.joblib — fitted DataTransformer for context
  • embedding_layers.joblib — nn.ModuleDict with entity embedding weights
  • data_column_info.joblib — column layout list
  • metadata.json — hyperparameters and version info

Parameters:

Name Type Description Default
path str

Directory path to save into.

required
overwrite bool

If False (default), raises SerializationError if path already exists.

False

Raises:

Type Description
SerializationError

If path exists and overwrite=False, or if any component fails to serialize.

Source code in syntho_hive/core/models/ctgan.py
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
def save(self, path: str, *, overwrite: bool = False) -> None:
    """Persist full model state to a directory checkpoint.

    Saves all components required for a cold load-and-sample without the
    original training data: network weights, DataTransformer state,
    context_transformer state, embedding layer weights, column layout, and
    human-readable metadata.

    The directory contains:
        - generator.pt — generator state_dict
        - discriminator.pt — discriminator state_dict
        - transformer.joblib — fitted DataTransformer for child table
        - context_transformer.joblib — fitted DataTransformer for context
        - embedding_layers.joblib — nn.ModuleDict with entity embedding weights
        - data_column_info.joblib — column layout list
        - metadata.json — hyperparameters and version info

    Args:
        path: Directory path to save into.
        overwrite: If False (default), raises SerializationError if path already exists.

    Raises:
        SerializationError: If path exists and overwrite=False, or if any
            component fails to serialize.
    """
    import joblib
    import json
    from pathlib import Path
    from datetime import datetime, timezone

    p = Path(path)
    if p.exists() and not overwrite:
        raise SerializationError(
            f"SerializationError: Save path '{path}' already exists. "
            f"Pass overwrite=True to replace it."
        )

    try:
        p.mkdir(parents=True, exist_ok=True)

        # Network weights — torch native format
        torch.save(self.generator.state_dict(), p / "generator.pt")
        torch.save(self.discriminator.state_dict(), p / "discriminator.pt")

        # sklearn and numpy-heavy objects — joblib for efficient NumPy serialization
        joblib.dump(self.transformer, p / "transformer.joblib")
        joblib.dump(self.context_transformer, p / "context_transformer.joblib")

        # Embedding layers (nn.ModuleDict) — joblib serializes via pickle
        joblib.dump(self.embedding_layers, p / "embedding_layers.joblib")

        # Column layout list (list of dicts describing each column)
        joblib.dump(self.data_column_info, p / "data_column_info.joblib")

        # Metadata — human-readable, enables version mismatch detection on load
        try:
            from syntho_hive import __version__

            current_version = __version__
        except Exception as exc:
            log.warning(
                "version_lookup_failed",
                error=str(exc),
                note="Could not determine SynthoHive version — using 'unknown'",
            )
            current_version = "unknown"

        meta = {
            "synthohive_version": current_version,
            "embedding_dim": self.embedding_dim,
            "generator_dim": list(self.generator_dim),
            "discriminator_dim": list(self.discriminator_dim),
            "legacy_context_conditioning": self.legacy_context_conditioning,
            "saved_at": datetime.now(timezone.utc).isoformat(),
        }
        with open(p / "metadata.json", "w") as f:
            json.dump(meta, f, indent=2)

        log.info("model_saved", path=str(p))

    except SerializationError:
        raise
    except Exception as exc:
        raise SerializationError(
            f"SerializationError: Failed to save model to '{path}'. "
            f"Original error: {exc}"
        ) from exc

load

load(path: str) -> None

Load full model state from a directory checkpoint.

Reconstructs the complete model — DataTransformer, context_transformer, embedding_layers, column layout, and network weights — without requiring the original training data.

Parameters:

Name Type Description Default
path str

Directory path produced by save().

required

Raises:

Type Description
SerializationError

If path does not exist, is missing required files, or if any component fails to deserialize.

Source code in syntho_hive/core/models/ctgan.py
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
def load(self, path: str) -> None:
    """Load full model state from a directory checkpoint.

    Reconstructs the complete model — DataTransformer, context_transformer,
    embedding_layers, column layout, and network weights — without requiring
    the original training data.

    Args:
        path: Directory path produced by save().

    Raises:
        SerializationError: If path does not exist, is missing required files,
            or if any component fails to deserialize.
    """
    import joblib
    import json
    from pathlib import Path

    p = Path(path)
    if not p.exists():
        raise SerializationError(
            f"SerializationError: Checkpoint path '{path}' does not exist."
        )

    required_files = [
        "generator.pt",
        "discriminator.pt",
        "transformer.joblib",
        "context_transformer.joblib",
        "embedding_layers.joblib",
        "data_column_info.joblib",
    ]
    missing = [f for f in required_files if not (p / f).exists()]
    if missing:
        raise SerializationError(
            f"SerializationError: Checkpoint at '{path}' is incomplete. "
            f"Missing files: {', '.join(missing)}. "
            f"The checkpoint may have been saved by an older version or is corrupt."
        )

    saved_version = "unknown"
    try:
        # Version check — warn but do not fail
        meta_path = p / "metadata.json"
        if meta_path.exists():
            with open(meta_path) as f:
                meta = json.load(f)
            try:
                from syntho_hive import __version__

                current_version = __version__
            except Exception as exc:
                log.warning(
                    "version_lookup_failed",
                    error=str(exc),
                    note="Could not determine SynthoHive version — using 'unknown'",
                )
                current_version = "unknown"
            saved_version = meta.get("synthohive_version", "unknown")
            if saved_version != current_version:
                log.warning(
                    "checkpoint_version_mismatch",
                    saved_version=saved_version,
                    current_version=current_version,
                    path=str(p),
                    note="Attempting load — schema changes between versions may cause failures",
                )
            # Restore hyperparams from metadata so _build_model() uses correct dims
            if "embedding_dim" in meta:
                self.embedding_dim = meta["embedding_dim"]
            if "generator_dim" in meta:
                self.generator_dim = tuple(meta["generator_dim"])
            if "discriminator_dim" in meta:
                self.discriminator_dim = tuple(meta["discriminator_dim"])
            # Default False for forward compatibility with old checkpoints that lack this key
            self.legacy_context_conditioning = meta.get(
                "legacy_context_conditioning", False
            )

        # Load sklearn objects first — transformer must be in place before _build_model()
        self.transformer = joblib.load(p / "transformer.joblib")
        self.context_transformer = joblib.load(p / "context_transformer.joblib")

        # Load saved column layout and embedding layers (will be restored after _build_model)
        saved_data_column_info = joblib.load(p / "data_column_info.joblib")
        saved_embedding_layers = joblib.load(p / "embedding_layers.joblib")

        # Validate transformer round-trip integrity
        if (
            not hasattr(self.transformer, "output_dim")
            or self.transformer.output_dim <= 0
        ):
            raise SerializationError(
                f"SerializationError: Loaded transformer has invalid output_dim "
                f"({getattr(self.transformer, 'output_dim', 'missing')}). "
                f"The checkpoint may be corrupt."
            )

        # Derive dimensions needed to reconstruct the generator/discriminator architecture.
        # context_transformer.output_dim is 0 when no context was used during training.
        data_dim = self.transformer.output_dim
        context_dim = getattr(self.context_transformer, "output_dim", 0)

        # Reconstruct generator/discriminator architecture.
        # _build_model() internally calls _compile_layout(self.transformer) which overwrites
        # self.data_column_info and self.embedding_layers with freshly-initialised layers.
        # We restore the saved values immediately after so weights can be loaded correctly.
        self._build_model(data_dim, context_dim)

        # Restore saved column layout and trained embedding weights (overwrite fresh ones)
        self.data_column_info = saved_data_column_info
        self.embedding_layers = saved_embedding_layers

        # Load network weights — weights_only=False REQUIRED for PyTorch 2.6+
        # (PyTorch 2.6 changed default to weights_only=True; custom objects fail without False)
        # SECURITY WARNING: weights_only=False uses pickle deserialization under the hood,
        # which can execute arbitrary code. Only load checkpoints from trusted sources.
        # Restructuring to weights_only=True requires registering all custom types with
        # torch.serialization.add_safe_globals() and is a non-trivial migration.
        self.generator.load_state_dict(
            torch.load(p / "generator.pt", weights_only=False)
        )
        self.discriminator.load_state_dict(
            torch.load(p / "discriminator.pt", weights_only=False)
        )

        # Set model to eval mode for inference
        self.generator.eval()
        self.discriminator.eval()

        log.info("model_loaded", path=str(p), version=saved_version)

    except SerializationError:
        raise
    except Exception as exc:
        raise SerializationError(
            f"SerializationError: Failed to load model from '{path}'. "
            f"Original error: {exc}"
        ) from exc

syntho_hive.core.models.base.ConditionalGenerativeModel

Bases: GenerativeModel

Contract for models that condition on parent context during training/sampling.

Constructor convention

Custom model classes passed as model_cls to StagedOrchestrator must accept the following constructor signature::

def __init__(self, metadata, batch_size=500, epochs=300, **kwargs):
    ...

The metadata positional argument and batch_size/epochs keyword arguments are forwarded by the orchestrator during fit_all(). Additional keyword arguments are forwarded from fit_all(**model_kwargs).

Python ABCs cannot enforce constructor signatures; this convention is documented here so custom implementations know what is expected.

Source code in syntho_hive/core/models/base.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class ConditionalGenerativeModel(GenerativeModel):
    """Contract for models that condition on parent context during training/sampling.

    Constructor convention:
        Custom model classes passed as ``model_cls`` to ``StagedOrchestrator``
        must accept the following constructor signature::

            def __init__(self, metadata, batch_size=500, epochs=300, **kwargs):
                ...

        The ``metadata`` positional argument and ``batch_size``/``epochs`` keyword
        arguments are forwarded by the orchestrator during ``fit_all()``. Additional
        keyword arguments are forwarded from ``fit_all(**model_kwargs)``.

        Python ABCs cannot enforce constructor signatures; this convention is
        documented here so custom implementations know what is expected.
    """

    @abstractmethod
    def fit(self, data: pd.DataFrame, context: Optional[pd.DataFrame] = None, **kwargs: Any) -> None:
        """Train the model with optional parent context.

        Args:
            data: Child table data to learn from.
            context: Optional parent attributes used for conditioning.
            **kwargs: Model-specific training options.
        """
        pass  # pragma: no cover

    @abstractmethod
    def sample(self, num_rows: int, context: Optional[pd.DataFrame] = None, **kwargs: Any) -> pd.DataFrame:
        """Generate synthetic rows with optional conditioning context.

        Args:
            num_rows: Number of rows to generate.
            context: Optional parent attributes aligned to the requested rows.
            **kwargs: Additional sampling controls.

        Returns:
            DataFrame of synthetic samples aligned to the provided context (if any).
        """
        pass  # pragma: no cover

fit abstractmethod

fit(data: DataFrame, context: Optional[DataFrame] = None, **kwargs: Any) -> None

Train the model with optional parent context.

Parameters:

Name Type Description Default
data DataFrame

Child table data to learn from.

required
context Optional[DataFrame]

Optional parent attributes used for conditioning.

None
**kwargs Any

Model-specific training options.

{}
Source code in syntho_hive/core/models/base.py
68
69
70
71
72
73
74
75
76
77
@abstractmethod
def fit(self, data: pd.DataFrame, context: Optional[pd.DataFrame] = None, **kwargs: Any) -> None:
    """Train the model with optional parent context.

    Args:
        data: Child table data to learn from.
        context: Optional parent attributes used for conditioning.
        **kwargs: Model-specific training options.
    """
    pass  # pragma: no cover

sample abstractmethod

sample(num_rows: int, context: Optional[DataFrame] = None, **kwargs: Any) -> pd.DataFrame

Generate synthetic rows with optional conditioning context.

Parameters:

Name Type Description Default
num_rows int

Number of rows to generate.

required
context Optional[DataFrame]

Optional parent attributes aligned to the requested rows.

None
**kwargs Any

Additional sampling controls.

{}

Returns:

Type Description
DataFrame

DataFrame of synthetic samples aligned to the provided context (if any).

Source code in syntho_hive/core/models/base.py
79
80
81
82
83
84
85
86
87
88
89
90
91
@abstractmethod
def sample(self, num_rows: int, context: Optional[pd.DataFrame] = None, **kwargs: Any) -> pd.DataFrame:
    """Generate synthetic rows with optional conditioning context.

    Args:
        num_rows: Number of rows to generate.
        context: Optional parent attributes aligned to the requested rows.
        **kwargs: Additional sampling controls.

    Returns:
        DataFrame of synthetic samples aligned to the provided context (if any).
    """
    pass  # pragma: no cover

Data Transformation

syntho_hive.core.data.transformer.DataTransformer

Reversible transformer for tabular data.

Continuous columns use a Bayesian GMM-based normalizer, while categorical columns are either one-hot encoded or mapped to indices for embeddings.

Source code in syntho_hive/core/data/transformer.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class DataTransformer:
    """Reversible transformer for tabular data.

    Continuous columns use a Bayesian GMM-based normalizer, while categorical
    columns are either one-hot encoded or mapped to indices for embeddings.
    """

    def __init__(self, metadata: Any, embedding_threshold: int = 50):
        """Create a transformer configured by table metadata.

        Args:
            metadata: Metadata object describing tables, keys, and constraints.
            embedding_threshold: Switch to embedding mode when cardinality exceeds this value.
        """
        self.metadata = metadata
        self.embedding_threshold = embedding_threshold
        self._transformers = {}
        self._column_info = {}  # Maps col_name -> {'type': str, 'dim': int, 'transformer': obj}
        self.output_dim = 0
        self._excluded_columns = []

    def _prepare_categorical(self, series: pd.Series) -> pd.Series:
        """Fill nulls with a sentinel value and ensure string type."""
        # Convert to object to handle mixed types (e.g. numbers and NaNs)
        series = series.astype(object)
        return series.fillna("<NAN>").astype(str)

    def fit(
        self,
        data: pd.DataFrame,
        table_name: Optional[str] = None,
        seed: Optional[int] = None,
    ):
        """Fit per-column transformers and collect column layout metadata.

        Args:
            data: DataFrame to profile and transform.
            table_name: Optional table name for applying PK/FK exclusions and constraints.
            seed: Optional integer seed propagated to each ``ClusterBasedNormalizer``
                  for deterministic BayesianGMM fitting.  A per-column seed is derived
                  from ``seed`` to avoid correlated RNG sequences across columns.

        Raises:
            ValueError: If metadata is missing table configurations.
        """
        self.table_name = table_name  # Store for constraint application later
        if not self.metadata.tables:
            raise ValueError("Metadata must be populated with table configs")

        # H5: Reject empty DataFrames early
        if data.empty:
            raise ValueError("Cannot fit transformer on empty DataFrame")

        # C4: Reset stale state from any previous fit() call
        self._transformers = {}
        self._column_info = {}

        columns_to_transform = data.columns.tolist()

        # Handle relational constraints if table_name is provided
        if table_name:
            table_config = self.metadata.get_table(table_name)
            if table_config:
                # Exclude PK and FKs from transformation
                pk = table_config.pk
                fks = list(table_config.fk.keys())
                self._excluded_columns = [pk] + fks
                columns_to_transform = [
                    c for c in columns_to_transform if c not in self._excluded_columns
                ]

        self.output_dim = 0

        for col in columns_to_transform:
            col_data = data[col]

            if pd.api.types.is_numeric_dtype(col_data):
                # Derive a per-column deterministic seed from the parent seed to avoid
                # correlated RNG sequences across columns when all columns share one seed.
                # C5: Use hashlib instead of Python's hash() for cross-process determinism.
                col_hash = int(hashlib.sha256(col.encode()).hexdigest(), 16) % 100_000
                col_seed = (seed + col_hash) if seed is not None else None
                # Continuous column
                transformer = ClusterBasedNormalizer(n_components=10, seed=col_seed)
                transformer.fit(col_data)

                # Dim is managed by the transformer now (dynamic based on nulls)
                dim = transformer.output_dim
                self._transformers[col] = transformer
                self._column_info[col] = {
                    "type": "continuous",
                    "dim": dim,
                    "transformer": transformer,
                }
                self.output_dim += dim

            else:
                # Categorical column
                # Use OneHotEncoder for now.
                # Categorical column
                # Check cardinality for embedding suggestion
                n_unique = col_data.nunique()
                if n_unique > self.embedding_threshold:
                    # Use LabelEncoder for Entity Embeddings
                    from sklearn.preprocessing import LabelEncoder

                    transformer = LabelEncoder()

                    # Fill Nulls
                    col_data_filled = self._prepare_categorical(col_data)

                    # LabelEncoder expects 1D array
                    transformer.fit(col_data_filled)

                    dim = 1  # Just the index
                    num_categories = len(
                        transformer.classes_
                    )  # include sentinel for nulls
                    self._transformers[col] = transformer
                    self._column_info[col] = {
                        "type": "categorical_embedding",
                        "dim": dim,
                        "num_categories": num_categories,
                        "transformer": transformer,
                    }
                    self.output_dim += dim
                else:
                    # Use OneHotEncoder
                    transformer = OneHotEncoder(
                        sparse_output=False, handle_unknown="ignore"
                    )

                    # Fill Nulls
                    col_data_filled = self._prepare_categorical(col_data)

                    values = col_data_filled.to_numpy(dtype=str).reshape(-1, 1)
                    transformer.fit(values)

                    dim = len(transformer.categories_[0])
                    self._transformers[col] = transformer
                    self._column_info[col] = {
                        "type": "categorical",
                        "dim": dim,
                        "transformer": transformer,
                    }
                    self.output_dim += dim

    def transform(self, data: pd.DataFrame) -> np.ndarray:
        """Transform a dataframe into model-ready numpy arrays.

        Args:
            data: DataFrame with the same columns used during ``fit``.

        Raises:
            ValueError: If the transformer has not been fitted or a column is missing.

        Returns:
            Concatenated numpy array representing all transformed columns.
        """
        if not self._transformers:
            raise ValueError("Transformer has not been fitted.")

        output_arrays = []

        # Iterate in the same order as fit/stored in _column_info
        for col, info in self._column_info.items():
            if col not in data.columns:
                raise ValueError(f"Column {col} missing from input data")

            transformer = self._transformers[col]
            col_data = data[col]

            if info["type"] == "continuous":
                # Returns (N, n_components + 1 [+ 1 if nulls])
                transformed = transformer.transform(col_data)
            elif info["type"] == "categorical_embedding":
                # Returns (N, 1)
                col_data_filled = self._prepare_categorical(col_data)
                # H7: Map unseen categories to a known class to avoid ValueError
                known = set(transformer.classes_)
                col_data_safe = col_data_filled.map(
                    lambda x: x if x in known else transformer.classes_[0]
                )
                values = transformer.transform(col_data_safe)
                transformed = values.reshape(-1, 1)
            else:
                # Returns (N, n_categories)
                col_data_filled = self._prepare_categorical(col_data)
                values = col_data_filled.to_numpy(dtype=str).reshape(-1, 1)
                transformed = transformer.transform(values)

            output_arrays.append(transformed)

        return np.concatenate(output_arrays, axis=1)

    def inverse_transform(self, data: np.ndarray) -> pd.DataFrame:
        """Convert model outputs back to the original dataframe schema.

        Args:
            data: Numpy array produced by a model, aligned to transform layout.

        Raises:
            ValueError: If called before ``fit``.

        Returns:
            DataFrame with original column names and value types (constraints applied).
        """
        if not self._transformers:
            raise ValueError("Transformer has not been fitted.")

        output_df = pd.DataFrame()
        start_idx = 0

        for col, info in self._column_info.items():
            dim = info["dim"]
            end_idx = start_idx + dim
            col_data = data[:, start_idx:end_idx]

            transformer = self._transformers[col]

            if info["type"] == "continuous":
                original_values = transformer.inverse_transform(col_data)
            elif info["type"] == "categorical_embedding":
                # col_data is (N, 1) floats/ints.
                # We need ints for LabelEncoder.
                indices = np.clip(
                    col_data.flatten().astype(int), 0, info["num_categories"] - 1
                )
                original_values = transformer.inverse_transform(indices)
                # Restore NaNs
                original_values = (
                    pd.Series(original_values).replace("<NAN>", np.nan).values
                )
            else:
                original_values = transformer.inverse_transform(col_data).flatten()
                # Restore NaNs
                original_values = (
                    pd.Series(original_values).replace("<NAN>", np.nan).values
                )

            # Apply Constraints
            if self.metadata and hasattr(self, "table_name") and self.table_name:
                table_config = self.metadata.get_table(self.table_name)
                if table_config and col in table_config.constraints:
                    constraint = table_config.constraints[col]

                    _is_numeric_constraint = (
                        (constraint.min is not None)
                        or (constraint.max is not None)
                        or (constraint.dtype in ["int", "float"])
                    )

                    if _is_numeric_constraint:
                        # Ensure data is numeric. If it was processed as categorical (strings),
                        # we need to convert it back to numbers to apply numeric constraints.
                        try:
                            if isinstance(original_values, np.ndarray):
                                # flatten to 1D for to_numeric
                                original_values = pd.to_numeric(
                                    original_values.flatten(), errors="coerce"
                                )
                            else:
                                original_values = pd.to_numeric(
                                    original_values, errors="coerce"
                                )
                        except Exception as exc:
                            log.warning(
                                "column_cast_failed",
                                column=col,
                                target_type="numeric",
                                error=str(exc),
                            )
                            # Leave the column as-is (do not set to NaN silently)

                    # 1. Rounding/Type
                    if constraint.dtype == "int":
                        # Round first
                        original_values = np.round(original_values)
                        # Safe cast to int (handles NaNs by skipping cast or filling if appropriate?
                        # For now, we only cast if possible to avoid crash)
                        try:
                            if isinstance(original_values, pd.Series):
                                if not original_values.isnull().any():
                                    original_values = original_values.astype(int)
                            elif isinstance(original_values, np.ndarray):
                                if not np.isnan(original_values).any():
                                    original_values = original_values.astype(int)
                        except Exception as exc:
                            log.warning(
                                "column_cast_failed",
                                column=col,
                                target_type="int",
                                error=str(exc),
                            )
                            # Leave the column as-is (do not set to NaN silently)

                    # 2. Clipping
                    if constraint.min is not None or constraint.max is not None:
                        # Handle potential pandas Series or numpy array
                        if isinstance(original_values, pd.Series):
                            original_values = original_values.clip(
                                lower=constraint.min, upper=constraint.max
                            )
                        else:
                            original_values = np.clip(
                                original_values, constraint.min, constraint.max
                            )

            output_df[col] = original_values
            start_idx = end_idx

        return output_df

fit

fit(data: DataFrame, table_name: Optional[str] = None, seed: Optional[int] = None)

Fit per-column transformers and collect column layout metadata.

Parameters:

Name Type Description Default
data DataFrame

DataFrame to profile and transform.

required
table_name Optional[str]

Optional table name for applying PK/FK exclusions and constraints.

None
seed Optional[int]

Optional integer seed propagated to each ClusterBasedNormalizer for deterministic BayesianGMM fitting. A per-column seed is derived from seed to avoid correlated RNG sequences across columns.

None

Raises:

Type Description
ValueError

If metadata is missing table configurations.

Source code in syntho_hive/core/data/transformer.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def fit(
    self,
    data: pd.DataFrame,
    table_name: Optional[str] = None,
    seed: Optional[int] = None,
):
    """Fit per-column transformers and collect column layout metadata.

    Args:
        data: DataFrame to profile and transform.
        table_name: Optional table name for applying PK/FK exclusions and constraints.
        seed: Optional integer seed propagated to each ``ClusterBasedNormalizer``
              for deterministic BayesianGMM fitting.  A per-column seed is derived
              from ``seed`` to avoid correlated RNG sequences across columns.

    Raises:
        ValueError: If metadata is missing table configurations.
    """
    self.table_name = table_name  # Store for constraint application later
    if not self.metadata.tables:
        raise ValueError("Metadata must be populated with table configs")

    # H5: Reject empty DataFrames early
    if data.empty:
        raise ValueError("Cannot fit transformer on empty DataFrame")

    # C4: Reset stale state from any previous fit() call
    self._transformers = {}
    self._column_info = {}

    columns_to_transform = data.columns.tolist()

    # Handle relational constraints if table_name is provided
    if table_name:
        table_config = self.metadata.get_table(table_name)
        if table_config:
            # Exclude PK and FKs from transformation
            pk = table_config.pk
            fks = list(table_config.fk.keys())
            self._excluded_columns = [pk] + fks
            columns_to_transform = [
                c for c in columns_to_transform if c not in self._excluded_columns
            ]

    self.output_dim = 0

    for col in columns_to_transform:
        col_data = data[col]

        if pd.api.types.is_numeric_dtype(col_data):
            # Derive a per-column deterministic seed from the parent seed to avoid
            # correlated RNG sequences across columns when all columns share one seed.
            # C5: Use hashlib instead of Python's hash() for cross-process determinism.
            col_hash = int(hashlib.sha256(col.encode()).hexdigest(), 16) % 100_000
            col_seed = (seed + col_hash) if seed is not None else None
            # Continuous column
            transformer = ClusterBasedNormalizer(n_components=10, seed=col_seed)
            transformer.fit(col_data)

            # Dim is managed by the transformer now (dynamic based on nulls)
            dim = transformer.output_dim
            self._transformers[col] = transformer
            self._column_info[col] = {
                "type": "continuous",
                "dim": dim,
                "transformer": transformer,
            }
            self.output_dim += dim

        else:
            # Categorical column
            # Use OneHotEncoder for now.
            # Categorical column
            # Check cardinality for embedding suggestion
            n_unique = col_data.nunique()
            if n_unique > self.embedding_threshold:
                # Use LabelEncoder for Entity Embeddings
                from sklearn.preprocessing import LabelEncoder

                transformer = LabelEncoder()

                # Fill Nulls
                col_data_filled = self._prepare_categorical(col_data)

                # LabelEncoder expects 1D array
                transformer.fit(col_data_filled)

                dim = 1  # Just the index
                num_categories = len(
                    transformer.classes_
                )  # include sentinel for nulls
                self._transformers[col] = transformer
                self._column_info[col] = {
                    "type": "categorical_embedding",
                    "dim": dim,
                    "num_categories": num_categories,
                    "transformer": transformer,
                }
                self.output_dim += dim
            else:
                # Use OneHotEncoder
                transformer = OneHotEncoder(
                    sparse_output=False, handle_unknown="ignore"
                )

                # Fill Nulls
                col_data_filled = self._prepare_categorical(col_data)

                values = col_data_filled.to_numpy(dtype=str).reshape(-1, 1)
                transformer.fit(values)

                dim = len(transformer.categories_[0])
                self._transformers[col] = transformer
                self._column_info[col] = {
                    "type": "categorical",
                    "dim": dim,
                    "transformer": transformer,
                }
                self.output_dim += dim

transform

transform(data: DataFrame) -> np.ndarray

Transform a dataframe into model-ready numpy arrays.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with the same columns used during fit.

required

Raises:

Type Description
ValueError

If the transformer has not been fitted or a column is missing.

Returns:

Type Description
ndarray

Concatenated numpy array representing all transformed columns.

Source code in syntho_hive/core/data/transformer.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def transform(self, data: pd.DataFrame) -> np.ndarray:
    """Transform a dataframe into model-ready numpy arrays.

    Args:
        data: DataFrame with the same columns used during ``fit``.

    Raises:
        ValueError: If the transformer has not been fitted or a column is missing.

    Returns:
        Concatenated numpy array representing all transformed columns.
    """
    if not self._transformers:
        raise ValueError("Transformer has not been fitted.")

    output_arrays = []

    # Iterate in the same order as fit/stored in _column_info
    for col, info in self._column_info.items():
        if col not in data.columns:
            raise ValueError(f"Column {col} missing from input data")

        transformer = self._transformers[col]
        col_data = data[col]

        if info["type"] == "continuous":
            # Returns (N, n_components + 1 [+ 1 if nulls])
            transformed = transformer.transform(col_data)
        elif info["type"] == "categorical_embedding":
            # Returns (N, 1)
            col_data_filled = self._prepare_categorical(col_data)
            # H7: Map unseen categories to a known class to avoid ValueError
            known = set(transformer.classes_)
            col_data_safe = col_data_filled.map(
                lambda x: x if x in known else transformer.classes_[0]
            )
            values = transformer.transform(col_data_safe)
            transformed = values.reshape(-1, 1)
        else:
            # Returns (N, n_categories)
            col_data_filled = self._prepare_categorical(col_data)
            values = col_data_filled.to_numpy(dtype=str).reshape(-1, 1)
            transformed = transformer.transform(values)

        output_arrays.append(transformed)

    return np.concatenate(output_arrays, axis=1)

inverse_transform

inverse_transform(data: ndarray) -> pd.DataFrame

Convert model outputs back to the original dataframe schema.

Parameters:

Name Type Description Default
data ndarray

Numpy array produced by a model, aligned to transform layout.

required

Raises:

Type Description
ValueError

If called before fit.

Returns:

Type Description
DataFrame

DataFrame with original column names and value types (constraints applied).

Source code in syntho_hive/core/data/transformer.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def inverse_transform(self, data: np.ndarray) -> pd.DataFrame:
    """Convert model outputs back to the original dataframe schema.

    Args:
        data: Numpy array produced by a model, aligned to transform layout.

    Raises:
        ValueError: If called before ``fit``.

    Returns:
        DataFrame with original column names and value types (constraints applied).
    """
    if not self._transformers:
        raise ValueError("Transformer has not been fitted.")

    output_df = pd.DataFrame()
    start_idx = 0

    for col, info in self._column_info.items():
        dim = info["dim"]
        end_idx = start_idx + dim
        col_data = data[:, start_idx:end_idx]

        transformer = self._transformers[col]

        if info["type"] == "continuous":
            original_values = transformer.inverse_transform(col_data)
        elif info["type"] == "categorical_embedding":
            # col_data is (N, 1) floats/ints.
            # We need ints for LabelEncoder.
            indices = np.clip(
                col_data.flatten().astype(int), 0, info["num_categories"] - 1
            )
            original_values = transformer.inverse_transform(indices)
            # Restore NaNs
            original_values = (
                pd.Series(original_values).replace("<NAN>", np.nan).values
            )
        else:
            original_values = transformer.inverse_transform(col_data).flatten()
            # Restore NaNs
            original_values = (
                pd.Series(original_values).replace("<NAN>", np.nan).values
            )

        # Apply Constraints
        if self.metadata and hasattr(self, "table_name") and self.table_name:
            table_config = self.metadata.get_table(self.table_name)
            if table_config and col in table_config.constraints:
                constraint = table_config.constraints[col]

                _is_numeric_constraint = (
                    (constraint.min is not None)
                    or (constraint.max is not None)
                    or (constraint.dtype in ["int", "float"])
                )

                if _is_numeric_constraint:
                    # Ensure data is numeric. If it was processed as categorical (strings),
                    # we need to convert it back to numbers to apply numeric constraints.
                    try:
                        if isinstance(original_values, np.ndarray):
                            # flatten to 1D for to_numeric
                            original_values = pd.to_numeric(
                                original_values.flatten(), errors="coerce"
                            )
                        else:
                            original_values = pd.to_numeric(
                                original_values, errors="coerce"
                            )
                    except Exception as exc:
                        log.warning(
                            "column_cast_failed",
                            column=col,
                            target_type="numeric",
                            error=str(exc),
                        )
                        # Leave the column as-is (do not set to NaN silently)

                # 1. Rounding/Type
                if constraint.dtype == "int":
                    # Round first
                    original_values = np.round(original_values)
                    # Safe cast to int (handles NaNs by skipping cast or filling if appropriate?
                    # For now, we only cast if possible to avoid crash)
                    try:
                        if isinstance(original_values, pd.Series):
                            if not original_values.isnull().any():
                                original_values = original_values.astype(int)
                        elif isinstance(original_values, np.ndarray):
                            if not np.isnan(original_values).any():
                                original_values = original_values.astype(int)
                    except Exception as exc:
                        log.warning(
                            "column_cast_failed",
                            column=col,
                            target_type="int",
                            error=str(exc),
                        )
                        # Leave the column as-is (do not set to NaN silently)

                # 2. Clipping
                if constraint.min is not None or constraint.max is not None:
                    # Handle potential pandas Series or numpy array
                    if isinstance(original_values, pd.Series):
                        original_values = original_values.clip(
                            lower=constraint.min, upper=constraint.max
                        )
                    else:
                        original_values = np.clip(
                            original_values, constraint.min, constraint.max
                        )

        output_df[col] = original_values
        start_idx = end_idx

    return output_df

syntho_hive.core.data.transformer.ClusterBasedNormalizer

VGM-based normalizer for continuous columns.

Projects a value to a cluster assignment and a normalized scalar relative to the chosen component.

Source code in syntho_hive/core/data/transformer.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
class ClusterBasedNormalizer:
    """VGM-based normalizer for continuous columns.

    Projects a value to a cluster assignment and a normalized scalar relative
    to the chosen component.
    """

    def __init__(self, n_components: int = 10, seed: Optional[int] = None):
        """Configure the number of mixture components.

        Args:
            n_components: Number of Gaussian mixture components.
            seed: Random state for the BayesianGaussianMixture. When None,
                  defaults to 42 for backward compatibility.
        """
        self.n_components = n_components
        # Use provided seed; fall back to 42 for backward compatibility when no seed given.
        random_state = seed if seed is not None else 42
        self._seed = random_state
        self.means = None
        self.stds = None
        self.has_nulls = False
        self.all_null = False
        self.fill_value = 0.0
        self.output_dim = 0

    def fit(self, data: pd.Series):
        """Fit the Bayesian GMM on a continuous series.

        Args:
            data: Continuous pandas Series to normalize.
        """
        # 1. Handle Nulls
        self.has_nulls = data.isnull().any()

        # C6: All-null numeric columns — skip GMM fitting entirely
        if data.dropna().empty:
            self.all_null = True
            self.fill_value = 0.0
            # Output dim: n_components (one-hot) + 1 (scalar) + 1 (null indicator)
            self.output_dim = self.n_components + 1 + 1
            return

        if self.has_nulls:
            self.fill_value = data.mean()
            # Impute for training GMM
            values = data.fillna(self.fill_value).to_numpy(dtype=float).reshape(-1, 1)
            self.output_dim = self.n_components + 1 + 1  # +1 for null indicator
        else:
            values = data.to_numpy(dtype=float).reshape(-1, 1)
            self.output_dim = self.n_components + 1

        # H6: Clamp n_components when fewer valid samples than components
        n_valid = data.dropna().shape[0]
        effective_components = min(self.n_components, max(n_valid, 1))

        self.model = BayesianGaussianMixture(
            n_components=effective_components,
            weight_concentration_prior_type="dirichlet_process",
            n_init=1,
            random_state=self._seed,
        )
        self.model.fit(values)
        self.means = self.model.means_.flatten()  # (effective_components,)
        self.stds = np.sqrt(
            self.model.covariances_
        ).flatten()  # (effective_components,)

    def transform(self, data: pd.Series) -> np.ndarray:
        """Project values to one-hot cluster assignment and normalized scalar.

        Args:
            data: Continuous pandas Series to transform.

        Returns:
            Numpy array of shape ``(N, n_components + 1 [+1])`` with one-hot cluster, scaled value, [null_ind].
        """
        n_samples = len(data)

        # C6: all-null column — return zeros + null indicator of all 1s
        if self.all_null:
            cluster_one_hot = np.zeros((n_samples, self.n_components))
            scalar_col = np.zeros((n_samples, 1))
            null_indicator = np.ones((n_samples, 1))
            return np.concatenate([cluster_one_hot, scalar_col, null_indicator], axis=1)

        values_raw = data.to_numpy(dtype=float).reshape(-1, 1)

        if self.has_nulls:
            # 0. Create Null Indicator
            null_indicator = pd.isnull(data).to_numpy(dtype=float).reshape(-1, 1)

            # 1. Impute for projection
            values_clean = (
                data.fillna(self.fill_value).to_numpy(dtype=float).reshape(-1, 1)
            )
        else:
            values_clean = values_raw

        # 2. Get cluster probabilities: P(c|x)
        probs = self.model.predict_proba(values_clean)  # (N, n_components)

        # 3. Sample component c ~ P(c|x) (Argmax for simplicity/determinism in this impl)
        # CTGAN uses argmax during interaction but sampling during training prep sometimes.
        # Using argmax is stable.
        cluster_assignments = np.argmax(probs, axis=1)

        # 4. Calculate normalized scalar: v = (x - mu_c) / (4 * sigma_c)
        # Clip to [-1, 1] usually, or roughly there.
        means = self.means[cluster_assignments]
        stds = self.stds[cluster_assignments]

        # H8: Add epsilon to avoid division by zero when std is zero
        normalized_values = (values_clean.flatten() - means) / (4 * stds + 1e-8)
        normalized_values = normalized_values.reshape(-1, 1)

        # 5. Create One-Hot encoding of cluster assignment
        cluster_one_hot = np.zeros((n_samples, self.n_components))
        cluster_one_hot[np.arange(n_samples), cluster_assignments] = 1

        # Output: [one_hot_cluster, scalar, (null_indicator)]
        if self.has_nulls:
            return np.concatenate(
                [cluster_one_hot, normalized_values, null_indicator], axis=1
            )
        else:
            return np.concatenate([cluster_one_hot, normalized_values], axis=1)

    def inverse_transform(self, data: np.ndarray) -> pd.Series:
        """Reconstruct approximate original values from normalized representation.

        Args:
            data: Array shaped ``(N, n_components + 1 [+1])`` produced by ``transform``.

        Returns:
            Pandas Series of reconstructed continuous values.
        """
        n_samples = data.shape[0]

        # C6: all-null column — return all NaNs
        if self.all_null:
            return pd.Series(np.full(n_samples, np.nan))

        # data shape: (N, n_components + 1 [+1])

        current_idx = 0

        # 1. Cluster One-Hot
        cluster_one_hot = data[:, current_idx : current_idx + self.n_components]
        current_idx += self.n_components

        # 2. Scalar
        scalars = data[:, current_idx]
        current_idx += 1

        # 3. Null Indicator
        if self.has_nulls:
            null_indicators = data[:, current_idx]
            current_idx += 1
        else:
            null_indicators = None

        # Identify cluster
        cluster_assignments = np.argmax(cluster_one_hot, axis=1)

        means = self.means[cluster_assignments]
        stds = self.stds[cluster_assignments]

        # Reconstruct: x = v * 4 * sigma_c + mu_c
        reconstructed_values = scalars * 4 * stds + means

        # Apply Null Masking
        if null_indicators is not None:
            # If null indicator > 0.5 (generated as sigmoid usually, but here just boolean/float)
            # Generator should output something close to 0 or 1.
            # We assume threshold 0.5
            is_null = null_indicators > 0.5
            reconstructed_values[is_null] = np.nan

        return pd.Series(reconstructed_values)

fit

fit(data: Series)

Fit the Bayesian GMM on a continuous series.

Parameters:

Name Type Description Default
data Series

Continuous pandas Series to normalize.

required
Source code in syntho_hive/core/data/transformer.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def fit(self, data: pd.Series):
    """Fit the Bayesian GMM on a continuous series.

    Args:
        data: Continuous pandas Series to normalize.
    """
    # 1. Handle Nulls
    self.has_nulls = data.isnull().any()

    # C6: All-null numeric columns — skip GMM fitting entirely
    if data.dropna().empty:
        self.all_null = True
        self.fill_value = 0.0
        # Output dim: n_components (one-hot) + 1 (scalar) + 1 (null indicator)
        self.output_dim = self.n_components + 1 + 1
        return

    if self.has_nulls:
        self.fill_value = data.mean()
        # Impute for training GMM
        values = data.fillna(self.fill_value).to_numpy(dtype=float).reshape(-1, 1)
        self.output_dim = self.n_components + 1 + 1  # +1 for null indicator
    else:
        values = data.to_numpy(dtype=float).reshape(-1, 1)
        self.output_dim = self.n_components + 1

    # H6: Clamp n_components when fewer valid samples than components
    n_valid = data.dropna().shape[0]
    effective_components = min(self.n_components, max(n_valid, 1))

    self.model = BayesianGaussianMixture(
        n_components=effective_components,
        weight_concentration_prior_type="dirichlet_process",
        n_init=1,
        random_state=self._seed,
    )
    self.model.fit(values)
    self.means = self.model.means_.flatten()  # (effective_components,)
    self.stds = np.sqrt(
        self.model.covariances_
    ).flatten()  # (effective_components,)

inverse_transform

inverse_transform(data: ndarray) -> pd.Series

Reconstruct approximate original values from normalized representation.

Parameters:

Name Type Description Default
data ndarray

Array shaped (N, n_components + 1 [+1]) produced by transform.

required

Returns:

Type Description
Series

Pandas Series of reconstructed continuous values.

Source code in syntho_hive/core/data/transformer.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def inverse_transform(self, data: np.ndarray) -> pd.Series:
    """Reconstruct approximate original values from normalized representation.

    Args:
        data: Array shaped ``(N, n_components + 1 [+1])`` produced by ``transform``.

    Returns:
        Pandas Series of reconstructed continuous values.
    """
    n_samples = data.shape[0]

    # C6: all-null column — return all NaNs
    if self.all_null:
        return pd.Series(np.full(n_samples, np.nan))

    # data shape: (N, n_components + 1 [+1])

    current_idx = 0

    # 1. Cluster One-Hot
    cluster_one_hot = data[:, current_idx : current_idx + self.n_components]
    current_idx += self.n_components

    # 2. Scalar
    scalars = data[:, current_idx]
    current_idx += 1

    # 3. Null Indicator
    if self.has_nulls:
        null_indicators = data[:, current_idx]
        current_idx += 1
    else:
        null_indicators = None

    # Identify cluster
    cluster_assignments = np.argmax(cluster_one_hot, axis=1)

    means = self.means[cluster_assignments]
    stds = self.stds[cluster_assignments]

    # Reconstruct: x = v * 4 * sigma_c + mu_c
    reconstructed_values = scalars * 4 * stds + means

    # Apply Null Masking
    if null_indicators is not None:
        # If null indicator > 0.5 (generated as sigmoid usually, but here just boolean/float)
        # Generator should output something close to 0 or 1.
        # We assume threshold 0.5
        is_null = null_indicators > 0.5
        reconstructed_values[is_null] = np.nan

    return pd.Series(reconstructed_values)

transform

transform(data: Series) -> np.ndarray

Project values to one-hot cluster assignment and normalized scalar.

Parameters:

Name Type Description Default
data Series

Continuous pandas Series to transform.

required

Returns:

Type Description
ndarray

Numpy array of shape (N, n_components + 1 [+1]) with one-hot cluster, scaled value, [null_ind].

Source code in syntho_hive/core/data/transformer.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def transform(self, data: pd.Series) -> np.ndarray:
    """Project values to one-hot cluster assignment and normalized scalar.

    Args:
        data: Continuous pandas Series to transform.

    Returns:
        Numpy array of shape ``(N, n_components + 1 [+1])`` with one-hot cluster, scaled value, [null_ind].
    """
    n_samples = len(data)

    # C6: all-null column — return zeros + null indicator of all 1s
    if self.all_null:
        cluster_one_hot = np.zeros((n_samples, self.n_components))
        scalar_col = np.zeros((n_samples, 1))
        null_indicator = np.ones((n_samples, 1))
        return np.concatenate([cluster_one_hot, scalar_col, null_indicator], axis=1)

    values_raw = data.to_numpy(dtype=float).reshape(-1, 1)

    if self.has_nulls:
        # 0. Create Null Indicator
        null_indicator = pd.isnull(data).to_numpy(dtype=float).reshape(-1, 1)

        # 1. Impute for projection
        values_clean = (
            data.fillna(self.fill_value).to_numpy(dtype=float).reshape(-1, 1)
        )
    else:
        values_clean = values_raw

    # 2. Get cluster probabilities: P(c|x)
    probs = self.model.predict_proba(values_clean)  # (N, n_components)

    # 3. Sample component c ~ P(c|x) (Argmax for simplicity/determinism in this impl)
    # CTGAN uses argmax during interaction but sampling during training prep sometimes.
    # Using argmax is stable.
    cluster_assignments = np.argmax(probs, axis=1)

    # 4. Calculate normalized scalar: v = (x - mu_c) / (4 * sigma_c)
    # Clip to [-1, 1] usually, or roughly there.
    means = self.means[cluster_assignments]
    stds = self.stds[cluster_assignments]

    # H8: Add epsilon to avoid division by zero when std is zero
    normalized_values = (values_clean.flatten() - means) / (4 * stds + 1e-8)
    normalized_values = normalized_values.reshape(-1, 1)

    # 5. Create One-Hot encoding of cluster assignment
    cluster_one_hot = np.zeros((n_samples, self.n_components))
    cluster_one_hot[np.arange(n_samples), cluster_assignments] = 1

    # Output: [one_hot_cluster, scalar, (null_indicator)]
    if self.has_nulls:
        return np.concatenate(
            [cluster_one_hot, normalized_values, null_indicator], axis=1
        )
    else:
        return np.concatenate([cluster_one_hot, normalized_values], axis=1)