001/*
002Copyright 2006 Jerry Huxtable
003
004Licensed under the Apache License, Version 2.0 (the "License");
005you may not use this file except in compliance with the License.
006You may obtain a copy of the License at
007
008   http://www.apache.org/licenses/LICENSE-2.0
009
010Unless required by applicable law or agreed to in writing, software
011distributed under the License is distributed on an "AS IS" BASIS,
012WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013See the License for the specific language governing permissions and
014limitations under the License.
015*/
016
017package com.jhlabs.image;
018
019import java.awt.*;
020import java.awt.geom.*;
021import java.awt.image.*;
022import java.awt.color.*;
023import com.jhlabs.math.*;
024
025/**
026 * A filter which use FFTs to simulate lens blur on an image.
027 */
028public class LensBlurFilter extends AbstractBufferedImageOp {
029    
030    private float radius = 10;
031        private float bloom = 2;
032        private float bloomThreshold = 255;
033    private float angle = 0;
034        private int sides = 5;
035
036        /**
037         * Set the radius of the kernel, and hence the amount of blur.
038         * @param radius the radius of the blur in pixels.
039     * @see #getRadius
040         */
041        public void setRadius(float radius) {
042                this.radius = radius;
043        }
044        
045        /**
046         * Get the radius of the kernel.
047         * @return the radius
048     * @see #setRadius
049         */
050        public float getRadius() {
051                return radius;
052        }
053
054        /**
055         * Set the number of sides of the aperture.
056         * @param sides the number of sides
057     * @see #getSides
058         */
059        public void setSides(int sides) {
060                this.sides = sides;
061        }
062        
063        /**
064         * Get the number of sides of the aperture.
065         * @return the number of sides
066     * @see #setSides
067         */
068        public int getSides() {
069                return sides;
070        }
071
072        /**
073         * Set the bloom factor.
074         * @param bloom the bloom factor
075     * @see #getBloom
076         */
077        public void setBloom(float bloom) {
078                this.bloom = bloom;
079        }
080        
081        /**
082         * Get the bloom factor.
083         * @return the bloom factor
084     * @see #setBloom
085         */
086        public float getBloom() {
087                return bloom;
088        }
089
090        /**
091         * Set the bloom threshold.
092         * @param bloomThreshold the bloom threshold
093     * @see #getBloomThreshold
094         */
095        public void setBloomThreshold(float bloomThreshold) {
096                this.bloomThreshold = bloomThreshold;
097        }
098        
099        /**
100         * Get the bloom threshold.
101         * @return the bloom threshold
102     * @see #setBloomThreshold
103         */
104        public float getBloomThreshold() {
105                return bloomThreshold;
106        }
107
108
109    public BufferedImage filter( BufferedImage src, BufferedImage dst ) {
110        int width = src.getWidth();
111        int height = src.getHeight();
112        int rows = 1, cols = 1;
113        int log2rows = 0, log2cols = 0;
114        int iradius = (int)Math.ceil(radius);
115        int tileWidth = 128;
116        int tileHeight = tileWidth;
117
118        int adjustedWidth = (int)(width + iradius*2);
119        int adjustedHeight = (int)(height + iradius*2);
120
121                tileWidth = iradius < 32 ? Math.min(128, width+2*iradius) : Math.min(256, width+2*iradius);
122                tileHeight = iradius < 32 ? Math.min(128, height+2*iradius) : Math.min(256, height+2*iradius);
123
124        if ( dst == null )
125            dst = new BufferedImage( width, height, BufferedImage.TYPE_INT_ARGB );
126
127        while (rows < tileHeight) {
128            rows *= 2;
129            log2rows++;
130        }
131        while (cols < tileWidth) {
132            cols *= 2;
133            log2cols++;
134        }
135        int w = cols;
136        int h = rows;
137
138                tileWidth = w;
139                tileHeight = h;//FIXME-tileWidth, w, and cols are always all the same
140
141        FFT fft = new FFT( Math.max(log2rows, log2cols) );
142
143        int[] rgb = new int[w*h];
144        float[][] mask = new float[2][w*h];
145        float[][] gb = new float[2][w*h];
146        float[][] ar = new float[2][w*h];
147
148        // Create the kernel
149                double polyAngle = Math.PI/sides;
150                double polyScale = 1.0f / Math.cos(polyAngle);
151                double r2 = radius*radius;
152                double rangle = Math.toRadians(angle);
153                float total = 0;
154        int i = 0;
155        for ( int y = 0; y < h; y++ ) {
156            for ( int x = 0; x < w; x++ ) {
157                double dx = x-w/2f;
158                double dy = y-h/2f;
159                                double r = dx*dx+dy*dy;
160                                double f = r < r2 ? 1 : 0;
161                                if (f != 0) {
162                                        r = Math.sqrt(r);
163                                        if ( sides != 0 ) {
164                                                double a = Math.atan2(dy, dx)+rangle;
165                                                a = ImageMath.mod(a, polyAngle*2)-polyAngle;
166                                                f = Math.cos(a) * polyScale;
167                                        } else
168                                                f = 1;
169                                        f = f*r < radius ? 1 : 0;
170                                }
171                                total += (float)f;
172
173                                mask[0][i] = (float)f;
174                mask[1][i] = 0;
175                i++;
176            }
177        }
178                
179        // Normalize the kernel
180        i = 0;
181        for ( int y = 0; y < h; y++ ) {
182            for ( int x = 0; x < w; x++ ) {
183                mask[0][i] /= total;
184                i++;
185            }
186        }
187
188        fft.transform2D( mask[0], mask[1], w, h, true );
189
190        for ( int tileY = -iradius; tileY < height; tileY += tileHeight-2*iradius ) {
191            for ( int tileX = -iradius; tileX < width; tileX += tileWidth-2*iradius ) {
192//                System.out.println("Tile: "+tileX+" "+tileY+" "+tileWidth+" "+tileHeight);
193
194                // Clip the tile to the image bounds
195                int tx = tileX, ty = tileY, tw = tileWidth, th = tileHeight;
196                int fx = 0, fy = 0;
197                if ( tx < 0 ) {
198                    tw += tx;
199                    fx -= tx;
200                    tx = 0;
201                }
202                if ( ty < 0 ) {
203                    th += ty;
204                    fy -= ty;
205                    ty = 0;
206                }
207                if ( tx+tw > width )
208                    tw = width-tx;
209                if ( ty+th > height )
210                    th = height-ty;
211                src.getRGB( tx, ty, tw, th, rgb, fy*w+fx, w );
212
213                // Create a float array from the pixels. Any pixels off the edge of the source image get duplicated from the edge.
214                i = 0;
215                for ( int y = 0; y < h; y++ ) {
216                    int imageY = y+tileY;
217                    int j;
218                    if ( imageY < 0 )
219                        j = fy;
220                    else if ( imageY > height )
221                        j = fy+th-1;
222                    else
223                        j = y;
224                    j *= w;
225                    for ( int x = 0; x < w; x++ ) {
226                        int imageX = x+tileX;
227                        int k;
228                        if ( imageX < 0 )
229                            k = fx;
230                        else if ( imageX > width )
231                            k = fx+tw-1;
232                        else
233                            k = x;
234                        k += j;
235
236                        ar[0][i] = ((rgb[k] >> 24) & 0xff);
237                        float r = ((rgb[k] >> 16) & 0xff);
238                        float g = ((rgb[k] >> 8) & 0xff);
239                        float b = (rgb[k] & 0xff);
240
241                                                // Bloom...
242                        if ( r > bloomThreshold )
243                                                        r *= bloom;
244//                                                      r = bloomThreshold + (r-bloomThreshold) * bloom;
245                        if ( g > bloomThreshold )
246                                                        g *= bloom;
247//                                                      g = bloomThreshold + (g-bloomThreshold) * bloom;
248                        if ( b > bloomThreshold )
249                                                        b *= bloom;
250//                                                      b = bloomThreshold + (b-bloomThreshold) * bloom;
251
252                                                ar[1][i] = r;
253                                                gb[0][i] = g;
254                                                gb[1][i] = b;
255
256                        i++;
257                        k++;
258                    }
259                }
260
261                // Transform into frequency space
262                fft.transform2D( ar[0], ar[1], cols, rows, true );
263                fft.transform2D( gb[0], gb[1], cols, rows, true );
264
265                // Multiply the transformed pixels by the transformed kernel
266                i = 0;
267                for ( int y = 0; y < h; y++ ) {
268                    for ( int x = 0; x < w; x++ ) {
269                        float re = ar[0][i];
270                        float im = ar[1][i];
271                        float rem = mask[0][i];
272                        float imm = mask[1][i];
273                        ar[0][i] = re*rem-im*imm;
274                        ar[1][i] = re*imm+im*rem;
275                        
276                        re = gb[0][i];
277                        im = gb[1][i];
278                        gb[0][i] = re*rem-im*imm;
279                        gb[1][i] = re*imm+im*rem;
280                        i++;
281                    }
282                }
283
284                // Transform back
285                fft.transform2D( ar[0], ar[1], cols, rows, false );
286                fft.transform2D( gb[0], gb[1], cols, rows, false );
287
288                // Convert back to RGB pixels, with quadrant remapping
289                int row_flip = w >> 1;
290                int col_flip = h >> 1;
291                int index = 0;
292
293                //FIXME-don't bother converting pixels off image edges
294                for ( int y = 0; y < w; y++ ) {
295                    int ym = y ^ row_flip;
296                    int yi = ym*cols;
297                    for ( int x = 0; x < w; x++ ) {
298                        int xm = yi + (x ^ col_flip);
299                        int a = (int)ar[0][xm];
300                        int r = (int)ar[1][xm];
301                        int g = (int)gb[0][xm];
302                        int b = (int)gb[1][xm];
303
304                                                // Clamp high pixels due to blooming
305                                                if ( r > 255 )
306                                                        r = 255;
307                                                if ( g > 255 )
308                                                        g = 255;
309                                                if ( b > 255 )
310                                                        b = 255;
311                        int argb = (a << 24) | (r << 16) | (g << 8) | b;
312                        rgb[index++] = argb;
313                    }
314                }
315
316                // Clip to the output image
317                tx = tileX+iradius;
318                ty = tileY+iradius;
319                tw = tileWidth-2*iradius;
320                th = tileHeight-2*iradius;
321                if ( tx+tw > width )
322                    tw = width-tx;
323                if ( ty+th > height )
324                    th = height-ty;
325                dst.setRGB( tx, ty, tw, th, rgb, iradius*w+iradius, w );
326            }
327        }
328        return dst;
329    }
330
331        public String toString() {
332                return "Blur/Lens Blur...";
333        }
334}