concurrent.futures
is the easiest way to run
parallel jobs in the best-effort manner. In case the heavy jobs are run
off-interpreter (e.g., NumPy) using the thread pool from concurrent.futures
can give you some noticeable performance benefit.
One example of concurrent.futures
is as follows:
import concurrent.futures
import datetime
import math
import resource
def worker(n):
print(f"{datetime.datetime.now()}: worker {n} started")
x = []
for i in range(10_000_000): # just some slow operation
x.append(math.sin(i))
print(f"{datetime.datetime.now()}: worker {n} return")
return x
def main():
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
jobs = [executor.submit(worker, n) for n in range(50)]
for future in concurrent.futures.as_completed(jobs):
_ = future.result()
print(f"{datetime.datetime.now()}: Memory usage {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")
if __name__ == "__main__":
main()
This is a simplest way of using concurrent.futures
: First create the thread
pool executor, then run jobs via the executor by specifying the worker function
and the argument (executor.submit(fn, *args, **kwargs)
). We can send a lot of
jobs to the executor, but the thread pool may only run a limited number of jobs
concurrently. In the example above, we submit 50 jobs in one shot but run only
3 jobs at a time. Running the above code give you the following:
2024-02-09 15:19:54.083284: worker 0 started
2024-02-09 15:19:54.083414: worker 1 started
2024-02-09 15:19:54.095935: worker 2 started
2024-02-09 15:19:57.032095: worker 0 return
2024-02-09 15:19:57.032184: worker 3 started
2024-02-09 15:19:57.039709: worker 2 return
2024-02-09 15:19:57.039772: worker 4 started
2024-02-09 15:19:57.046087: Memory usage 1039187968
2024-02-09 15:19:57.069717: worker 1 return
2024-02-09 15:19:57.082388: Memory usage 1054179328
2024-02-09 15:19:57.099680: worker 5 started
2024-02-09 15:19:57.130026: Memory usage 1074479104
2024-02-09 15:19:59.946041: worker 3 return
2024-02-09 15:19:59.946126: worker 6 started
2024-02-09 15:19:59.975450: worker 4 return
2024-02-09 15:19:59.975521: worker 7 started
2024-02-09 15:19:59.983047: Memory usage 2039955456
2024-02-09 15:19:59.994621: Memory usage 2044674048
...
2024-02-09 15:20:38.770484: worker 42 return
2024-02-09 15:20:38.770559: worker 45 started
2024-02-09 15:20:38.775674: Memory usage 14824030208
2024-02-09 15:20:38.779237: worker 43 return
2024-02-09 15:20:38.812374: worker 46 started
2024-02-09 15:20:38.824755: Memory usage 14840037376
2024-02-09 15:20:38.979729: worker 44 return
2024-02-09 15:20:38.979810: worker 47 started
2024-02-09 15:20:38.986076: Memory usage 14882963456
2024-02-09 15:20:41.812591: worker 45 return
2024-02-09 15:20:41.812663: worker 48 started
2024-02-09 15:20:41.818818: Memory usage 15393980416
2024-02-09 15:20:41.840824: worker 46 return
2024-02-09 15:20:41.853475: worker 49 started
2024-02-09 15:20:41.859693: Memory usage 15393980416
2024-02-09 15:20:41.977037: worker 47 return
2024-02-09 15:20:41.996684: Memory usage 15393980416
2024-02-09 15:20:43.886611: worker 48 return
2024-02-09 15:20:43.899240: Memory usage 15393980416
2024-02-09 15:20:43.917258: worker 49 return
2024-02-09 15:20:43.917344: Memory usage 15393980416
We can see that job 3 is not started until one of jobs 0, 1, 2 has finished.
Therefore, concurrent.futures
maintains the workload within the specified
bound.
The worker in the above generates a list of 10 million floats, which assumes
32-bit float, would be 40MB, before adding the overhead that Python introduced
to the data structures. It is obvious that the memory usage is climbing up
while the jobs are executed by the thread pool. The reason for the “memory
leak” is at the future
object: You can always get back the return value using
future.result()
, and the future
objects are those created by
executor.submit()
— they are in the jobs
list created before the
for-loop.
Therefore, in case we can discard the future
object that completed, we can
save some memory. It is easy to do in Python: as_completed()
function needs
only an iterable. Hence we can use a generator instead of a list, like the
following:
import concurrent.futures
import datetime
import math
import resource
def worker(n):
print(f"{datetime.datetime.now()}: worker {n} started")
x = []
for i in range(10_000_000): # just some slow operation
x.append(math.sin(i))
print(f"{datetime.datetime.now()}: worker {n} return")
return x
def main():
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
jobs = (executor.submit(worker, n) for n in range(50))
for future in concurrent.futures.as_completed(jobs):
_ = future.result()
print(f"{datetime.datetime.now()}: Memory usage {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")
if __name__ == "__main__":
main()
In this case, the memory consumption stayed flat:
2024-02-09 15:25:37.403959: worker 0 started
2024-02-09 15:25:37.404110: worker 1 started
2024-02-09 15:25:37.416765: worker 2 started
2024-02-09 15:25:40.267059: worker 0 return
2024-02-09 15:25:40.267160: worker 3 started
2024-02-09 15:25:40.279696: Memory usage 1044201472
2024-02-09 15:25:40.347980: worker 2 return
2024-02-09 15:25:40.348048: worker 4 started
2024-02-09 15:25:40.421628: worker 1 return
2024-02-09 15:25:40.421696: worker 5 started
2024-02-09 15:25:40.432426: Memory usage 1066827776
2024-02-09 15:25:40.542792: Memory usage 1066827776
2024-02-09 15:25:43.285008: worker 3 return
2024-02-09 15:25:43.285091: worker 6 started
2024-02-09 15:25:43.387106: Memory usage 1385267200
2024-02-09 15:25:43.447038: worker 4 return
2024-02-09 15:25:43.447112: worker 7 started
2024-02-09 15:25:43.521575: worker 5 return
2024-02-09 15:25:43.521615: Memory usage 1385267200
...
2024-02-09 15:26:23.008981: worker 42 return
2024-02-09 15:26:23.021109: worker 45 started
2024-02-09 15:26:23.110888: Memory usage 1512325120
2024-02-09 15:26:23.193712: worker 43 return
2024-02-09 15:26:23.193791: worker 46 started
2024-02-09 15:26:23.275376: Memory usage 1512325120
2024-02-09 15:26:23.379593: worker 44 return
2024-02-09 15:26:23.391285: worker 47 started
2024-02-09 15:26:23.466048: Memory usage 1512325120
2024-02-09 15:26:26.115932: worker 46 return
2024-02-09 15:26:26.116017: worker 48 started
2024-02-09 15:26:26.199723: Memory usage 1512325120
2024-02-09 15:26:26.210967: worker 45 return
2024-02-09 15:26:26.211051: worker 49 started
2024-02-09 15:26:26.297508: Memory usage 1512325120
2024-02-09 15:26:26.433841: worker 47 return
2024-02-09 15:26:26.519633: Memory usage 1512325120
2024-02-09 15:26:28.275848: worker 49 return
2024-02-09 15:26:28.345301: Memory usage 1512325120
2024-02-09 15:26:28.345372: worker 48 return
2024-02-09 15:26:28.399259: Memory usage 1512325120
I believe using a generator is the cleanest way to do. Otherwise, you can
always remove the item in the list using jobs.remove(future)
at the end of
the for-loop, or something similar if you use set or dict instead of a list.
In a color photo, the color may be distorted. Some adjustments can be used for correction, such as:
Let’s first consider how to tune the brightness and contrast.
If we consider a grayscale image of $I(x,y)$ over a 8-bit unsigned integer value of intensity, we can correct it into $I’(x,y)$ by linear transform, such as:
\[I' = \alpha I+\beta = \frac{255}{I_{\max} - I_{\min}}I+\frac{255I_{\min}}{I_{\max} - I_{\min}}\]This is stretching, resulting in an image with a minimum pixel value of 0 and a maximum pixel value of 255. A variation is to consider the pixel intensities as a histogram or a probability distribution. Then, the minimum and maximum values in the formula above are defined not at the absolute minimum and maximum but at a low and high percentile, respectively. Similar stretching can be applied to separate channels in an RGB image and combined. In code, using numpy and OpenCV:
def channel_stretch(img: np.ndarray, min_percentile=0.01, max_percentile=0.99) -> np.ndarray:
"""Apply linear contrast stretch on an RGB image"""
# apply stretch per channel
channels = []
for ch in cv2.split(img):
# get histogram, then convert to CDF
hist, bins = np.histogram(ch.flatten(), 256, [0, 255])
cdf = hist.astype(float).cumsum()
cdf = cdf / cdf.max()
# linear contrast stretch
min_pixel = bins[np.searchsorted(cdf, min_percentile)]
max_pixel = bins[np.searchsorted(cdf, max_percentile)]
alpha = 255.0 / (max_pixel - min_pixel)
beta = -min_pixel * 255.0 / (max_pixel - min_pixel)
new_ch = cv2.convertScaleAbs(ch, alpha=alpha, beta=beta)
channels.append(new_ch)
return cv2.merge(channels)
If we consider pixels in a histogram, we can level it out. That is, first convert the histogram into a probability distribution $F(x)$ where $x$ is uint8 runs from 0 to 255. Then each pixel $x$ is replaced with $y=\lfloor 255\times F(x)\rfloor$. This is called histogram equalization. It is not a linear transform but depends on the intensity distribution. The code to do this on each channel of an image:
def histogram_equalize(img: np.ndarray) -> np.ndarray:
"""Apply histogram equalization on each channel"""
return cv2.merge([cv2.equalizeHist(ch) for ch in cv2.split(img)])
This equalization depends on one channel of the entire image. An alternative way is to perform equalization based on neighboring pixels only. This is contrast limited adaptive histogram equalization or CLAHE. To apply CLAHE to all channels in an RGB image,
def clahe_channels(img: np.ndarray) -> np.ndarray:
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
channels = [clahe.apply(ch) for ch in cv2.split(img)]
return cv2.merge(channels)
Instead of applying to all RGB channels, we can also apply this only to the L channel in the CIELAB color space, or only the V channel in HSB color space. This adjust only the brightness without touching the hue components.
Another nonlinear transformation to pixel intensity is gamma correction, which historically is how TV tries to correct the problem of nonlinear intensity response in cathode-ray tube. In essence, this is to update the pixel intensity with \(I' = I^{1/\gamma}\) which the intensity value should be between 0 and 1. For pixels of uint8 values, we can implement gamma correction as:
def adjust_gamma(img: np.ndarray, gamma=1.0) -> np.ndarray:
"""Gamma adjustment on RGB image, which gamma>1 brighten the image,
and gamma<1 darken the image
"""
invgamma = 1.0 / gamma
table = (np.power(np.arange(256)/255.0, invgamma) * 255).astype(np.uint8)
new_img = cv2.LUT(img, table)
return new_img
Use of lookup table in OpenCV is faster than manipulating the matrix using NumPy. Compared to the previous algorithms, this requires the parameter $\gamma$ predefined. One way to figure out the value is to consider
\[\gamma = \frac{\log I_\text{in}}{\log I_\text{out}}\]and if we assume the expected midtone is 127.5 (midpoint of 0 to 255) and the midtone of the original image is as computed by the average, we can find $\gamma$ using:
def find_optimal_gamma(img: np.ndarray, mid=0.5) -> np.ndarray:
# with midtone set, best gamma is log(mean)/log(mid*255)
mean = np.mean(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY))
gamma = np.log(mean)/np.log(mid*255)
return gamma
This code allows an adjustable midtone if 127.5 is not the best value.
Let’s consider the problem of color balance.
An image with RGB channels with a different channel mean or range would make the picture look color-distorted. Without knowing what the actual color should be, it is difficult to correct. Of course, we can make assumptions, such as the range of each color should be the same or the means should be aligned. Applying histogram equalization across each channel, as in the code above, does just this.
White balance is a special case of color balance, which usually describes the color of white as appears in the image in a spectrum from yellow to blue. One algorithm to adjust white balance automatically is the gray world assumption. This means the average of all colors in a picture should be gray, i.e., neutral to any color.
Applying gray world assumption to white balance correction is as follows:
def grayworld(img: np.ndarray) -> np.ndarray:
"""White balance by gray world assumption"""
gray = img.mean()
ch_mean = img.mean(axis=(0, 1))
channels = [
cv2.convertScaleAbs(ch, alpha=gray/ch_mean[i])
for i, ch in enumerate(cv2.split(img))
]
return cv2.merge(channels)
This assumes the target gray level as the simple mean of all channels in all pixels (in which the brightness is not adjusted). Then, the mean color intensity is computed per color channel. The quotient between the two is the scaling factor, in which the channel with higher intensity receives the lower scaling factor.
Another way of white balance is to consider the CIELAB color space, in which the “a” and “b” channels are for the hue spectrum, and the center is white. Hence, we can correct the white balance by shifting these two channels:
def lab_shift(img: np.ndarray) -> np.ndarray:
"""Adjust white balance in Lab color space.
OpenCV's LAB color space run from 0 to 255 on all channels
"""
l, a, b = cv2.split(cv2.cvtColor(img, cv2.COLOR_RGB2LAB))
delta_a, delta_b = 127.5-a.mean(), 127.5-b.mean()
a = cv2.convertScaleAbs(a, beta=delta_a)
b = cv2.convertScaleAbs(b, beta=delta_b)
return cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)
]]>One way to let a model to handle context length significantly longer than it was pretrained is to do position interpolation, such as arXiv:2306.15595. And this paper used the NTK-aware interpolation and its improvements, dynamic NTK and NTK-by-part interpolations.
RoPE considers the feature vector of length $D$ as $x\in\mathbb{R}^D$ and it can also be represented in complex vector space $x\in\mathbb{C}^{D/2}$. The vector inner product in $\mathbb{R}^D$ should be Hermitian inner product in $\mathbb{C}^{D/2}$ to make them equal in real part. RoPE is to multiply $e^{im\theta}$ for some $\theta$ to the $m$-th token in the sequence $x_m\in\mathbb{C}^{D/2}$. This way, the inner product of two encoded tokens $x_m,x_n$ will depend only on the difference $m-n$ but not the exact positions $m$ or $n$.
The paper denotes the position encoding as the function $f(x_m, m, \theta_d)$, which the $d$-th component of $x_m$ is $x_{m,d}\in\mathbb{C}$ and it is transformed into $x_{m,d}e^{im\theta_d}$. The original design of RoPE was to make $\theta_d = b^{-2d/D}$ for some large constant $b=10^4$.
The position interpolation is not to extrapolate the position encoding because the model has not learned about the encoding beyond the range it was trained. If the model was pretrained with context length $L$, which $m=1,\dots,L$, then to extend the model to length $L’$ we make the encoding function
\[f'(x_m,m,\theta_d) = f(x_m,\frac{mL}{L'}, \theta_d) = f(x_m, m/s, \theta_d)\]which we further denote $s=L’/L > 1$ as the scale factor from the pretrained model.
We also defined the wavelength at the hidden dimension $d$ as
\[\lambda_d = \frac{2\pi}{\theta_d} = 2\pi b^{2d/D}\]Tancik et al (2020) suggest to look at the encoding with Neural Tangent Kernel theory, which the network cannot learn high frequency information if the input dimension is low and the embedding lacks high frequency components. RoPE with PI for longer context length does not introduce any higher frequency, which the authors of this paper argue as the reason for the increase of perplexity if the model was fine-tuned on longer context length but used with shorter input afterward.
The solution would be NTK-aware interpolation on RoPE, which is
\[\begin{aligned} f'(x_m,m,\theta_d) &= f(x_m,m,b'^{-2d/D}) & \text{where }b' &= b\cdot s^{D/(D-2)} \end{aligned}\]with the scale factor $s$ applied to $\theta_d$ part. This scheme is used by Code Llama (arXiv:2308.12950) with scaled base $b=10^6$. Note that this changes the frequency $\theta_d$ in the encoding.
However, this treats all hidden dimension $d$ equally, while it is learned that the wavelengths $\lambda_d$ varies. For a given $L$ in pretraining, there are some $d$ that $\lambda_d > L$ and that means this dimension’s encodings are not distributed evenly, and it works like absolute positional encoding. But if $\lambda_d \ll L$ then only relation positional information is provided.
Scaling up the RoPE with factor $s$ or a larger base $b’$ essentially make the dot product of two vectors rotated by a lesser amount, hence impairing LLM’s ability to understand local relationships. Therefore, the authors said the model would confuse on the positional order of close-by tokens. The proposal was to
The proposal first introduces the ratio $r_d=L/\lambda_d$ as the ratio between the original context size $L$ and wavelength $\lambda_d$. That is,
\[r_d = \frac{L}{\lambda_d} = \frac{L}{2\pi b'^{2d/D}}\]Then two thresholds $\alpha,\beta$ are introduced, with the ramp function defined as
\[\gamma(r) = \begin{cases} 0 & \text{if }r<\alpha \\ 1 & \text{if }r>\beta \\ \dfrac{r-\alpha}{\beta-\alpha} & \text{otherwise} \end{cases}\]and the NTK-by-part scheme is to use encoding function $f’(x_m, m, \theta_d) = f(x_m, m, h(\theta_d))$ where
\[h(\theta_d) = \big(1-\gamma(r_d)\big)\frac{\theta_d}{s} + \gamma(r_d)\theta_d\]Such that for hidden dimensions $d$,
The paper suggested for LLaMA family of models set $\alpha=1,\beta=32$.
In autoregressive generation, the sequence lengths are increasing in each step. There are two ways to run the inference. We can set a fixed $s=L’/L$ for the entire inference cycle, where $L’$ is an extended context size. This will have a performance discount at a length shorter than $L$ and degradation when length goes beyond $L’$.
The alternative is dynamically, use a different scale factor $s=\max(1, l’/L)$ in each forward pass, with the current sequence length $l’$. This allows the model to degrade gracefully rather than breaking down all of a sudden. This is called Dynamic NTK interpolation. The downside is when the model has kv-caching to help autoregressive generation, we need to modify the code to cache the kv-embedding without RoPE and apply encoding to the entire cache in each iteration.
The author proposed to add a temperature $t$ to the softmax step, named as attention scaling:
\(\text{softmax}_n\Big(\frac{q_m^\top k_n}{t\sqrt{D}}\Big)\) In implementation, this simply needs to scale both $q_m$ and $k_n$ by a factor $\sqrt{1/t}$ and the previous softmax implementation is left intact. This temperature parameter has a uniform impact on perplexity regardless of the token position. For LLaMA models, it is recommended to set $\sqrt{1/t}=0.1\ln(s)+1$
The YaRN method is to combine attention scaling and NTK-by-parts interpolation method.
The authors took 400 training steps, with 0.1% of the original pretraining corpus, to extend the context window. The training and evaluation proceduce follows arXiv:2306.15595:
Code LLaMa with dataset on 16K context has shown that the network can extrapolate up to 100K content without ever saw such context sizes during training. The authors of this paper showed that $s=32$ model can extrapolate up to 128K context even with only 64K data in training. Therefore, YaRN is efficient at transfer learning with increasing scale $s$.
The fine-tuned model is evaluated on the perplexity score, the passkey retrieval task, and other common LLM benchmark results. And found YaRN is able to train short and test long.
The structure of the repo is as follows. It is the LLaMA 2 model in its main branch. The older version is moved to another branch:
llama
├── __init__.py
├── generation.py
├── model.py
└── tokenizer.py
This is just the model code. To use it, such as to pretrain it from scratch (if you have such resources) or fine-tune it, you need to look into the scripts on the repo llama-recipes.
All language model for text should start with a tokenizer that breaks a string
into tokens. In LLaMA 2, the tokenizer.py
defines the class Tokenizer
with
the encode()
and decode()
methods.
The LLaMA tokenizer uses SentencePiece model, which is based on BPE and Unigram. It added the BOS, EOS, and PAD as special tokens. The tokens are numbered as integer, which the encode and decode functions are to convert between a string and a list of such integers.
There are several building block functions defined in the model.py
The static one is the dataclass ModelArgs
, which holds the parameters for
inference or construction of the model. The default are:
dim
= 4096n_heads
= n_layers
= 32max_seq_len
= 2048, but this number is doubled in the model, hence
effective sequence length limit is 4096The building block functions are described one by one as follows:
The LLaMA model uses RoPE as the position encoding. Essentially it performs complex multiplication $xe^{it}$ for input $x$ and position $t$. The $e^{it}$ part is pre-computed and cached, up to the max sequence length, in the following function:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
which the function torch.polar(x,t)
computes $xe^{it}$, and $x=1$ here. The
output of this function is a tensor of shape $T\times D$. When it is used, the
input tensor is usually in the shape of $N\times T\times d$ for a batch of $N$
sequences. To perform element-wise multiplication of the position encoding
tensor and the input tensor, the following function is used to reshape the
position encoding tensor to fit the input tensor using view()
:
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
In most cases, the above function is to add a new dimension as the first axis to the position encoding. The actual rotary embedding operation is defined in another function:
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
The first two lines are to transform xq
and xk
from shape $(T,D)$ into
shape $(T,\frac{D}{2},2)$. Then the third line transforms the frequency tensor,
and finally xq_out
is defined as the original xq_
elementwise multiply with
the position encoding, and take only the real part, and then reshaped back into
shape $(T,D)$.
The rotary embedding (RoPE) is explained as follows: The $d$-dimensional input vector $x$ is rearranged as $d/2$ pairs $(x_m^{(1)},x_m^{(2)})$. Here $m=1,\dots,T$ denotes the sequence position. Consider this as a coordinate pair in 2D plane, the transformation is to rotate with a constant angle $\theta$:
\[\textrm{RoPE}(x_m^{(1)},x_m^{(2)},m) = \begin{pmatrix}\cos m\theta & -\sin m\theta\\ \sin m\theta & \cos m\theta \end{pmatrix}\begin{pmatrix}x_m^{(1)}\\ x_m^{(2)}\end{pmatrix}\]The transformed output fulfills:
\[\begin{aligned} & \langle \text{RoPE}(x_m^{(1)},x_m^{(2)},m), \text{RoPE}(x_m^{(1)},x_m^{(2)},n)\rangle\\ = & \langle \text{RoPE}(x_m^{(1)},x_m^{(2)},m-n), \text{RoPE}(x_m^{(1)},x_m^{(2)},0)\rangle \end{aligned}\]which means the dot product is relative. In the implementation, we can consider $x$ at position $m$ is encoded as $xe^{im\epsilon}$ for some $\epsilon\in(0,\frac{\pi}{2N}]$. So the features $1,\dots,d$ will have feature $i$ pair up with feature $i+d/2$, and using angle $\theta_i = 10000^{-2(i-1)/d}$. The inner product above $\langle x, y\rangle = x\cdot \bar{y}$ is using complex inner product.
Lastly, to help multihead attention, there is a function to replicate the input
tensor x
for n_rep
times on the third dimension:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
At input, the shape of x
is (bs, slen, n_kv_heads, head_dim)
and at output,
it becomes shape (bs, slen, n_rep*n_kv_heads, head_dim)
. The replication is
done as a view hence no new memory allocation.
The overall transformer model is defined in the class Transformer
, which has
the following workflow:
tokens
is a tensor of token ids (shape of batch size $N\times$
sequence length $L$) and start_pos
as integer
start_pos
is for caching, useful in case of distributed modeltokens
into embedding h
using ParallelEmbedding
module from
fairscale (an PyTorch extension). The output tensor h
is of dimension $d$freqs_cis
at position range
start_pos:start_pos+seq_len
to match the inputmask
of size $L\times L$ such that its upper triangular elements
above offset start_pos
are all $-\infty$ and all other values are zero.
This is the mask for causal inference.TransformerBlock
. Each block transforms h
, with
the other parameters: start_pos
, freqs_cis
, and mask
.h
with RMSNorm
ColumnParallelLinear
from fairscale with no activation function, which
h
is a sequence of vectors of hidden dimension $d$Fairscale is an alternative to
torch.nn.parallel.DistributedDataParallel
(DDP) to do sharding on data,
model, and optimizer.
The building blocks to support the overall transformer architecture is as follows:
The class RMSNorm
calculates
where the mean is applied on the square $x^2$ along the last dimension (i.e., embedding dimension). This is implementing the layer normalization.
The class Attention
: This is where the parallelism applied in case of
distributed model execution
self.wq
, self.wk
, self.wv
to multiply with x
self.wo
for output, as shape $d\times d$
self.cache_k
and self.cache_v
as tensors of shape
$(N,T,H_{kv},d_H)$ for batch size $N$, sequence length $T$, number of
key/value head $H_{kv}$, and attention dimension $d_H$x
, start_pos
, freqs_cis
, and mask
freqs_cis
torch.matmul()
torch.matmul(A,Xv)
and the output has shape
$(N,H,T,d_H)$, which is then transposed into $(N,T,H,d_H)$. The
resulting dimension $T$ is from a 1-to-1 matching between $X_Q$ and $X_V$The class FeedForward
: This is just a fully-connected layer used by TransformerBlock
that computes
where:
TransformerBlock
as $4d_{in}$, but then adjusted
to $\lfloor\frac23 d_h\rfloor$ and then round up to the multiple of a factor
(e.g., 256)
ColumnParallelLinear
and $W_2$ is implemented
as RowParallelLinear
; which only W_2
has set gather_output=True
to
synchronize parallel runsThe class TransformerBlock
: It connects a Attention
module, a FeedForward
module, and two RMSNorm
modules (one for attention and one for feedforward).
Its workflow is:
of which:
At generation.py
, the class Llama
and the function sample_top_p()
are defined.
The function sample_top_p()
takes a production distribution tensor probs
and a probability threshold p
as input. Then:
probs
in descending order into probs_sort
(and remembers the original index), then computes the cumulative sum probs_sum
probs_sum - probs_sort > p
; the mask at position $i$ means the cumulative probability of positions 0 to $i-1$ is strictly beyond $p$probs_sort
(i.e., position $i$ until the end) and renormalize it to make it sum to 1The class Llama
ties everything together. In its constructor, it created a
Transformer
object and a Tokenizer
object. In the static factory method
build()
, it takes inputs ckpt_dir
(checkpoint dir), tokenizer_path
,
max_seq_len
, max_batch_size
, model_parallel_size
, and seed
, and returns
a Llama
object. In this build()
function:
torch.distributed
with nccl backend (the only backend for GPU, the other backends gloo and mpi are for CPU only)torch.manual_seed()
to seed RNG to 1*.pth
files in the checkpoint dir; based on the current machine’s index in the distributed cluster, load the indexed file by get_model_parallel_rank()
consolidated.{00,01}.pth
) are of the same byte sizeparams.json
from the checkpoint dir and
updates ModelArgs
, except the parameters max_seq_len
and max_batch_size
are provided by build()
only. Example:{"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
Tokenizer
from the tokenizer_path
, e.g., set the vocabulary sizeTransformer
object with the prepared arguments, then load the checkpoint. This works as one shardLlama
object with the Tokenizer
and Transformer
modelsIn the Llama
class, the method generate()
is decorated with a PyTorch
inference mode decorator. It takes as input prompt_tokens
(a batch of
prompts, as a list of list of vocab id), max_gen_len
, temperature
, top_p
,
logprobs
(boolean), echo
(boolean). It returns the list of output tokens
(list of list of ids) as well as the corresponding logprob.
total_len
to be the min of max_seq_len
and max_gen_len +
max_prompt_len
; this should be the max output sequence length that this
model can producetokens
: tensor of size (batch_size
, total_len
) filled with the pad
id, which is retrieved from the tokenizer
tokens
tensor with the input prompts (left aligned)token_logprobs
: tensor of size (batch_size
, total_len
) filled with zeroeos_reached
: tensor of size (1, batch_size
), filled with False to indicate if a prediction reached the EOS tokeninput_text_mask
as boolean tensor to tell if tokens
is not pad idThe workflow for generate()
is as follows:
min_prompt_len
reached total_len
, run the model once to get the
logit, then calculate token_lgoprobs
based on the cross entropy between
the logits and the original tokensmin_prompt_len
reached total_len
, with the position cursor cur_pos
:
a. if temperature=0, simply pick the next token by argmax
b. if temperature > 0, use softmax and sample_top_p()
to pick the next token
c. next token is filled into the tensor tokens
at position cur_pos
only if it is masked
d. then update token_logprobs
by comparing the model output logits to the “next token”
prev_pos:cur_pos
across the batch to generate next token of each position; only the last output is used for next_token
, which correspond to the position cur_pos+1
prev_pos
is 0; each iteration updates prev_pos
to cur_pos
prev_pos+1:cur_pos+1
tells how accurate is the model output using as much information as possible; all tokens at prev_pos+1:cur_pos+1
are considered equally in cross entropy because the entire sequence is presented to the model
e. update eos_reached
to check if EOS has been generated at cur_pos
; terminate the autoregressive for-loop if all input in the batch has EOS
f. update prev_pos
to cur_pos
before next iterationtokens
into list of tokens, cut off at EOS, optionally also produce the list of logprobThe methods text_completion()
and chat_completion()
are applications of
generate()
method. Both takes the similar input as the last one (e.g.,
temperature, top-p) but in particular, in text_completion()
,
self.generate()
to get the generated tokens and logprobsand in chat_completion()
,
Dialog
objects, which in turn is a list of Message
, a typed-dict of role
and content
[INST]
, [/INST]
, <<SYS>>
, <</SYS>>
)format the input: first message’s role can be “system”. If so, merge the first message with the second using the template
<<SYS>>
{first message}
<</SYS>>
{second message}
the list of dialog is assumed to have the prompt and answer interleaving; which then each pair is formatted as
[INST]
{prompt}
[/INST]
{answer}
and then the pairs are concatenated. The last message in dialog is the final prompt, also concatenated using the same template.
prompt_tokens
and sent to self.generate()
The repo has an example script, example_text_completion.py
. This is not a
vanilla PyTorch code, but using fairscale. Hence you cannot run the script with
a barebone python interpreter. The suggested command (from its readme) is:
torchrun --nproc_per_node 1 example_text_completion.py \
--ckpt_dir ../llama-2-7b \
--tokenizer_path ../llama-2-tokenizer/tokenizer.model
The version in Hugging Face has these removed, hence we can run with the python interpreter directly. Example code:
from transformers import AutoTokenizer
import transformers
import torch
import accelerate
model = "meta-llama/Llama-2-7b-chat-hf"
tokenizer=AutoTokenizer.from_pretrained(model)
pipeline=transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
max_length=1000,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
sequences = pipeline(
'Hi! I like cooking. Can you suggest some recipes?\n')
for seq in sequences:
print(f"Result: {seq['generated_text']}")
where trust_remote_code
is required to download the weight from the hub, and
device_map
should set to "auto"
when using accelerate
library. Other
parameters are to control the decoding strategies, such as multinomial
sampling, beam-search, top-K, and top-P. Only the beam-search and sampling
support num_return_sequences
more than 1.
It is an alternative to PyTorch DDP. It is a tool for sharding. To use:
fairscale.optim.oss.OSS(optim=torch.optim.SGD, params=model.parameters, **otherargs)
to
create an optimizer instead of the simple torch.optim.SGD
;
then replace the original model with wrapper ShardedDDP(model, optimizer)
;
afterward, run model.train()
as usualfairscale.nn.Pipe(model, balance=[a,b])
to run a model across two GPUs such
that the layers are loaded to each at a ratio of 2:1In LLaMA-2 code, the parallelism is determined by the parameters WORLD_SIZE
and LOCAL_RANK
, which the WORLD_SIZE
should match the number of
checkpoints, e.g., 13B is 2. The model will have sharding on the number of
attention heads in the multi-head attention module.
ColumnParallelLinear
and RowParallelLinear
Attention
class, self.cache_k
and self.cache_v
are set up to the
size of local heads to manage the xk
and xv
(i.e., key and value tensors)ColumnParallelLinear
is used for wq
, wk
, and wv
, the option
gather_output=False
makes all matrix multiplication local; as
RowParallelLinear
is used for wo
, the output is gathered, and specified
input_is_parallel=True
dim=5120
, n_heads=40
,
n_layers=40
. Based on the code, the attention dimension head_dim
should
be 5120/40=128 and the matrices $W^Q,W^K,W^V,W^O$ should be all 5120×5120
(after input vector multiplied with matrices, reshape it to (40,128) for
attention score computation). But in the .pth
file, these matrices are of
shape 5120×2560.My notes below:
A chart of signature models presented in Fig.2 of the paper:
By consensus, a language model is considered large if there are 10B parameters or more.
Pretrained models can be fine-tuned for downstream tasks
Instruction fine-tuning: incorporating additional prompts or instructions during fine-tuning to guide the output, such that users can get more fine-grained control over the output.
Parameter-efficient fine-tuning (PEFT): updating only a small number of parameters
Tokenization is a preprocessing step to convert strings into characters, subwords, or symbols. Some common segmentation techniques are as follows.
Self-attention aka intra-attention. This layer connects all the sequence positions to learn about the long-range dependencies \(A=\sigma(\dfrac{QK^\top}{\sqrt{n}})\) Cross-attention is seen in encoder-decoder architectures. The output of the encoder blocks acts as queries, and input from the decoder are the keys and values.
The native way of computing attention by the full sequence is called full attention. The complexity is $O(n^2)$. Alternatively, we can approximate it with sparse attention. To speed up the access to GPU memory, there is a Flash Attention technique to tackle the bottleneck of memory in computing attention by tiling memory blocks. Flash attention allows longer context-length in LLMs.
Transformer introduced positional encoding because attention is agnostic to input sequence order. There are absolute (assigning unique identifier to each position) and relative encodings (to signal relative dependencies). Encodings are added to the input sequence. Relative positional encoding can be learned.
Two famous relative encodings are Alibi and RoPE:
The common variations of activation functions used in LLMs are as follows:
Goal of LN is for faster convergence. There are two ways of doing normalization:
The training data for LLM are filtered by (1) a classifier such that only high-quality text is selected (2) heuristics-based rules based on language, metrics, statistics, and keywords.
Training data are also deduplicated at multiple levels, from sentences to documents, and removing personal information using heuristics.
Transformers are designed as sequence transduction models for machine translation. The original architecture is an encoder-decoder.
Objective of an LLM is to predict the next token based on the input sequence. The causal decoder architecture restricts the flow of information backward, i.e., predicting token $\hat{t}k$ using only the tokens $t{1:k-1}$. This is the most widely used variant of LLMs.
Another decoder-only architecture is prefix decoder, which the sequence is always fully visible on a portion of the input (usually the first few tokens of a sequence). The prefix decoder is non-causal.
Alignment tuning: To align the model to human preferences
Pretraining: The first stage in building LLM. The model is trained in a self-supervised manner on a large corpus to predict the next tokens given the input.
Fine-tuning: There are multiple adaptations:
Parameter-efficient tuning: Train LLMs with fewer resources. Multiple approaches are used
Prompt engineering: In-context learning is also known as few-shot learning.
Reasoning: If LLM has been trained on reasoning datasets, it can generate reasons using a prompting style. Techniques of reasoning include chain-of-thought (CoT), which the model generates outcomes with step-by-step reasoning; tree-of-thought (ToT), which the model explores multiple reasoning paths with look ahead and backtrack, and self-consistency (arXiv:2203.11171) which generates multiple responses and select the most frequent answer.
T5: an encoder-decoder model for text-to-text NLP problems. Train with masked language modeling objective but the masked span is replaced with a single mask token. The model can be fine-tuned using adapter layers for downstream tasks.
GPT-3: same architecture as GPT-2 but with dense and sparse attention. It can train on larger batch size with a lower learning rate. GPT-3 has a model parameters of 175B.
mT5: A multilingual T5 model trained on mC4 dataset with 101 languages. The vocab size is 250K. It is suggested to use a small amount of pretraining datasets to include all languages together with English data for fine-tuning.
Gopher: A family of models from 44M to 280B parameters to study the effect of scale on LLM performance.
GPT-NeoX-20B: A model follows GPT-3 that is trained on Pile dataset without deduplication. it has parallel attention and FC layers in a transformer block, uses rotary positional encoding (RoPE), and only dense attention.
OPT: as a clone of GPT-3, which training employs dynamic loss scaling and restarts from earlier checkpoint with lower learning rate whenever loss divergence is observed.
BLOOM: Causal decoder model trained on the ROOTS corpus. It uses ALiBi positional encoding and adds a normalization layer after the embedding layer. These changes are found to stabilize training.
Chinchilla: Causal decoder trained on the same dataset as Gopher. Using AdamW optimizer. Chinchilla found the relationship that model size should double for every doubling of training tokens.
PaLM: Causal decoder with parallel attention and FC layers. It uses SwiGLU activation, and RoPE. If loss is spiked during training, it restart from 100 steps earlier and skips 200-500 batches of training data. PaLM-2 is a smaller multi-lingual variant of it, but trained on a better quality dataset for more iterations.
LLaMA: Decoder-only language models, famous for parameter-efficient and instruction tuning. LLaMA-1 implements casual attention by not storing and computing masked attention weights and key/query scores. LLaMA-2 has a chat model for dialog generation
CodeGen: Trained on both natural language and programming language data, in the sequence of Pile, BigQuery, and BigPython datasets. It also proposed a multi-step approach to synthesizing code. There is a multi-turn programming benchmark (MTPB) to evaluate multi-step program synthesis.
CodeT5+: Modified CodeT5 with shallow encoder and deep decoder to train on multiple stages, first on code data, then on text-code pairs.
StarCoder: Decoder-only model with Flash Attention to scale up the context length to 8K. It outperforms PaLM, LLaMA on HumanEval and MBPP benchmarks.
LaMDA: Decoder-only model pretrained on public dialog data. It produces responses with high quality, safety, and groundedness.
T0, mT0: Use templates to convert existing dataset into prompt datasets then trained the model.
Tk-Instruct: Fine-turned T5 with in-context instructions
OPT-IML, Flan: trained with task datasets
Flan-T5: Trained on 1.88M CoT samples
Techniques in fine-tuning:
To make it an assistant, you need to integrate it with your code editor. Neovim, unfortunately, has a smaller user base and not much progress yet. But there is a VSCode plugin that can use FauxPilot. FauxPilot is a project try to let you self-host a server compatible to GitHub Copilot. Theoretically you can use the GitHub Copilot plugin in your editor (i.e., Neovim has it) but switch to the FauxPilot backend. But since GitHub hard-coded the server address in the plugin, you need to somehow hack the plugin to make it work. The dedicated FauxPilot plugin allows you to configure a different hostname/port number for the server, and that would be more convenient. Of course, how to communicate with the editor so that you can extract the context and provide suggestions seamlessly from the UX perspective would be another story. But the point is, there’s a solution for the client (i.e., editor such as VSCode or neovim). And there’s a model (e.g., CodeGen2). FauxPilot hard-coded to use Salesforce’s CodeGen model.
Indeed, I believe FauxPilot made things too complicated. Of course, its merit is to have a professional deployment of the self-hosted Copilot clone using Docker. But if the goal is to try out models with your IDE, that’s too heavy. Therefore, I trimed down FauxPilot to take only the web interface part: Endpoints are implemented as REST API using FastAPI and uvicorn (hence server code can be asynchronous). From the web requests, we get the code that the user typed as a string (together with some parameters such as model temperature) and we can invoke the model to produce output. The interaction with the model should be a blackbox to the REST API. Hence it is designed as such.
The code is at here: https://github.com/righthandabacus/fauxpilot_lite
See the readme for more details.
]]>The other way to run this would be using conda. It is special because conda is not a Python virtualenv. An environment in conda can come with other binaries, such as CUDA library. Hence you can conda install cudatoolkit and then conda install pytorch. These are conda-specific build that assumed libraries are installed in non-standard locations.
At the time of writing, we have Python 3.11.4, PyTorch 2.0.1, and TensorFlow 2.13.0. Luckily, PyTorch 2.0 and TensorFlow 2.13 both depends on CUDA 11.8 (CUDA 12 is not supported yet). But, unfortunately, conda does not have TensorFlow 2.13 yet. In order to get everything in the same conda environment, seems it is what should be done:
sudo apt-get install cuda-11-8
mamba create -n <name> python=3.11.4
mamba activate <name>
mamba install -c pytorch -c nvidia pytorch torchvision torchaudio pytorch-cuda=11.8
pip install tensorflow # 2.13.0 using system CUDA
If TensorFlow 2.12 is acceptable, we can indeed do (the build is necessary)
mamba install tensorflow=2.12.0=gpu_py311h65739b5_0
But in the case above, we have TWO copies of CUDA installed: one at system level and another inside conda environment. The pip installed tensorflow uses the former, and conda installed pytorch uses the latter.
]]>SSH-BASED VIRTUAL PRIVATE NETWORKS
ssh contains support for Virtual Private Network (VPN) tunnelling using
the tun(4) network pseudo-device, allowing two networks to be joined
securely. The sshd_config(5) configuration option PermitTunnel controls
whether the server supports this, and at what level (layer 2 or 3
traffic).
The following example would connect client network 10.0.50.0/24 with
remote network 10.0.99.0/24 using a point-to-point connection from
10.1.1.1 to 10.1.1.2, provided that the SSH server running on the gateway
to the remote network, at 192.168.1.15, allows it.
On the client:
# ssh -f -w 0:1 192.168.1.15 true
# ifconfig tun0 10.1.1.1 10.1.1.2 netmask 255.255.255.252
# route add 10.0.99.0/24 10.1.1.2
On the server:
# ifconfig tun1 10.1.1.2 10.1.1.1 netmask 255.255.255.252
# route add 10.0.50.0/24 10.1.1.1
Client access may be more finely tuned via the /root/.ssh/authorized_keys
file (see below) and the PermitRootLogin server option. The following
entry would permit connections on tun(4) device 1 from user “jane” and on
tun device 2 from user “john”, if PermitRootLogin is set to
“forced-commands-only”:
tunnel="1",command="sh /etc/netstart tun1" ssh-rsa ... jane
tunnel="2",command="sh /etc/netstart tun2" ssh-rsa ... john
Since an SSH-based setup entails a fair amount of overhead, it may be more
suited to temporary setups, such as for wireless VPNs. More permanent
VPNs are better provided by tools such as ipsecctl(8) and isakmpd(8).
which the command to launch the VPN is as follows (routing still needed):
ssh \
-o PermitLocalCommand=yes \
-o LocalCommand="sudo ifconfig tun5 192.168.244.2 pointopoint 192.168.244.1 netmask 255.255.255.0" \
-o ServerAliveInterval=60 \
-w 5:5 vpn@example.com \
'sudo ifconfig tun5 192.168.244.1 pointopoint 192.168.244.2 netmask 255.255.255.0; echo tun0 ready'
]]>The contribution of this paper: (1) A process to create high-quality synthetic dataset for super-resolution, (2) a network for SR, especially using U-Net in the discriminator with spectral normalization to increase discriminator capability.
The classical degradation model includes blur, downsampling, noise, and JPEG compression:
\[\mathbf{x} = D(\mathbf{y}) = [(\mathbf{y} \oast \mathbf{k})\]first-order vs high-order degradation modeling for real-world degradation sinc filter for ringing and overshoot artifacts discriminator of more powerful capability gradient feedback from discriminator needs to be more accurate for local detail nehancement U-net design with spectral normalization (SN) regularization
The optimization target for a super-resolution algorithm is usually MSE between the pixels of the high-resolution image (HR) and that of the super-resolution output (SR). Minimizing the MSE is effectively minimizing the PSNR, or SSIM. But MSE cannot capture perceptual differences, such as texture details. The major contribution of this paper is to introduce a perceptual loss function using the feature maps of the VGG network.
The SR-GAN is using a CNN architecture to process images. Some key design components in the CNN:
Adversarial min-max problem:
\(\min_{\theta_G} \max_{\theta_D} \Big\{ \mathbb{E}_{I^{\text{HR}}\in T^{\text{HR}}}\Big[\log D_{\theta_D}(I^{\text{HR}})\Big] + \mathbb{E}_{I^{\text{LR}}\in T^{\text{LR}}}\Big[\log\Big(1-D_{\theta_D}\big(G_{\theta_G}(I^{\text{LR}})\big)\Big)\Big] \Big\}\)
Model design: (Fig.4 in the paper on page 5)
The key to GAN training is the perceptual loss function, defined as equation (3) in the paper: \(l^{\text{SR}} = l_X^{\text{SR}} + 10^{-3} l_{\text{Gen}}^{\text{SR}}\) which $l_X^{\text{SR}}$ is the content loss (based on VGG19 network features) and $l_{\text{Gen}}^{\text{SR}}$ is the adversarial loss (based on discriminator).
The content loss $l_X^{\text{SR}}$ is modeled after pixel-wise MSE loss, which is shown to positively correlate with PSNR. However, MSE loss on pixels tend to overly smooth textures as it fail to account for the high frequency content. The paper propose to do MSE loss on feature output from VGG19 of the layer $\phi_{5,4}$, i.e., the 4th conv layer in the block preceding the 5th pooling layer (note not the 4th conv layer from beginning, but the 4th in that block starting from an activation layer). There are $C=512$ feature channels. The MSE is computing elementwise.
The adversarial loss or the generative loss $l_{\text{Gen}}^{\text{SR}}$ is the cross entropy over all training samples: $$ l_{\text{Gen}}^{\text{SR}} =
The model is trained as follows: BSD300 dataset is used as the test set. The model is designed for a scale factor of $4\times$ or $16\times$ the pixel count. PSNR (in dB) and SSIM are used as the evaluation metric. The model is compared to the upscaling algorithms nearest neighbor, bicubic, SRCNN (Dong et al, 2014), and SelfExSR (Huang et al, 2015).
Training set is a random sample of 350K images from ImageNet database, which LR images are from downsampling the HR images (RGB) using bicubic $4\times$ downsampling. Then 16 random $96\times 96$ subimages are cropped from distinct image samples for a mini-batch.
The LR images (input to generator) are in pixel range of $[0,1]$ and HR images (output from generator) are in range of $[-1,1]$. The VGG output features is scaled by a factor of $1/12.75$ to make the MSE on VGG features comparable to pixel MSE loss.
Training is using Adam optimizer ($\beta_1=0.9$) with learning rate of $10^{-4}$ for the first 100K update steps and learning rate $10^{-5}$ for another 100K update steps.
There are quite a number of implementations on the web. Below is what I polished from various sources:
#!/usr/bin/env python
# coding: utf-8
"""
Based on the paper
"""
import os
import cv2
import numpy as np
import tensorflow as tf
import tqdm
from tensorflow.keras.layers import \
Input, Conv2D, LeakyReLU, BatchNormalization, Flatten, Dense, PReLU, Add, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.losses import BinaryCrossentropy, binary_crossentropy, mean_squared_error
from tensorflow.keras.optimizers import Adam, SGD
#
# Data generator
#
def make_dataset(image_dir, hires_size=(256,256), lores_size=(64,64), batch_size=8):
"""Tensorflow dataset of batches of (lores,hires) images"""
hires = tf.keras.utils.image_dataset_from_directory(image_dir, labels=None,
color_mode="rgb",
image_size=hires_size,
batch_size=None)
hires = hires.batch(batch_size, drop_remainder=True)
lores = hires.map(lambda nhwc: tf.image.resize(nhwc, lores_size))
dataset = tf.data.Dataset.zip((hires, lores))
return dataset
#
# Discriminator
#
def discriminator_block(input, n_filters, strides=1, bn=True, name_prefix=""):
"""Repeated discriminator block. Batch normalization is not used on the first block"""
y = Conv2D(n_filters, (3, 3), strides, padding="same", name=name_prefix+"_conv")(input)
if bn:
y = BatchNormalization(momentum=0.8, name=name_prefix+"_bn")(y)
y = LeakyReLU(alpha=0.2, name=name_prefix+"lrelu")(y)
return y
def discriminator_model(input, name="discriminator"):
"""The complete discriminator that takes an input image and output a logit value"""
n_filters = 64
# k3n64s1 and k3n64s2
y = discriminator_block(input, n_filters, bn=False, name_prefix="block1")
y = discriminator_block(y, n_filters, strides=2, name_prefix="block2")
# k3n128s1 and k3n128s2
y = discriminator_block(y, n_filters*2, name_prefix="block3")
y = discriminator_block(y, n_filters*2, strides=2, name_prefix="block4")
# k3n256s1 and k3n256s2
y = discriminator_block(y, n_filters*4, name_prefix="block5")
y = discriminator_block(y, n_filters*4, strides=2, name_prefix="block6")
# k3n512s1 and k3n512s2
y = discriminator_block(y, n_filters*8, name_prefix="block7")
y = discriminator_block(y, n_filters*8, strides=2, name_prefix="block8")
# Dense layers and logit output
y = Flatten(name="flatten")(y)
y = Dense(n_filters*16, name="fc1")(y)
y = LeakyReLU(alpha=0.2, name="lrelu")(y)
output = Dense(1, name="fc2")(y) # no sigmoid act, to make logit output
return Model(inputs=input, outputs=output, name=name)
#
# Generator
#
def residual_block(input, name_prefix=""):
"""Residual block in generator"""
# two layers of k3n64s1
y = Conv2D(64, (3, 3), padding="same", name=name_prefix+"_conv1")(input)
y = BatchNormalization(momentum=0.5, name=name_prefix+"_bn1")(y)
y = PReLU(shared_axes=[1, 2], name=name_prefix+"_prelu")(y)
y = Conv2D(64, (3, 3), padding="same", name=name_prefix+"_conv2")(y)
y = BatchNormalization(momentum=0.5, name=name_prefix+"_bn2")(y)
y = Add(name=name_prefix+"_add")([input, y]) # skip connection
return y
def upscale_block(input, name_prefix=""):
"""Upscale the image 2x, used at the end of the generator network
"""
# k3n256s1
y = Conv2D(256, (3, 3), padding="same", name=name_prefix+"_conv")(input)
y = tf.nn.depth_to_space(y, 2) # 2x upsampling
y = PReLU(shared_axes=[1, 2], name=name_prefix+"_prelu")(y)
return y
def generator_model(input, num_res_blocks=16, name="generator"):
"""Create the generator model of SR-GAN for 4x super-resolution"""
# k9n64s1 and PReLU layer before the residual block
y = Conv2D(64, (9, 9), padding="same", name="entry_conv")(input)
y = PReLU(shared_axes=[1, 2], name="entry_prelu")(y)
# B times the residual blocks
res_input = y
for n in range(num_res_blocks):
y = residual_block(y, name_prefix=f"residual{n}")
# k3n64s1 Conv+BN block
y = Conv2D(64, (3, 3), padding="same", name="mid_conv")(y)
y = BatchNormalization(momentum=0.5, name="mid_bn")(y)
y = Add(name="mid_add")([y, res_input])
# two upscale blocks
y = upscale_block(y, name_prefix="up1")
y = upscale_block(y, name_prefix="up2")
# k9n3s1 conv at output
output = Conv2D(3, (9, 9), padding="same", name="out_conv")(y)
return Model(inputs=input, outputs=output, name=name)
#
# VGG model for content loss
#
def vgg_model(output_layer=20):
"""Create VGG19 model for measuring the perceptual loss
"""
# take VGG model from Keras, output at layer "block5_conv4" (20),
# paper referred this layer as \phi_{5,4}
vgg = tf.keras.applications.VGG19(input_shape=(None, None, 3), weights="imagenet", include_top=False)
model = Model(inputs=vgg.input, outputs=vgg.layers[output_layer].output, name="VGG19")
model.trainable = False # need model.compile()
for layer in model.layers:
layer.trainable = False # no need model.compile()
return model
#
# Training
#
def save_weights(generator, discriminator, epoch, basedir="checkpoint"):
"""Syntax sugar for saving the generator and discriminator models"""
os.makedirs(basedir, exist_ok=True)
gen_path = os.path.join(basedir, f"generator_{epoch}.h5")
disc_path = os.path.join(basedir, f"discriminator_{epoch}.h5")
generator.save(gen_path)
discriminator.save(disc_path)
def main():
image_dir = "dataset_images"
batch_size = 8
n_epochs = 100
# try to build and print the discriminator
hr_input = Input(shape=(256, 256, 3))
discriminator = discriminator_model(hr_input)
discriminator.summary(line_length=120, expand_nested=True, show_trainable=True)
# try to build and print the generator (1/4 size of the discriminator input)
lr_input = Input(shape=(64, 64, 3))
generator = generator_model(lr_input)
generator.summary(line_length=120, expand_nested=True, show_trainable=True)
# VGG model to reuse for feature extraction during loss calculation
vgg = vgg_model()
vgg.summary(line_length=120, expand_nested=True, show_trainable=True)
# The loss metrics
ones = tf.ones(batch_size)
zeros = tf.ones(batch_size)
def content_loss(hires, supres):
"""Use VGG model to compare features extracted from hires and supre-res images.
Keras VGG model expects "caffe" image format (BGR, mean-shifted), hence
preprocess_input() is required. This function is for use with model.compile()
Args:
hires: Hires image, pixels in [0,255]
supres: Generator output, pixels in [0,1] supposedly
Returns:
tf.Tensor of a scalar value
"""
supres = tf.keras.applications.vgg19.preprocess_input(tf.clip_by_value((supres+1)*127.5, 0, 255))
hires = tf.keras.applications.vgg19.preprocess_input(hires)
hires_feat = vgg(hires, training=False) / 12.75
supres_feat = vgg(supres, training=False) / 12.75
return tf.math.reduce_mean(tf.math.squared_difference(hires_feat, supres_feat))
disc_loss = BinaryCrossentropy(from_logits=True)
def gan_loss(hires, supres):
"""Generator perceptual loss = content loss + 1e-3 * adversarial loss"""
disc_output = discriminator(supres, training=False)
content = content_loss(hires, supres)
adversarial = disc_loss(ones, disc_output)
return content + 1e-3 * adversarial
# Optmizers for use in training: Separate because these optimizers are stateful
gen_opt = Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
disc_opt = Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
# compile models
generator.compile(loss=gan_loss, optimizer=gen_opt)
discriminator.compile(loss=disc_loss, optimizer=disc_opt)
# training loop
dataset = make_dataset(image_dir, batch_size=batch_size).prefetch(tf.data.AUTOTUNE)
p_mean = tf.keras.metrics.Mean() # to average perceptual loss
d_mean = tf.keras.metrics.Mean() # to average discriminator loss
for epoch in range(n_epochs):
with tqdm.tqdm(dataset, unit="step", desc=f"Epoch {epoch}") as tqdmbar:
for hires_batch, lores_batch in tqdmbar:
# train the discriminator; generator input is [0,1] output is [-1,1]
lores_batch /= 255.0
supres_batch = generator(lores_batch, training=False) # output pixel [-1,1]
disc_loss0 = discriminator.train_on_batch(supres_batch, zeros)
disc_loss1 = discriminator.train_on_batch(hires_batch/127.5-1, ones) # convert [0,255] -> [-1,1]
# train the generator
percep_loss = generator.train_on_batch(lores_batch, hires_batch)
p_mean.update_state(percep_loss)
d_mean.update_state(disc_loss0+disc_loss1)
tqdmbar.set_postfix(percep=f"{p_mean.result():.3f}",
disc=f"{d_mean.result():.3f}")
# save model at end of each epoch
save_weights(generator, discriminator, epoch+1)
p_mean.reset_states()
d_mean.reset_states()
main()
]]>