AI News HubLIVE
站内改写5 min read

Salesforce CodeGen Tutorial: Generate, Validate, and Rerank Python Functions With Unit Tests and Safety Checks

We implement an end-to-end workflow for Salesforce CodeGen, loaded from Hugging Face. We move past basic inference by adding function extraction, syntax checking, static safety checks, and unit-test validation. We rerank best-of-N candidates, compose multi-turn program synthesis, and experiment with prompt styles. We finish by visualizing a mini benchmark and exporting the generated artifacts as reusable files.

SourceMarkTechPostAuthor: Sana Hassan

In this tutorial, we implement an end-to-end workflow for Salesforce CodeGen. We load a CodeGen model from Hugging Face, prepare it for code generation, and use it to generate Python functions from natural-language prompts. We then move beyond basic inference by adding function extraction, syntax checking, static safety checks, unit-test-based validation, best-of-N candidate reranking, multi-step program synthesis, prompt-style experimentation, benchmark visualization, and artifact export. Through this workflow, we learn how CodeGen can be used not only as a code completion model but also as part of a structured code-generation pipeline that evaluates, filters, and organizes generated solutions.

Loading the Salesforce CodeGen Model from Hugging Face

Copy CodeCopiedUse a different Browser

import os, sys, subprocess, textwrap, json, re, time, math, ast, tempfile, multiprocessing as mp from pathlib import Path def sh(cmd): print(f"\n$ {cmd}") subprocess.run(cmd, shell=True, check=True) sh(f"{sys.executable} -m pip install -q -U transformers accelerate safetensors einops datasets evaluate pandas matplotlib tqdm rich radon tiktoken") import torch import pandas as pd import matplotlib.pyplot as plt from tqdm.auto import tqdm from rich import print from rich.panel import Panel from rich.syntax import Syntax from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from radon.complexity import cc_visit OUT_DIR = Path("/content/codegen_advanced_tutorial") OUT_DIR.mkdir(parents=True, exist_ok=True) set_seed(42) print(Panel.fit("Salesforce CodeGen Advanced Tutorial", style="bold green")) print("\nRuntime information") print("Python:", sys.version.split()[0]) print("Torch:", torch.version) print("CUDA available:", torch.cuda.is_available()) if torch.cuda.is_available(): print("GPU:", torch.cuda.get_device_name(0)) print("CUDA memory GB:", round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2)) MODEL_ID = os.environ.get("CODEGEN_MODEL_ID", "Salesforce/codegen-350M-mono") MODEL_OPTIONS = { "easy_colab_default": "Salesforce/codegen-350M-mono", "larger_codegen1": "Salesforce/codegen-2B-mono", "codegen2_1b": "Salesforce/codegen2-1B_P", "codegen25_7b_mono": "Salesforce/codegen25-7b-mono_P", } print("\nSelected model:", MODEL_ID) print("Available model examples:", MODEL_OPTIONS) trust_remote_code = any(x in MODEL_ID.lower() for x in ["codegen2", "codegen25"]) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=trust_remote_code ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Loading model...") load_kwargs = { "trust_remote_code": trust_remote_code, "low_cpu_mem_usage": True, } if torch.cuda.is_available(): load_kwargs["torch_dtype"] = dtype load_kwargs["device_map"] = "auto" else: load_kwargs["torch_dtype"] = torch.float32 model = AutoModelForCausalLM.from_pretrained(MODEL_ID, load_kwargs) if not torch.cuda.is_available(): model.to(device) model.eval() def count_parameters(model): return sum(p.numel() for p in model.parameters()) print(f"Loaded {MODEL_ID}") print(f"Parameter count: {count_parameters(model)/1e6:.1f}M") def generate_text( prompt, max_new_tokens=180, temperature=0.35, top_p=0.92, top_k=50, do_sample=True, num_return_sequences=1, repetition_penalty=1.05, ): inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, num_return_sequences=num_return_sequences, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) return decoded def print_code(title, code): print(Panel.fit(title, style="bold cyan")) print(Syntax(code, "python", theme="monokai", line_numbers=True))

We install all required libraries and prepare the environment for running Salesforce CodeGen. We check the runtime, detect GPU availability, select the CodeGen model, and load both the tokenizer and model from Hugging Face. We also define helper functions for text generation and for displaying formatted code so that the rest of the tutorial is easier to follow.

Building Extraction, Safety, and Unit-Test Validation Utilities

Copy CodeCopiedUse a different Browser

def extract_function_source(full_text, function_name): text = full_text.replace("\r\n", "\n") fence = re.search(r"``(?:python)?\n(.*?)``", text, flags=re.S | re.I) if fence: text = fence.group(1) pattern = rf"^def\s+{re.escape(function_name)}\s*\(" match = re.search(pattern, text, flags=re.M) if not match: return "" chunk = text[match.start():] lines = chunk.splitlines() collected = [] for i, line in enumerate(lines): if i > 0: if line.startswith("def ") or line.startswith("class "): break if line.startswith("if name"): break if line and not line.startswith((" ", "\t", "#")) and re.match(r"^[A-Za-z_][A-Za-z0-9_]*\s*=", line): break collected.append(line) source = "\n".join(collected).rstrip() try: ast.parse(source) return source except SyntaxError: fixed_lines = [] for line in collected: fixed_lines.append(line) candidate = "\n".join(fixed_lines).rstrip() try: ast.parse(candidate) source = candidate except SyntaxError: pass return source if source.strip().startswith("def ") else "" def syntax_ok(source): try: ast.parse(source) return True, "" except SyntaxError as e: return False, str(e) FORBIDDEN_NAMES = { "eval", "exec", "compile", "open", "input", "import", "globals", "locals", "vars", "dir", "getattr", "setattr", "delattr", "help", "breakpoint", "exit", "quit" } FORBIDDEN_NODES = ( ast.Import, ast.ImportFrom, ast.Global, ast.Nonlocal, ast.With, ast.AsyncWith, ast.AsyncFunctionDef, ast.ClassDef, ast.Delete, ast.Raise, ) ALLOWED_BUILTINS = { "abs": abs, "all": all, "any": any, "bool": bool, "dict": dict, "enumerate": enumerate, "float": float, "int": int, "isinstance": isinstance, "len": len, "list": list, "map": map, "max": max, "min": min, "pow": pow, "range": range, "reversed": reversed, "round": round, "set": set, "sorted": sorted, "str": str, "sum": sum, "tuple": tuple, "zip": zip, } def static_safety_check(source): try: tree = ast.parse(source) except SyntaxError as e: return False, f"SyntaxError: {e}" for node in ast.walk(tree): if isinstance(node, FORBIDDEN_NODES): return False, f"Forbidden AST node: {type(node).name}" if isinstance(node, ast.Name): if node.id in FORBIDDEN_NAMES or node.id.startswith(""): return False, f"Forbidden name: {node.id}" if isinstance(node, ast.Attribute): if node.attr.startswith(""): return False, f"Forbidden attribute: {node.attr}" if isinstance(node, ast.Call): if isinstance(node.func, ast.Name) and node.func.id in FORBIDDEN_NAMES: return False, f"Forbidden call: {node.func.id}" return True, "passed" def _worker_run_tests(source, function_name, tests, queue): try: safe_globals = {"builtins": ALLOWED_BUILTINS} safe_locals = {} compiled = compile(source, "", "exec") exec(compiled, safe_globals, safe_locals) fn = safe_locals.get(function_name) or safe_globals.get(function_name) if fn is None: queue.put({"ok": False, "error": f"{function_name} not found", "passed": 0, "total": len(tests)}) return passed = 0 details = [] for test in tests: args = test.get("args", []) kwargs = test.get("kwargs", {}) expected = test["expected"] result = fn(*args, **kwargs) ok = result == expected passed += int(ok) details.append({ "args": args, "kwargs": kwargs, "expected": expected, "result": result, "ok": ok, }) queue.put({"ok": passed == len(tests), "error": "", "passed": passed, "total": len(tests), "details": details}) except Exception as e: queue.put({"ok": False, "error": repr(e), "passed": 0, "total": len(tests)}) def run_unit_tests_safely(source, function_name, tests, timeout_seconds=3): safe, reason = static_safety_check(source) if not safe: return {"ok": False, "error": reason, "passed": 0, "total": len(tests), "details": []} ctx = mp.get_context("fork") queue = ctx.Queue() process = ctx.Process(target=_worker_run_tests, args=(source, function_name, tests, queue)) process.start() process.join(timeout_seconds) if process.is_alive(): process.terminate() process.join() return {"ok": False, "error": "timeout", "passed": 0, "total": len(tests), "details": []} if queue.empty(): return {"ok": False, "error": "no result returned", "passed": 0, "total": len(tests), "details": []} return queue.get() def code_complexity(source): try: blocks = cc_visit(source) if not blocks: return 1 return max(block.complexity for block in blocks) except Exception: return None def score_candidate(source, test_result): syntax_score = 1 if syntax_ok(source)[0] else 0 safety_score = 1 if static_safety_check(source)[0] else 0 passed = test_result.get("passed", 0) total = max(test_result.get("total", 1), 1) test_score = passed / total complexity = code_complexity(source) complexity_penalty = 0 if complexity is None else min(complexity / 20, 0.25) return syntax_score + safety_score + 3 * test_score - complexity_penalty

We build the utility layer that extracts generated Python functions from raw model outputs. We add syntax validation, static safety checks, restricted execution, unit-test execution, and timeout handling to make generated code easier to evaluate. We also calculate code complexity and create a scoring function to rank generated candidates by correctness, safety, and simplicity.

Copy CodeCopiedUse a different Browser

print("\n" + "=" * 90)

Generating Code and Defining Benchmark Tasks

Copy CodeCopiedUse a different Browser

print("Demo 1: Basic natural-language-to-code completion") print("=" * 90) basic_prompt = """# Write a Python function that returns the area of a circle.

The function should be named circle_area and should accept radius as input.

Do not print anything. Return the numeric result.

def circle_area(radius): """ basic_output = generate_text( basic_prompt, max_new_tokens=120, temperature=0.25, do_sample=True, num_return_sequences=1, )[0] print_code("Raw CodeGen output", basic_output) circle_source = extract_function_source(basic_output, "circle_area") print_code("Extracted function", circle_source if circle_source else "# No function extracted") circle_tests = [ {"args": [1], "expected": math.pi}, {"args": [2], "expected": 4 * math.pi}, ] if circle_source: print("Syntax:", syntax_ok(circle_source)) print("Safety:", static_safety_check(circle_source)) print("Complexity:", code_complexity(circle_source)) print("\n" + "=" * 90) print("Demo 2: Best-of-N generation with test-based reranking") print("=" * 90) TASKS = [ { "name": "factorial", "signature": "def factorial(n):", "instruction": "Return n factorial for a non-negative integer n. Use 1 for factorial(0).", "tests": [ {"args": [0], "expected": 1}, {"args": [1], "expected": 1}, {"args": [5], "expected": 120}, {"args": [7], "expected": 5040}, ], }, { "name": "is_palindrome", "signature": "def is_palindrome(text):", "instruction": "Return True if text is a palindrome after removing spaces and ignoring case, otherwise return False.", "tests": [ {"args": ["Race car"], "expected": True}, {"args": ["hello"], "expected": False}, {"args": ["Never odd or even"], "expected": True}, ], }, { "name": "fibonacci", "signature": "def fibonacci(n):", "instruction": "Return the nth Fibonacci number where fibonacci(0)=0 and fibonacci(1)=1.", "tests": [ {"args": [0], "expected": 0}, {"args": [1], "expected": 1}, {"args": [8], "expected": 21}, {"args": [10], "expected": 55}, ], }, { "name": "dedupe_keep_order", "signature": "def dedupe_keep_order(items):", "instruction": "Return a list with duplicate values removed while preserving the first occurrence order.", "tests": [ {"a

[truncated for AI cost control]

Salesforce CodeGen Tutorial: Generate, Validate, and Rerank Python Functions With Unit Tests and Safety Checks | AI News Hub