How to generate SQL from Informatica Jobs XML using LLM

# Install dependencies (run once)
!pip install xmltodict langchain langchain-openai tiktoken

import os
import json
import xmltodict
import re
from math import ceil
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain

# ---------------------------
# Configuration
# ---------------------------
MODEL_NAME = "gpt-4.1-nano-2025-04-14"
# conservative per-call token budget (model has 60k, but leave headroom)
MAX_TOKENS_PER_CALL = 20000   # tune lower if needed
# tokens reserved for prompt overhead and final assembly
PROMPT_OVERHEAD_TOKENS = 2000

# Path to XML (change if needed)
XML_PATH = "/content/large_job.xml"

# ---------------------------
# Tokenizer utilities
# ---------------------------
import tiktoken

def get_tokenizer():
    try:
        return tiktoken.get_encoding("cl100k_base")
    except Exception:
        class Fallback:
            def encode(self, s): return s.encode("utf-8")
            def decode(self, b): return b.decode("utf-8")
        return Fallback()

tokenizer = get_tokenizer()

def count_tokens(text: str) -> int:
    try:
        toks = tokenizer.encode(text)
        return len(toks)
    except Exception:
        return max(1, len(text) // 4)

def chunk_list_by_tokens(items, key_fn, max_tokens):
    chunks = []
    cur = []
    cur_tokens = 0
    for it in items:
        s = key_fn(it)
        t = count_tokens(s)
        if cur and (cur_tokens + t > max_tokens):
            chunks.append(cur)
            cur = [it]
            cur_tokens = t
        else:
            cur.append(it)
            cur_tokens += t
    if cur:
        chunks.append(cur)
    return chunks

# ---------------------------
# Parse XML and extract compact mapping
# ---------------------------
with open(XML_PATH, "r", encoding="utf-8") as f:
    xml_content = f.read()

etl_dict = xmltodict.parse(xml_content)

def safe_get(d, *keys):
    cur = d
    for k in keys:
        if cur is None:
            return None
        cur = cur.get(k)
    return cur

def extract_compact_mapping(etl):
    repo = safe_get(etl, "POWERMART", "REPOSITORY")
    folder = None
    if repo is None:
        return {"mappings": []}
    if isinstance(repo.get("FOLDER"), list):
        folder = repo["FOLDER"][0]
    else:
        folder = repo.get("FOLDER")
    result = {"mappings": []}
    if not folder:
        return result
    mappings = folder.get("MAPPING")
    if mappings is None:
        return result
    if not isinstance(mappings, list):
        mappings = [mappings]
    for m in mappings:
        mname = m.get("@NAME")
        compact = {"name": mname, "sources": [], "targets": [], "expressions": [], "filters": [], "aggregators": []}
        sources = folder.get("SOURCE") or []
        if not isinstance(sources, list):
            sources = [sources]
        for s in sources:
            fields = s.get("SOURCEFIELD") or []
            if not isinstance(fields, list):
                fields = [fields]
            sr = {"name": s.get("@NAME"), "owner": s.get("@OWNERNAME"), "fields": [ {"name":f.get("@NAME"), "datatype": f.get("@DATATYPE")} for f in fields ]}
            compact["sources"].append(sr)
        targets = folder.get("TARGET") or []
        if not isinstance(targets, list):
            targets = [targets]
        for t in targets:
            tf = t.get("TARGETFIELD") or []
            if not isinstance(tf, list):
                tf = [tf]
            tr = {"name": t.get("@NAME"), "owner": t.get("@OWNERNAME"), "fields": [ {"name":f.get("@NAME"), "datatype": f.get("@DATATYPE")} for f in tf ]}
            compact["targets"].append(tr)
        trans = m.get("TRANSFORMATION") or []
        if not isinstance(trans, list):
            trans = [trans]
        for t in trans:
            ttype = (t.get("@TYPE") or t.get("TYPE") or "").lower()
            tname = t.get("@NAME")
            tf = t.get("TRANSFORMFIELD") or []
            if not isinstance(tf, list):
                tf = [tf]
            if "expression" in ttype:
                exprs = []
                for f in tf:
                    exp = f.get("@EXPRESSION") or f.get("EXPRESSION")
                    if exp:
                        exprs.append({"field": f.get("@NAME"), "expression": " ".join(str(exp).split())})
                if exprs:
                    compact["expressions"].append({"name": tname, "exprs": exprs})
            elif "filter" in ttype:
                cond = None
                attrs = t.get("TABLEATTRIBUTE") or []
                if not isinstance(attrs, list):
                    attrs = [attrs]
                for a in attrs:
                    if a.get("@NAME","").lower().startswith("filter") and a.get("@VALUE"):
                        cond = a.get("@VALUE")
                if cond:
                    compact["filters"].append({"name": tname, "condition": " ".join(str(cond).split())})
            elif "aggregator" in ttype:
                aggs = []
                groupbys = []
                for f in tf:
                    exp = f.get("@EXPRESSION") or f.get("EXPRESSION")
                    if exp and any(fn in str(exp).upper() for fn in ["SUM(", "COUNT(", "AVG(", "MIN(", "MAX("]):
                        aggs.append({"field": f.get("@NAME"), "expression": " ".join(str(exp).split())})
                gb = t.get("GROUPBYPORT") or []
                if not isinstance(gb, list):
                    gb = [gb]
                for g in gb:
                    if isinstance(g, dict) and g.get("@NAME"):
                        groupbys.append(g.get("@NAME"))
                compact["aggregators"].append({"name": tname, "group_bys": groupbys, "aggregates": aggs})
        result["mappings"].append(compact)
    return result

compact = extract_compact_mapping(etl_dict)
compact_json = json.dumps(compact, ensure_ascii=False)
print("Compact JSON size (chars):", len(compact_json))
with open("/content/compact_summary.json","w",encoding="utf-8") as f:
    f.write(compact_json)

# ---------------------------
# Flatten expressions to items
# ---------------------------
expr_items = []
for mapping in compact["mappings"]:
    mapping_name = mapping["name"]
    for expr_block in mapping.get("expressions", []):
        for e in expr_block.get("exprs", []):
            expr_items.append({
                "mapping": mapping_name,
                "transform": expr_block.get("name"),
                "field": e["field"],
                "expression": e["expression"]
            })
print("Total expression items:", len(expr_items))

# ---------------------------
# Lightweight pre-translation / sanitization
# ---------------------------
def pre_translate_expr(inf_expr: str) -> str:
    s = inf_expr
    # IIF -> CASE WHEN (simple heuristic)
    def replace_iif(match):
        inner = match.group(1)
        depth = 0
        cur = ""
        parts = []
        for ch in inner:
            if ch == '(':
                depth += 1
            elif ch == ')':
                depth -= 1
            if ch == ',' and depth == 0 and len(parts) < 2:
                parts.append(cur)
                cur = ""
            else:
                cur += ch
        parts.append(cur)
        if len(parts) >= 3:
            cond = parts[0].strip()
            thenp = parts[1].strip()
            elsep = ",".join(parts[2:]).strip()
            return f"(CASE WHEN {cond} THEN {thenp} ELSE {elsep} END)"
        else:
            return match.group(0)
    for _ in range(4):
        s = re.sub(r"IIF\((.*?)\)", replace_iif, s, flags=re.IGNORECASE|re.DOTALL)
    s = re.sub(r"TO_INTEGER\s*\(\s*TO_CHAR\s*\(\s*([A-Za-z0-9_\.]+)\s*,\s*'YYYYMMDD'\s*\)\s*\)", 
               r"CAST(FORMAT_DATE('%Y%m%d', DATE(\1)) AS INT64)", s, flags=re.IGNORECASE)
    s = re.sub(r"ISNULL\s*\(\s*([A-Za-z0-9_\.]+)\s*\)", r"IFNULL(\1, NULL)", s, flags=re.IGNORECASE)
    s = re.sub(r"TO_CHAR\s*\(\s*([A-Za-z0-9_\.]+)\s*,\s*'YYYYMM'\s*\)", r"FORMAT_DATE('%Y%m', DATE(\1))", s, flags=re.IGNORECASE)
    return s

for it in expr_items:
    it["expr_sanitized"] = pre_translate_expr(it["expression"])

# ---------------------------
# Chunk expressions by tokens
# ---------------------------
per_call_token_limit = MAX_TOKENS_PER_CALL - PROMPT_OVERHEAD_TOKENS
per_chunk_tokens = max(2000, int(per_call_token_limit / 6))
expr_chunks = chunk_list_by_tokens(expr_items, key_fn=lambda it: it["expr_sanitized"], max_tokens=per_chunk_tokens)
print(f"Total chunks to generate: {len(expr_chunks)} (per chunk token budget ~{per_chunk_tokens})")

# ---------------------------
# Initialize LLM (LangChain) - use environment variable for API key
# ---------------------------
OPENAI_API_KEY = "your open ai key here"
if not OPENAI_API_KEY:
    raise RuntimeError("Please set OPENAI_API_KEY in your environment (do NOT hardcode).")
llm = ChatOpenAI(model=MODEL_NAME, openai_api_key=OPENAI_API_KEY, temperature=0)

# ---------------------------
# Fragment prompt (use ChatPromptTemplate) and chain (create once)
# ---------------------------
fragment_prompt_template = """
You are a senior data engineer. You will receive a JSON list of transformation expressions from an Informatica mapping.
For each expression item, produce a BigQuery SELECT expression that assigns the transformed expression as a column.
Output a single SQL CTE named {cte_name} with columns:
 - mapping_name
 - field_name
 - computed_value (the computed column for that expression)

Requirements:
- Translate constructs to BigQuery syntax (CASE WHEN instead of IIF, IFNULL, FORMAT_DATE, CAST(... AS INT64), etc).
- The CTE should produce rows of the form (mapping_name, field_name, computed_value).
- Output only the SQL CTE, no extra text.
JSON Input:
{json_chunk}
"""
fragment_prompt = ChatPromptTemplate.from_template(fragment_prompt_template)
chain_fragment = LLMChain(llm=llm, prompt=fragment_prompt)

# ---------------------------
# Generate CTE fragments (use chain_fragment.run with variables)
# ---------------------------
cte_list = []
for i, chunk in enumerate(expr_chunks):
    cte_name = f"cte_expr_{i}"
    small_chunk = [{"mapping": it["mapping"], "transform": it["transform"], "field": it["field"], "expr": it["expr_sanitized"]} for it in chunk]
    json_chunk = json.dumps(small_chunk, ensure_ascii=False)
    print(f"Requesting fragment {i+1}/{len(expr_chunks)} - items: {len(chunk)}")
    # IMPORTANT: pass variables to .run(); do NOT pre-format the prompt object
    sql_cte = chain_fragment.run({"cte_name": cte_name, "json_chunk": json_chunk})
    sql_cte = sql_cte.strip()
    cte_list.append({"name": cte_name, "sql": sql_cte})

# Save generated CTEs
with open("/content/generated_ctes.sql","w",encoding="utf-8") as f:
    for c in cte_list:
        f.write(c["sql"] + "\n\n")
print("Completed generating CTE fragments. Saved to /content/generated_ctes.sql")

# ---------------------------
# Final assembly: use single prompt and chain (create once)
# ---------------------------
final_summary = {"mappings": compact["mappings"]}
final_summary_json = json.dumps(final_summary, ensure_ascii=False)

final_prompt_template = """
You are a senior data engineer. You have been provided:
1) A compact JSON mapping summary:
{mapping_json}

2) A list of previously generated CTEs (each CTE is saved in the SQL file '/content/generated_ctes.sql').
Each CTE is named: {cte_names}

Task:
Using BigQuery SQL, compose a single MERGE statement that implements the mapping for the mapping named: {mapping_to_use}
Rules:
- Use source and target table names from the mapping summary.
- Use natural keys found in target fields for MERGE ON.
- Use the generated CTEs as needed (reference them in a WITH clause).
- Apply filters and aggregations per the mapping summary.
- Set LAST_UPDATE_TS = CURRENT_TIMESTAMP() on UPDATE and INSERT.
- Output ONLY the final MERGE SQL (no comments).
Mapping JSON:
{mapping_json}
CTE names available: {cte_names}
CTE file location: /content/generated_ctes.sql
Mapping to use: {mapping_to_use}
"""
final_prompt = ChatPromptTemplate.from_template(final_prompt_template)
chain_final = LLMChain(llm=llm, prompt=final_prompt)

mapping_name = compact["mappings"][0]["name"] if compact["mappings"] else ""
cte_names = ", ".join([c["name"] for c in cte_list]) if cte_list else ""

print("Requesting final MERGE SQL assembly...")
final_sql = chain_final.run({"mapping_json": final_summary_json, "cte_names": cte_names, "mapping_to_use": mapping_name})
final_sql = final_sql.strip()
print("===== FINAL MERGE SQL =====")
print(final_sql)

with open("/content/final_merge.sql","w",encoding="utf-8") as f:
    f.write(final_sql)
print("Saved final MERGE SQL to /content/final_merge.sql")