Leveraging Constrained Sampling for Fill-In-the-Middle Code Completion

Two octopodes flying through space... we want to control where they go to!

Existing constrained sampling methods only focus on generating code that will be a syntactically well-formed program on its own after completion. What if we have some remaining code suffix and want the entirety to become a valid program? In this post I explore how to leverage existing tooling to create well-formed Copilot-like Fill-In-the-Middle (FIM) code completions.

LLMs are greate for writing code. However, they tend to generate invalid code. Constrained sampling has emerged to resolve this issue. Impressively, it is able to handle arbitrary Context Free Grammars (CFGs), i.e. most programming language syntax, not only resulting in formally guaranteed correctness but even speedups over unconstrained sampling. [1] Constrained sampling even made it into the popular OpenAI Chat API, in case you wondered how the response_format parameter is implemented.

Fill-In-the-Middle is not handled by existing approaches

In popular posts that announce the existing methods, people usually congratulate on these results, but I also repeatedly found requests for a commonly used feature: Support for Fill-In-the-Middle (FIM) completions. In FIM, like when requesting a completion from copilot, we have code in front of the cursor, denoted \(p\) and behind the cursor, denoted \(s\). Ideally, we want to sample a code completion \(x\) that is syntactically valid together with \(p\) and \(s\), i.e. we want that the concatenation \(pxs\) is a valid program.

Consider the following example. In this JavaScript program, we would like a completion for the piece of code after the cursor marked with the pipe “|”. Valid completions could be “2;” or “2; let three = 3;“, but not “2; }“. Existing methods would allow all of these completions.

function foo() {
   let one = 1;
   let two = |
   let four = 4;
}

You may wonder: “Even though “2; }” is not a valid completion, we can still add more tokens to make it valid, i.e. make it into “2; } function bar() {“. Is this therefore just an issue related to the maximum number of tokens?”

The answer to this is no, even though the solutions to these two issues can be combined freely. However, consider the following language \(L\) of balanced 0s and 1s.

S → ε
S → 0S1

If we have the suffix \(s = 111, p = 0\), then sampling \(001\) should be disallowed, whereas it would be perfectly legal for the original language \(L\). We need to prevent sampling \(001\) though, because no future completion can complete \(001…111\) such that the result would become a valid word according to \(L\). We need to adjust our grammar to encode the fact that we want to enforce a given suffix.

Unifying constrained sampling and FIM

To resolve this issue, we recall a result from formal language theory: Context-Free-Grammars are closed under intersection with Regular Languages[2]. This means that we can define any regular language (those equivalent to basic regular expressions) and construct a language that only accepts the words that are *both* valid according to the Context-Free Grammar and the Regular Language. This new language is again Context-Free.

To generate such a language for our FIM problem with given \(p\) and \(s\) we define the following regular language: \( R_{(p,s)} = \left\{ p(x_1…x_n)s | x_{i} \in T \right\} \). Here \(T\) denotes the tokens of the language model we are sampling from, not the tokens of the CFG to constrain to.

If you are not convinced that this is indeed a regular expression, think of the regular expression p({|.join(T)})*s, informally written as pseudo-Python format string. I will further provide a proper Python implementation for this expression in the next section.

Our target language to constrain sampling to is then \( L \cap R_{(p,s)} \), where \( L \) is the original CFG we want to constrain to.

The great part is that this generates a CFG that we can feed into any of the existing CFG constraining methods and thus leverage them directly for proper well-formedness guarantees.

Putting it all together in Python

I am using the library PyFormLang[3] to perform the previously described manipulations. Below, you can find an implementation of how to construct the constrained grammar and an example of it applied to the standard example of a CFG, the set of words with balanced 0s and 1s. Of course, this construction applies to any CFG, including the syntax of languages like C or JavaScript.

Note that I construct the grammar here only for a given suffix \(s\). In practice you will likely not need to create a new grammar for every prefix, but just run the prefix together with the sampled tokens through the derived parser.

"""
Construct the post-prefix language of a grammar
"""
from typing import List

from pyformlang.regular_expression import PythonRegex
from pyformlang.regular_expression.python_regex import TRANSFORMATIONS
from pyformlang.cfg import CFG, Terminal
from pyformlang.pda import PDA


SPECIAL_CHARS = set(TRANSFORMATIONS.keys())


def escape_regex(s: str):
    for c in SPECIAL_CHARS:
        s = s.replace(c, rf"\{c}")
    return s


def get_vocab(lang: CFG | PDA):
    if isinstance(lang, CFG):
        vocab = [x.value for x in lang.terminals]
    elif isinstance(lang, PDA):
        vocab = [x.value for x in lang.input_symbols]
    else:
        raise NotImplementedError("wrong type of lang")
    return vocab


def fim_lang(
    lang: CFG | PDA,
    prefix: str = "",
    suffix: str = "",
    vocab: List[str] = None,
) -> PDA | CFG:
    if vocab is None:
        vocab = get_vocab(lang)
    prefix_regex = PythonRegex(
        f"{escape_regex(prefix)}({'|'.join(vocab)})*{escape_regex(suffix)}"
    )
    ppl = lang.intersection(prefix_regex)
    return ppl


balanced_01s = CFG.from_text(
    """
S -> ε
S -> 0 S 1
"""
)


def split_to_terms(s: str):
    return list(Terminal(c) for c in s)


def test_balanced_fim(self):
    fim_l = fim_lang(balanced_01s, "0", "111")
    assert fim_l.contains(split_to_terms("000111"))
    assert fim_l.contains(split_to_terms("00001111"))
    assert not fim_l.contains(split_to_terms("0011"))
    assert not fim_l.contains(split_to_terms("00111"))
Limitations and Future Directions

This section is basically the same as in my post about limiting the number of tokens in constrained sampling. For every new prefix/suffix combination, a new CFG is created, resulting in high cost to generate the token masks for sampling. The trade off with online masking in the spirit of the Completion Engine introduced by Synchromesh[4] is unclear without further experiments. Future directions could further try to come up with unified algorithms that directly decide whether a word can be completed with a given suffix – without the expensive step of generating the CFG again, or cleverly exploit common structures in the CFGs generated to enable reuse and thus amortization.

Citation

If you end up using these results for your work, please cite as follows

@misc{muendler2024fimconstrains,
      title={Leveraging Constrained Sampling for Fill-In-the-Middle Code Completion}, 
      author={Niels Mündler},
      year={2024},
      url={https://blog.nielstron.de/2024/08/14/leveraging-constrained-sampling-for-fill-in-the-middle-code-completion/}
}
References
  • [1] Beurer-Kellner et. al. 2024, Guiding LLMs The Right Way: Fast, Non-Invasive Constrained Generation and Ugare et. al. 2024, SynCode: LLM Generation with Grammar Augmentation
  • [2] Hopcroft & Ullman 1979, pp.135–136, Theorem 6.5
  • [3] Julien 2021, Pyformlang: An Educational Library for Formal Language Manipulation
  • [4] Poesia 2022, Synchromesh: Reliable code generation from pre-trained language models

Leave a Reply

Your email address will not be published. Required fields are marked *