What tricks are you thinking about?
Sharing the history still means you need to save the state of the autoregressive transformer, which is usually prohibitively large?
I'm talking about inference. We need to save is keys, we need all of them to compute next tokens. We don't need queries. But we can play the fact that each next token depends only on the previous. And in whatever gets out of each tranformer's block it's the same. Let's call it 'history'. Which is 2d array [prev_size, embed_size]. Typical will be 1024x512 = 0.5M, may be more depending on the model, but looks like still affordable. prev_size here is [0..max_prompt_size] as we do inference. The idea is that we don't need to recompute it every time. Just add one element as we compute each next token. And if we want to try several alternative tokens, we can put them in one batch, and they will have the same 'history'. We need just a copy, or better reference. This way the branching is almost free. As opposite to 'normal' way when everything is recomputed for each alternative token.